In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Mha_GPT(nn.Module):
    def __init__(self, vocab_size, embedding_size, max_seq_length, num_heads):
        super(Mha_GPT, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_size)

        self.positional_encoding = PositionalEncoding(embedding_size, max_seq_length)
        
        self.multihead_self_attention = MultiheadSelfAttention(embedding_size, num_heads)
       
        self.add_norm1 = nn.LayerNorm(embedding_size) 
        
        self.feed_forward = FeedForward(embedding_size, 10)
        
        self.add_norm = nn.LayerNorm(embedding_size)
        
        self.fc = nn.Linear(embedding_size, vocab_size)
    
    def forward(self, input_ids):
        
        embeddings = self.embedding(input_ids)    
        
        embeddings = self.positional_encoding(embeddings)
        
        attention_output, heads_attention_weights = self.multihead_self_attention(embeddings)
        
        output1 = embeddings + self.add_norm1(attention_output)
        
        ff_output = self.feed_forward(output1)
        
        output = output1 + self.add_norm(ff_output)       
        
        logits = self.fc(output)
        
        return logits, heads_attention_weights


class FeedForward(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, input_size)
    
    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_enc = torch.zeros(1, max_seq_length, d_model)
        pos_enc[:, :, 0::2] = torch.sin(position * div_term)
        pos_enc[:, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        x = x + self.pos_enc[:, :x.size(1)].detach()
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, embedding_size, num_heads):
        super(MultiheadSelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.num_heads = num_heads

        split_size = embedding_size // num_heads

         # Linear projections for queries, keys, and values for each head
        self.linear_queries = nn.ModuleList([nn.Linear(split_size, split_size) for _ in range(num_heads)])
        self.linear_keys = nn.ModuleList([nn.Linear(split_size, split_size) for _ in range(num_heads)])
        self.linear_values = nn.ModuleList([nn.Linear(split_size, split_size) for _ in range(num_heads)])
        # Final linear projection
        self.linear_out = nn.Linear(split_size * num_heads, embedding_size)

    def forward(self, embeddings, mask=None):
        # Split the embeddings into 'num_heads' parts
        split_size = self.embedding_size // self.num_heads

      
        split_embeddings = embeddings.view(embeddings.size(0), embeddings.size(1), self.num_heads, split_size)

       
        
        # Linear projections for queries, keys, and values for each head
        queries = [linear(split_embeddings[:, :, i, :]) for i, linear in enumerate(self.linear_queries)]
        keys = [linear(split_embeddings[:, :, i, :]) for i, linear in enumerate(self.linear_keys)]
        values = [linear(split_embeddings[:, :, i, :]) for i, linear in enumerate(self.linear_values)]

      
     
        # Perform attention independently for each head
        attention_output_lists = [self._scaled_dot_product_attention(q, k, v, mask) for q, k, v in zip(queries, keys, values)]
        attention_outputs = [inner_list[0] for inner_list in attention_output_lists]
        heads_attention_weights = [inner_list[1] for inner_list in attention_output_lists]
        
        # Concatenate the results from all heads
        concatenated_attention = torch.cat(attention_outputs, dim=-1)
     
        # Apply the final linear projection
        output = self.linear_out(concatenated_attention)

        return output, heads_attention_weights

    def _scaled_dot_product_attention(self, query, key, value, mask=None):
        # Compute attention scores
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (query.size(-1) ** 0.5)

        # Apply mask
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # Apply softmax
        attention_weights = F.softmax(attention_scores, dim=-1)
        

        # Compute weighted sum (Attention Scores x V)
        output = torch.matmul(attention_weights, value)

        return output, attention_weights


In [3]:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load("tinystories_tokeniser.model")

vocab = [sp.id_to_piece(i) for i in range(sp.get_piece_size())]

vocab_size = len(vocab)
embedding_size = 32
max_seq_length = 12
list_of_lists = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
tensor_2d = torch.tensor(list_of_lists)
mha_model = Mha_GPT(vocab_size, embedding_size, max_seq_length, 4)
probs, heads_attention_weights = mha_model(tensor_2d)
print(heads_attention_weights)

[tensor([[[0.3566, 0.2985, 0.3449],
         [0.2803, 0.5300, 0.1896],
         [0.2498, 0.4374, 0.3128]],

        [[0.3985, 0.3823, 0.2192],
         [0.1069, 0.2132, 0.6798],
         [0.0799, 0.3115, 0.6086]],

        [[0.4470, 0.1556, 0.3974],
         [0.4303, 0.2464, 0.3233],
         [0.5613, 0.0830, 0.3557]]], grad_fn=<SoftmaxBackward0>), tensor([[[0.3559, 0.3340, 0.3101],
         [0.3260, 0.3428, 0.3311],
         [0.2630, 0.4204, 0.3166]],

        [[0.3393, 0.3654, 0.2953],
         [0.5932, 0.1683, 0.2385],
         [0.2273, 0.1289, 0.6438]],

        [[0.4670, 0.1724, 0.3606],
         [0.6996, 0.1114, 0.1890],
         [0.5667, 0.1947, 0.2386]]], grad_fn=<SoftmaxBackward0>), tensor([[[0.3358, 0.3086, 0.3556],
         [0.3163, 0.2746, 0.4090],
         [0.4141, 0.4912, 0.0947]],

        [[0.3384, 0.3789, 0.2827],
         [0.3592, 0.4100, 0.2308],
         [0.3461, 0.3889, 0.2650]],

        [[0.2714, 0.2393, 0.4893],
         [0.3415, 0.2610, 0.3974],
         [0.314

In [20]:
print(probs)

tensor([[[-0.5804,  2.0289,  0.9384,  ...,  0.0202, -0.9566,  0.8582],
         [ 0.2776,  1.8028, -0.3687,  ..., -0.1361, -2.1818,  0.6668],
         [ 0.5065,  2.2130,  0.4301,  ...,  1.1961, -0.7159,  0.9888]],

        [[-0.1150,  0.7286, -0.4281,  ...,  0.4932, -0.9897,  2.1935],
         [-0.0316,  1.5358, -0.7105,  ...,  0.5412, -0.6257,  2.7560],
         [ 0.1027, -0.1298, -0.2231,  ..., -0.6187, -0.5578,  2.9701]],

        [[-1.5034,  2.5926,  1.6077,  ..., -0.5684,  0.4391,  0.5368],
         [-0.9994,  2.1132,  0.0809,  ..., -0.9080, -0.6356,  1.1818],
         [-0.6180,  2.5011,  0.0144,  ..., -1.1487,  1.3814,  0.1739]]],
       grad_fn=<ViewBackward0>)
