<a href="https://colab.research.google.com/github/Sandeep-Jala/MyOwnLLM/blob/main/Multihead_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [73]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F


In [74]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(InputEmbeddings, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)



# Now we do postional encoding

In [75]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        Pe = torch.zeros(max_seq_length, d_model)
        Postions = torch.arange(0,max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        Pe[:,0::2] = torch.sin(Postions * div_term)
        Pe[:,1::2] = torch.cos(Postions * div_term)
        self.register_buffer('Pe', Pe.unsqueeze(0))
    def forward(self, x):
      return x + self.Pe[:, :x.size(1)]


In [76]:
# Initialize layers
embedding_layer = InputEmbeddings(1000, 6)
positional_encoding = PositionalEncoding(6, 1000)

# Example input
input_ids = torch.randint(0, 1000, (1, 1000))  # shape (batch_size, seq_length)
print(input_ids.shape)
# Apply embedding
embedded = embedding_layer(input_ids)  # shape (1, 5, 5)

# Apply positional encoding
output = positional_encoding(embedded)

print(output)

torch.Size([1, 1000])
tensor([[[ 5.2561,  1.3253, -1.4101, -0.9511, -2.6186,  0.3296],
         [-0.1340,  2.1850, -0.8530,  1.4460, -0.1336, -1.9809],
         [ 3.7167, -1.5794, -3.4649,  5.7472,  1.9074, -0.0166],
         ...,
         [-5.7967,  0.0773, -0.8146, -2.2116,  0.1297,  0.9332],
         [-5.2165,  2.8915,  1.2930,  0.0587, -4.2698, -1.1211],
         [-0.9825, -0.2154,  2.8814, -4.0014,  3.8546,  1.3333]]],
       grad_fn=<AddBackward0>)


# MultiHead Attention

In [77]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads, dropout=0.0):
    super().__init__()
    self.d_model = d_model
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
    self.num_heads = num_heads
    self.head_dim   = d_model // num_heads
    self.query_linear = nn.Linear(d_model, d_model, bias=False)
    self.key_linear = nn.Linear(d_model, d_model, bias=False)
    self.value_linear = nn.Linear(d_model, d_model, bias=False)
    self.output_linear = nn.Linear(d_model, d_model)

  def split_heads(self, x, batch_size):
    seq_length = x.size(1)
    # we split the diemntions into multiple heads
    x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
    # Reshuffle in correct order for the attention
    return x.permute(0, 2, 1, 3)

  def compute_attention(self, query, key, value, mask=None):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attention_weights, value)

  def combine_heads(self, x, batch_size):
      seq_length = x.size(1)
      # Combine heads back to (batch_size, seq_length, d_model)
      x = x.permute(0, 2, 1, 3).contiguous()
      return x.view(batch_size, -1, self.d_model)

  def forward(self, query, key, value, mask=None):
      batch_size = query.size(0)

      # Build the forward pass
      query = self.split_heads(self.query_linear(query), batch_size)
      key = self.split_heads(self.key_linear(key), batch_size)
      value = self.split_heads(self.value_linear(value), batch_size)

      attention_weights = self.compute_attention(query, key, value, mask)
      output = self.combine_heads(attention_weights, batch_size)
      return self.output_linear(output)


In [78]:
class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
    def forward(self, x):
        return self.net(x)

# encoder Model

In [79]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        # Instantiate the layers
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        # Complete the forward method
        attn_output = self.self_attn(x, x, x, src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [80]:
 class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        # Define the embedding, positional encoding, and encoder layers
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

    def forward(self, x, src_mask):
        # Perform the forward pass through the layers
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

In [81]:
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.fc(x)
        return F.log_softmax(logits, dim=-1)


# Decoder Layer

In [82]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        # Define cross-attention and a third layer normalization, remove the cross_attention if you want just decoder
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, tgt_mask, cross_mask):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        # Complete the forward pass
        cross_attn_output = self.cross_attn(x,y,y,cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [83]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        # Define the list of decoder layers and linear layer
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        # Define a linear layer to project hidden states to likelihoods
        self.fc = nn.Linear(d_model,vocab_size)

    def forward(self, x, memory, tgt_mask=None, cross_mask=None):
            # x: (batch, tgt_len)
            x = self.embedding(x)
            x = self.positional_encoding(x)
            for layer in self.layers:
                # pass memory into each layer for cross-attention
                x = layer(x, memory, tgt_mask, cross_mask)
            x = self.fc(x)
            return F.log_softmax(x, dim=-1)

Encoder-Decoder

In [84]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

    def forward(self, x, src_mask, tgt_mask, cross_mask):
        # Complete the forward pass
        encoder_output = self.encoder(x, src_mask)
        decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask)
        return decoder_output


In [85]:
vocab_size     = 5000
d_model        = 128
num_heads      = 8
num_layers     = 2
d_ff           = 512
max_seq_length = 30
dropout        = 0.1

model = Transformer(
    vocab_size, d_model, num_heads,
    num_layers, d_ff, max_seq_length, dropout
)

In [87]:
import torch
def make_causal_mask(L):
    return torch.tril(torch.ones(L, L)).bool()
batch_size = 1
seq_len     = 12
x = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)

src_mask   = None
tgt_mask   = make_causal_mask(seq_len)
cross_mask = None


model.eval()
with torch.no_grad():
    logits = model(x, src_mask, tgt_mask, cross_mask)

print("forward pass successful")
print("Logits shape:", logits.shape)  #

pred_ids = logits.argmax(dim=-1)
print("Predicted IDs:", pred_ids)


forward pass successful
Logits shape: torch.Size([1, 12, 5000])
Predicted IDs: tensor([[ 605, 1144,   86, 1653, 2323, 3789, 2289,  290, 1673, 4775, 3652, 3130]])
