**Base Implementation with Group Query Attention**

In [13]:
import torch.nn as nn
import torch
import copy
import math
import numpy as np
import torch.nn.functional as F

from torch.nn.parameter import Parameter
from torch.nn.modules import ModuleList
from dataclasses import dataclass

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
class LayerNorm(nn.Module):

    def __init__(self, hidden_size, epsilon=1e-12):
      """
      Initialize LayerNorm module.
      """
      super().__init__()

      # Learnable weight parameter for scaling.
      self.weight = nn.Parameter(torch.ones(hidden_size))

      # Learnable bias parameter for shifting.
      self.bias = nn.Parameter(torch.zeros(hidden_size))

      # Small value to avoid division by zero in normalization.
      self.epsilon = epsilon

    def forward(self, x):

      # Compute mean and variance along the last dimension.
      u = x.mean(-1, keepdim=True)
      s = (x - u).pow(2).mean(-1, keepdim=True)

      # Normalize the input tensor.
      x = (x - u) / torch.sqrt(s + self.epsilon)

      # Scale and shift using learnable parameters.
      return self.weight * x + self.bias

In [16]:
class Conv1D(nn.Module):
    def __init__(self, nx, nf):
        '''
        The CONV 1D layer can be thought of as a linear layer itself.
        It is casting an initial tensor x (having the final
        dimension of x.size(-1)) being passed to it to have a final dimension
        of size self.nf.

        We do this to be able to cast the input to query, key and value matrices.

        nx: Number of input features.
        nf: Number of filters (output channels).
        '''
        super().__init__()
        self.nf = nf
        #intialising an empty matrix as weights for size of (nx)X(nf)
        w = torch.empty(nx, nf)
        #initialising these weights as normal distribution
        nn.init.normal_(w, std=0.02)
        #calculating the weights and biases by encoding them using nn.Parameter
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        '''x:The input tensor'''
        #this size output is summation of x second dimension and the nf dimension
        size_out = x.size()[:-1] + (self.nf,)
        # dot multiplying Q,K(transpose) and V
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        #x.view helps in taking the transpose out
        x = x.view(*size_out)
        return x

In [17]:
class FeedForward(nn.Module):
    def __init__(self, dropout, d_model=768, nx=768*4):
        super().__init__()
        self.c_fc    = Conv1D(d_model, nx)
        self.c_proj  = Conv1D(nx, d_model)
        self.act     = F.gelu
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.c_proj(self.act(self.c_fc(x))))

In [18]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model=768, n_head=12, n_ctx=1024, d_head=64, bias=True, scale=True, num_groups=2):
        """
        Initialize Grouped Query Attention Layer module.
        """
        super().__init__()

        # Number of attention heads.
        self.n_head = n_head

        # Dimensionality of the model.
        self.d_model = d_model

        # Number of query groups
        self.num_groups = num_groups

        # Check if the number of groups divides the dimensionality evenly
        assert d_model % num_groups == 0, "Number of groups must evenly divide the dimensionality."

        # Dimensionality of each group
        self.group_dim = d_model // num_groups

        # Dimensionality of each head
        self.head_dim = d_model // n_head

        # 1D Convolutional Layer for attention weights computation.
        self.c_attn = Conv1D(d_model, d_model * 3)

        # Flag to scale attention scores.
        self.scale = scale

        # Softmax activation for attention scores.
        self.softmax = nn.Softmax(dim=-1)

        # Lower triangular bias matrix for masking future tokens.
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))

        # Dropout layer for regularization.
        self.dropout = nn.Dropout(0.1)

        # 1D Convolutional Layer for output projection.
        self.c_proj = Conv1D(d_model, d_model)

        # Rotary Position Embedding
        # self.rpe = RotaryPositionalEmbedding(self.head_dim, n_ctx)

    def split_heads(self, x):
        """
        Split the last dimension of the input tensor into multiple heads.
        """
        new_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)

    def _attn(self, q, k, v, attn_mask=None):
        """
        Compute attention scores and apply attention to values.
        """
        scores = torch.matmul(q, k.transpose(-2, -1))
        if self.scale:
            scores = scores / torch.sqrt(torch.tensor(self.group_dim).float())  # Scale by sqrt(group_dim)
        nd, ns = scores.size(-2), scores.size(-1)
        if attn_mask is not None:
            scores = scores + attn_mask
        scores = self.softmax(scores)
        scores = self.dropout(scores)
        outputs = torch.matmul(scores, v)
        return outputs

    def merge_heads(self, x):
        """
        Merge the heads back to the original shape.
        """
        x = x.permute(0, 2, 1, 3).contiguous()
        new_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_shape)

    def group_queries(self, x):
        """
        Group queries based on the number of query groups.
        """
        return torch.split(x, self.group_dim, dim=-1)

    def forward(self, x):

        # Compute attention weights using 1D convolution.
        x = self.c_attn(x)

        # Split the tensor into query, key, and value.
        q, k, v = x.split(self.d_model, dim=2)

        # Split heads for query, key, and value.
        q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)

        # Apply RopE embeddings
        # q = self.rpe(q)
        # k = self.rpe(k)

        # Group queries
        grouped_queries = self.group_queries(q)

        # Apply grouped attention mechanism
        grouped_outputs = []
        for group_queries in grouped_queries:
            out = self._attn(group_queries, k, v)
            grouped_outputs.append(out)

        # Merge the grouped outputs
        out = torch.cat(grouped_outputs, dim=-1)

        # Merge the heads back to the original shape.
        out = self.merge_heads(out)

        # Apply output projection.
        out = self.c_proj(out)

        return out

In [19]:
class TransformerBlock_GroupQueryAttention(nn.Module):

    def __init__(self, d_model, n_head, num_groups, n_ctx, dropout=0.1):
      """
      Initialize Transformer Block module.
      """
      super().__init__()

      # GroupQuery-Attention Layer
      self.attn = GroupedQueryAttention(d_model=d_model, n_head=n_head, d_head=64,
                                        n_ctx=n_ctx, bias=True, scale=True,
                                        num_groups=num_groups)

      # Feedforward Layer
      self.feedforward = FeedForward(dropout=0.1, d_model=d_model,
                                     nx=d_model * 4)

      # Layer Normalization for the attention output
      self.ln_1 = LayerNorm(d_model)

      # Layer Normalization for the feedforward output
      self.ln_2 = LayerNorm(d_model)

    def forward(self, x):

      # Self-Attention Layer with Layer Normalization and skip connection
      x = x + self.attn(self.ln_1(x))

      # Feedforward Layer with Layer Normalization and skip connection
      x = x + self.feedforward(self.ln_2(x))

      return x

In [20]:
class GPT2_GroupQueryAttention(nn.Module):
    def __init__(self,  n_layer=12 , n_ctx=1024, d_model=768, vcb_sz=50257,n_head=12,num_groups=2):
        """
        Initialize GPT-2 model.
        """
        super().__init__()

        # Number of transformer layers
        self.nlayers = n_layer

        # Transformer block as the basic building unit
        self.block = TransformerBlock_GroupQueryAttention(d_model=d_model, n_head=n_head,
                                      n_ctx = n_ctx,
                                      num_groups = num_groups,
                                      dropout=0.1)

        # List of transformer blocks forming the layers
        self.h = nn.ModuleList([copy.deepcopy(self.block) for i in range(self.nlayers)])

        # Word Embedding layer
        self.wte = nn.Embedding(vcb_sz, d_model)


        # Layer Normalization for the final output
        self.ln_f = LayerNorm(d_model)

        # Linear layer for output predictions
        self.out = nn.Linear(d_model, vcb_sz, bias=False)

        # CrossEntropyLoss for training
        self.loss_fn = nn.CrossEntropyLoss()

        # Initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize the weights of the model.
        """
        # Share weights between output layer and word embedding
        self.out.weight = self.wte.weight

        # Apply custom weight initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """
        Custom weight initialization for linear, embedding, and convolutional layers.
        """
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, src, labels=None, pos_ids=None):

        # Apply word embeddings
        inp = self.wte(src)

        # Forward pass through transformer layers
        for i in range(self.nlayers):
            inp = self.h[i](inp)

        # Apply layer normalization to the final output
        inp = self.ln_f(inp)

        # Linear layer for output predictions
        logits = self.out(inp)

        # Prepare outputs
        outputs = (logits,) + (inp,)

        # If labels are provided, compute and return the loss
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            outputs = (loss,) + outputs
            return outputs

        # Otherwise, return logits
        return logits

In [21]:
# load pretrained_weights from hugging face
# download file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin to `.`

!wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin .

--2023-12-17 15:33:46--  https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.102.22, 52.217.126.64, 54.231.224.24, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.102.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 548118077 (523M) [application/octet-stream]
Saving to: ‘gpt2-pytorch_model.bin.1’


2023-12-17 15:33:54 (64.8 MB/s) - ‘gpt2-pytorch_model.bin.1’ saved [548118077/548118077]

--2023-12-17 15:33:54--  http://./
Resolving . (.)... failed: No address associated with hostname.
wget: unable to resolve host address ‘.’
FINISHED --2023-12-17 15:33:54--
Total wall clock time: 8.2s
Downloaded: 1 files, 523M in 8.1s (64.8 MB/s)


In [22]:
model = GPT2_GroupQueryAttention()

model_dict = model.state_dict() #currently with random initialization
state_dict = torch.load("/content/gpt2-pytorch_model.bin") #pretrained weights

old_keys = []
new_keys = []
for key in state_dict.keys():
    if "mlp" in key: #The hugging face state dict references the feedforward network as mlp, need to replace to `feedforward` be able to reuse these weights
        new_key = key.replace("mlp", "feedforward")
        new_keys.append(new_key)
        old_keys.append(key)

for old_key, new_key in zip(old_keys, new_keys):
    state_dict[new_key]=state_dict.pop(old_key)

pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}

model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval() #model in inference mode as it's now initialized with pretrained weights

GPT2_GroupQueryAttention(
  (block): TransformerBlock_GroupQueryAttention(
    (attn): GroupedQueryAttention(
      (c_attn): Conv1D()
      (softmax): Softmax(dim=-1)
      (dropout): Dropout(p=0.1, inplace=False)
      (c_proj): Conv1D()
    )
    (feedforward): FeedForward(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (ln_1): LayerNorm()
    (ln_2): LayerNorm()
  )
  (h): ModuleList(
    (0-11): 12 x TransformerBlock_GroupQueryAttention(
      (attn): GroupedQueryAttention(
        (c_attn): Conv1D()
        (softmax): Softmax(dim=-1)
        (dropout): Dropout(p=0.1, inplace=False)
        (c_proj): Conv1D()
      )
      (feedforward): FeedForward(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_1): LayerNorm()
      (ln_2): LayerNorm()
    )
  )
  (wte): Embedding(50257, 768)
  (ln_f): LayerNorm()
  (out): Linear(in_features=768, out_features=50257, b

In [23]:
import time
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
context = torch.tensor([tokenizer.encode("Hi Contlo, How")])

def generate(context, ntok=20):
    start_time = time.time()
    for _ in range(ntok):
        out = model(context)
        logits = out[:, -1, :]
        indices_to_remove = logits < torch.topk(logits, 10)[0][..., -1, None]
        logits[indices_to_remove] = np.NINF
        next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1).squeeze(1)
        context = torch.cat([context, next_tok.unsqueeze(-1)], dim=-1)
    end_time = time.time()
    inference_time = end_time - start_time
    return context, inference_time

out, inference_time = generate(context, ntok=20)
decoded_output = tokenizer.decode(out[0])

print(f"Inference Time: {inference_time:.4f} seconds")
print(f"Generated Output: {decoded_output}")


Inference Time: 2.5750 seconds
Generated Output: Hi Contlo, How, or all or any any all or any of any any of any any of any any any any
