**Sliding Window Attention in the base GPT2 Implementation**

In [1]:
!pip install einops xformers

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m657.3 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting xformers
  Downloading xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl (213.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Collecting torch==2.1.2 (from xformers)
  Downloading torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.1.2->xformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from tor

In [2]:
import torch.nn as nn
import torch
import copy
import math
import numpy as np
import torch.nn.functional as F

from torch.nn.parameter import Parameter
from torch.nn.modules import ModuleList
from dataclasses import dataclass

from typing import Optional
from xformers.components.attention import (
    AttentionMask,
    maybe_sparsify,
    sparsify


)
from xformers.components.attention.attention_patterns import (
    causal_1d_pattern,
    local_1d_pattern,
)

from torch import broadcast_tensors
from einops import rearrange, repeat
from einops import rearrange, repeat, pack, unpack
from torch import nn, einsum



In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class LayerNorm(nn.Module):

    def __init__(self, hidden_size, epsilon=1e-12):
      """
      Initialize LayerNorm module.
      """
      super().__init__()

      # Learnable weight parameter for scaling.
      self.weight = nn.Parameter(torch.ones(hidden_size))

      # Learnable bias parameter for shifting.
      self.bias = nn.Parameter(torch.zeros(hidden_size))

      # Small value to avoid division by zero in normalization.
      self.epsilon = epsilon

    def forward(self, x):

      # Compute mean and variance along the last dimension.
      u = x.mean(-1, keepdim=True)
      s = (x - u).pow(2).mean(-1, keepdim=True)

      # Normalize the input tensor.
      x = (x - u) / torch.sqrt(s + self.epsilon)

      # Scale and shift using learnable parameters.
      return self.weight * x + self.bias

In [5]:
class Conv1D(nn.Module):
    def __init__(self, nx, nf):
        '''
        nx: Number of input features.
        nf: Number of filters (output channels).
        '''
        super().__init__()
        self.nf = nf
        #intialising an empty matrix as weights for size of (nx)X(nf)
        w = torch.empty(nx, nf)
        #initialising these weights as normal distribution
        nn.init.normal_(w, std=0.02)
        #calculating the weights and biases by encoding them using nn.Parameter
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        '''x:The input tensor'''
        #this size output is summation of x second dimension and the nf dimension
        size_out = x.size()[:-1] + (self.nf,)
        # dot multiplying Q,K(transpose) and V
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)#x.view helps in taking the transpose out
        x = x.view(*size_out)
        return x


In [6]:
class FeedForward(nn.Module):
    def __init__(self, dropout, d_model=768, nx=768*4):
        super().__init__()
        self.c_fc    = Conv1D(d_model, nx)
        self.c_proj  = Conv1D(nx, d_model)
        self.act     = F.gelu
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.c_proj(self.act(self.c_fc(x))))

**Changing Mulihead Attention in the base implementation with Sliding Window Attention**

In [7]:
def exists(val):
    return val is not None

def default(value, d):
    return d if not exists(value) else value

def to(t):
    return {'device': t.device, 'dtype': t.dtype}

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    '''Function for padding over tensor for how many times which is multiple here
    if the seqlen is not a muliple of multiple so we need to pad the remaining for which we caluclate the remainder and then pad the remainder
    Params:
    tensor:The tensor that needs to be pad
    multiple:The multiple upto which padding is happenining
    dim: the dimension accross for padding
    value:what should the padded values'''
    seqlen = tensor.shape[dim]
    m = seqlen / multiple
    if m.is_integer():
        return False, tensor
    remainder = math.ceil(m) * multiple - seqlen#calculating the remaninder after the multiple padding has happened
    pad_offset = (0,) * (-1 - dim) * 2
    return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value)

def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
    '''This is a function for padding our x with the sliding window attention mechanism first getting the shapes and dimensions
    then padding the x from backward to forward which is of range of window 2n+1
    now we iteratively pad with different combination of windows by loop over the forward+backward+1
    and finally concatenating all the different tensors that are formed resulting in our final attention'''
    t = x.shape[1]
    dims = (len(x.shape) - dim) * (0, 0)
    padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
    return torch.cat(tensors, dim = dim)

In [8]:
class SlidingWindowAttention(nn.Module):
    def __init__(self, d_model=768, n_head=12,window_size=3, n_ctx=1024, d_head=64, bias=True, scale=False,look_forward=1,look_backward=1):
        '''An implementation of a sliding window attention, as proposed in Longformer I am also combing the rotationaol embeddings with it for
        checking out the results
        Params:
        d_model:The dimension that needs to be feed into our model
        n_head:The number of heads for attention
        n_ctx:a parameters for buffer registry for bias
        d_head:the dimension head output
        bias:A bool for including or not including bias
        scale: Whether to scale the attention scores by the square root of the dimension of the queries(use sqrt(dk) or not) "
        '''
        super().__init__()
        self.n_head  = n_head
        self.d_model = d_model
        self.c_attn  = Conv1D(d_model, d_model*3)
        self.proj_out = nn.Linear(n_head * d_head, d_model)
        self.scale   = scale
        self.softmax = nn.Softmax(dim=-1)
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.dropout = nn.Dropout(0.1)
        self.c_proj  = Conv1D(d_model, d_model)
        # self.rotary_emb = RotaryEmbedding(dim = 32)#intializing the rotatory embedding with dimension 32
        self.window_size = window_size
        # Properties specific to this attention mechanism
        self.supports_attention_mask = True
        self.supports_key_padding_mask = False

        self.attention_mask: Optional[torch.Tensor] = None#attention mask to store the values of the slided attention window
        self.requires_same_k_q_dimensions = True

        self.look_backward=look_backward
        self.look_forward=look_forward

        self.causal=False
        self.force_sparsity=False
        self.shared_qk=False

        self.attn_mask=None
        self.TOKEN_SELF_ATTN_VALUE = -5e4

    def _get_local_mask(self, shape: torch.Size) -> AttentionMask:
      self.window_size = min(self.window_size * 2 + 1, shape[1]) if self.causal else min(self.window_size, shape[1])
      mask = local_1d_pattern(shape[1], window_size)

      if self.causal:
          mask &= causal_1d_pattern(shape[1])

      mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)

      # Convert mask to tensor and set its dtype to float32
      mask_tensor = mask.to(torch.float32)

      return AttentionMask(mask_tensor)

    def split_heads(self, x):
        """
        spliting inyo given number of heads and then returning
        return shape [`batch`, `head`, `sequence`, `features`]
        """
        new_shape = x.size()[:-1] + (self.n_head, x.size(-1)//self.n_head)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)

    def _attn(self, q, k, v, attn_mask=None):
        """The main attention function
        That claculates the attention using our dot product formula"""
        scores  = torch.matmul(q, k.transpose(-2, -1))# dot multiplication between q and k transpose
        if self.scale: scores = scores/math.sqrt(v.size(-1))# scaling it by dividing by sqrt(dk)
        nd, ns  = scores.size(-2), scores.size(-1)
        if attn_mask is not None: scores = scores + attn_mask# adding scores with attention mask values
        scores  = self.softmax(scores)# adding softmax values
        scores  = self.dropout(scores) #dropout of 0.1 as mentioned
        outputs = torch.matmul(scores, v) # now the final matrix multiplication between score and V
        return outputs

    def merge_heads(self, x):
        # merging the attention heads into one
        x = x.permute(0, 2, 1, 3).contiguous()
        new_shape = x.size()[:-2] + (x.size(-2)*x.size(-1),)
        return x.view(*new_shape)


    def forward(self, x,mask = None,input_mask = None,attn_bias = None,window_size = None):
        '''The feed forward function that calculates the attention, split the heads, make attention, merge heads and project out the output
        Applies convolutional attention to the input tensor.
Splits the query, key, and value tensors into heads.
Applies rotary embeddings to the query and key.
Dynamically sets the window size if provided.
Asserts that the sequence length is divisible by the window size.
Applies the sliding window attention mechanism.
Computes attention, applies masks, and performs aggregation.
Returns the final output tensor.'''
        mask = default(mask, input_mask)

        x        = self.c_attn(x) #new `x` shape - `[1,3,2304]`
        q, k, v  = x.split(self.d_model, dim=2)

        q, k, v  = self.split_heads(q), self.split_heads(k), self.split_heads(v)
        #applying the rotatory embeddings over query and key
        # q = self.rotary_emb.rotate_queries_or_keys(q)
        # k = self.rotary_emb.rotate_queries_or_keys(k)
        shape, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
        (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
        b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype

        scale = default(self.scale, dim_head ** -0.5)

        if window_size is not None:
          self.window_size = window_size  # Set the window size dynamically

        assert (n % window_size) == 0, f'sequence length {n} must be divisible by window size {window_size} for local attention'

        windows = n // window_size


        seq = torch.arange(n, device = device)
        b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)

        bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))

        bq = bq * scale

        look_around_kwargs = dict(
            backward =  look_backward,
            forward =  look_forward,
            pad_value = pad_value
        )

        bk = look_around(bk, **look_around_kwargs)
        bv = look_around(bv, **look_around_kwargs)

        bq_t = b_t
        bq_k = look_around(b_t, **look_around_kwargs)

        bq_t = rearrange(bq_t, '... i -> ... i 1')
        bq_k = rearrange(bq_k, '... j -> ... 1 j')

        pad_mask = bq_k == pad_value

        sim = einsum('b h i e, b h j e -> b h i j', bq, bk)

        if exists(attn_bias):
            heads = attn_bias.shape[0]
            assert (b % heads) == 0

            attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
            sim = sim + attn_bias

        mask_value = max_neg_value(sim)

        if shared_qk:
            self_mask = bq_t == bq_k
            sim = sim.masked_fill(self_mask, self.TOKEN_SELF_ATTN_VALUE)
            del self_mask


        sim = sim.masked_fill(pad_mask, mask_value)

        # take care of key padding mask passed in

        if exists(mask):
            batch = mask.shape[0]
            assert (b % batch) == 0

            h = b // mask.shape[0]



            mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size)
            mask = look_around(mask, **{**look_around_kwargs, 'pad_value': False})
            mask = rearrange(mask, '... j -> ... 1 j')
            mask = repeat(mask, 'b ... -> (b h) ...', h = h)
            sim = sim.masked_fill(~mask, mask_value)
            del mask

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregation

        out = einsum('b h i j, b h j e -> b h i e', attn, bv)
        out = rearrange(out, 'b w n d -> b (w n) d')

        # out = self.proj_out(out)
        out, *_ = unpack(out, packed_shape, '* n d')
        out=rearrange(out, 'b n s d -> b s (n d)')
        return out

In [9]:
a=SlidingWindowAttention(d_model=768, n_head=12, d_head=64, n_ctx=1024, bias=True, scale=False)
d_model = 768

# Create a dummy variable
dummy_out = torch.randn(36, 63, d_model)
out=a(dummy_out)
x=torch.randn(36, 63, 64)
y=torch.randn(36, 63, 64)
out.shape

torch.Size([36, 63, 768])

In [10]:
class TransformerBlock_SlidingWindowAttention(nn.Module):
    def __init__(self, d_model=768, n_head=12, dropout=0.1,window_size=2):
        self.window_size=window_size
        super().__init__()
        self.attn        = SlidingWindowAttention(d_model=768,window_size=window_size, n_head=12, d_head=64, n_ctx=1024, bias=True, scale=False)
        self.feedforward = FeedForward(dropout=0.1, d_model=768, nx=768*4)
        self.ln_1        = LayerNorm(d_model)
        self.ln_2        = LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x),window_size=window_size)
        x = x + self.feedforward(self.ln_2(x))
        return x

In [11]:
def _get_clones(module, n):
    '''Here we can make certain copies of transformers'''
    return ModuleList([copy.deepcopy(module) for i in range(n)])

In [12]:
window_size=5

In [13]:
class GPT2_SlidingWindowAttention(nn.Module):
    def __init__(self, nlayers=12, n_ctx=1024, d_model=768, vcb_sz=50257):
        super(GPT2_SlidingWindowAttention, self).__init__()
        self.nlayers = nlayers
        block        = TransformerBlock_SlidingWindowAttention(window_size=window_size,d_model=768, n_head=12, dropout=0.1)
        self.h       = _get_clones(block, 12)
        self.wte     = nn.Embedding(vcb_sz, d_model)
        self.wpe     = nn.Embedding(n_ctx, d_model)
        self.drop    = nn.Dropout(0.1)
        self.ln_f    = LayerNorm(d_model)
        self.out     = nn.Linear(d_model, vcb_sz, bias=False)
        self.loss_fn = nn.CrossEntropyLoss()
        self.init_weights()

    def set_window_size(self, window_size):
        self.window_size = window_size

    def init_weights(self):
        '''Initialization of weights'''
        self.out.weight = self.wte.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        '''If the Linear, Embedding and Conv1D then nomrally initializing with mean and S.D'''
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                '''Data Bias zero'''
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, src, labels=None, pos_ids=None):
        '''Adding the positional embedding, dropping, then adding inputs, logits and outputs which are being used for loss function and then adding outputs and loss'''
        if pos_ids is None: pos_ids = torch.arange(0, src.size(-1)).unsqueeze(0)
        inp = self.drop((self.wte(src)+self.wpe(pos_ids)))
        for i in range(self.nlayers): inp = self.h[i](inp)
        inp     = self.ln_f(inp)
        logits  = self.out(inp)
        outputs = (logits,) + (inp,)

        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            outputs = (loss,) + outputs
            return outputs
        return logits

In [14]:
# load pretrained_weights from hugging face
# download file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin to `.`

!wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin .

--2023-12-17 15:27:26--  https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin
Resolving s3.amazonaws.com (s3.amazonaws.com)... 16.182.106.200, 52.217.175.120, 52.217.171.48, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|16.182.106.200|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 548118077 (523M) [application/octet-stream]
Saving to: ‘gpt2-pytorch_model.bin’


2023-12-17 15:28:09 (12.7 MB/s) - ‘gpt2-pytorch_model.bin’ saved [548118077/548118077]

--2023-12-17 15:28:09--  http://./
Resolving . (.)... failed: No address associated with hostname.
wget: unable to resolve host address ‘.’
FINISHED --2023-12-17 15:28:09--
Total wall clock time: 42s
Downloaded: 1 files, 523M in 41s (12.7 MB/s)


In [15]:
model = GPT2_SlidingWindowAttention()
# load pretrained_weights from hugging face

model_dict = model.state_dict() #currently with random initialization
state_dict = torch.load("/content/gpt2-pytorch_model.bin") #pretrained weights

In [16]:
old_keys = []
new_keys = []
for key in state_dict.keys():
    if "mlp" in key: #The hugging face state dict references the feedforward network as mlp, need to replace to `feedforward` be able to reuse these weights
        new_key = key.replace("mlp", "feedforward")
        new_keys.append(new_key)
        old_keys.append(key)

for old_key, new_key in zip(old_keys, new_keys):
    state_dict[new_key]=state_dict.pop(old_key)

pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}

model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval() #model in inference mode as it's now initialized with pretrained weights

GPT2_SlidingWindowAttention(
  (h): ModuleList(
    (0-11): 12 x TransformerBlock_SlidingWindowAttention(
      (attn): SlidingWindowAttention(
        (c_attn): Conv1D()
        (proj_out): Linear(in_features=768, out_features=768, bias=True)
        (softmax): Softmax(dim=-1)
        (dropout): Dropout(p=0.1, inplace=False)
        (c_proj): Conv1D()
      )
      (feedforward): FeedForward(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_1): LayerNorm()
      (ln_2): LayerNorm()
    )
  )
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (ln_f): LayerNorm()
  (out): Linear(in_features=768, out_features=50257, bias=False)
  (loss_fn): CrossEntropyLoss()
)

In [17]:
total_params = sum(p.numel() for p in model.parameters())

# Calculate size in bytes and megabytes
size_bytes = total_params * 4  # Assuming float32 parameters, where each parameter is 4 bytes
size_mb = size_bytes / (1024 ** 2)

print(f"Total size of the GPT-2 with rotatory embeddings and sliding window attention is: {size_bytes} bytes or {size_mb:.2f} MB")

Total size of the GPT-2 with rotatory embeddings and sliding window attention is: 526107648 bytes or 501.74 MB


In [22]:
import time
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
context = torch.tensor([tokenizer.encode("Hi Contlo, How")])

def generate_dynamic(context, window_size, ntok=20):
    start_time = time.time()
    for _ in range(ntok):
        model.set_window_size(window_size)  # Set the window size dynamically
        out = model(context)
        logits = out[:, -1, :]
        indices_to_remove = logits < torch.topk(logits, 10)[0][..., -1, None]
        logits[indices_to_remove] = np.NINF
        next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1).squeeze(1)
        context = torch.cat([context, next_tok.unsqueeze(-1)], dim=-1)

        # Dynamically adjust the length of the input sequence based on the window_size
        input_length = context.size(-1)
        padding_size = window_size - (input_length % window_size)
        if padding_size != window_size:
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
            padding_tokens = torch.zeros((context.size(0), padding_size), dtype=torch.long, device=context.device) + pad_token_id
            context = torch.cat([context, padding_tokens], dim=-1)

    end_time = time.time()
    inference_time = end_time - start_time
    return context, inference_time

# Usage
window_size = 5  # Adjust this as needed
out, inference_time = generate_dynamic(context, ntok=20, window_size=window_size)
decoded_output = tokenizer.decode(out[0])

print(f"Inference Time: {inference_time:.4f} seconds")
print(f"Generated Output: {decoded_output}")


Inference Time: 7.4760 seconds
Generated Output: Hi Contlo, How you!!!!
!!!! and!!!! a!!!!
!!!! (!!!! and!!!!
!!!!.!!!!,!!!! I!!!! (!!!! the!!!! the!!!! (!!!! a!!!!,!!!! the!!!! I!!!!.!!!!
