In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GPT(nn.Module):
    def __init__(self, vocab_size, embedding_size, max_seq_length):
        super(GPT, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_size)

        self.positional_encoding = PositionalEncoding(embedding_size, max_seq_length)
        
        self.self_attention = SelfAttention(embedding_size)
       
        self.add_norm1 = nn.LayerNorm(embedding_size) 
        
        self.feed_forward = FeedForward(embedding_size, 10)
        
        self.add_norm = nn.LayerNorm(embedding_size)
        
        self.fc = nn.Linear(embedding_size, vocab_size)
    
    def forward(self, input_ids):
        embeddings = self.embedding(input_ids)
        embeddings = self.positional_encoding(embeddings)
        attention_output = self.self_attention(embeddings)
        output1 = embeddings + self.add_norm1(attention_output)
        ff_output = self.feed_forward(output1)
        output = output1 + self.add_norm(ff_output)
        
        # Extract the last word in the sequence
        logits = self.fc(output)
        
        return logits


class FeedForward(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, input_size)
    
    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_enc = torch.zeros(1, max_seq_length, d_model)
        pos_enc[:, :, 0::2] = torch.sin(position * div_term)
        pos_enc[:, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        x = x + self.pos_enc[:, :x.size(1)].detach()
        return x

class SelfAttention(nn.Module):
    def __init__(self, embedding_size):
        super(SelfAttention, self).__init__()
        
        # Create linear layers for query, key, and value projections
        self.linear_query = nn.Linear(embedding_size, embedding_size)
        self.linear_key = nn.Linear(embedding_size, embedding_size)
        self.linear_value = nn.Linear(embedding_size, embedding_size)
        

    def forward(self, embeddings):
        # Project embeddings into query, key, and value spaces
        query = self.linear_query(embeddings)
        key = self.linear_key(embeddings)
        value = self.linear_value(embeddings)
        
        # Compute attention scores
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (query.size(-1) ** 0.5)
        
        # Apply masking to prevent attending to future positions
        seq_length = attention_scores.size(-1)
        mask = torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1).bool().to(attention_scores.device)
        attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        attention_probs = F.softmax(attention_scores, dim=-1)
        
        # Compute weighted sum (Attention Scores x V)
        output = torch.matmul(attention_probs, value)
        
        return output


In [3]:
# Example usage
vocab_size = 10000
embedding_size = 256
max_seq_length = 100  # Set your desired maximum sequence length

# Create an instance of SelfAttention
self_attention = SelfAttention(embedding_size)

# Generate some example token IDs (batch size of 4, sequence length of 10)
token_ids = torch.randint(0, vocab_size, (4, 10))

# Calculate attention weights and weighted sum
#output = self_attention(token_ids)



In [4]:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load("tinystories_tokeniser.model")

vocab = [sp.id_to_piece(i) for i in range(sp.get_piece_size())]

vocab_size = len(vocab)
embedding_size = 30
max_seq_length = 1192
list_of_lists = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
tensor_2d = torch.tensor(list_of_lists)
gpt_model = GPT(vocab_size, embedding_size, max_seq_length)
probs = gpt_model(tensor_2d)


In [5]:
print(probs)

tensor([[ 0.3861, -0.2036,  0.2616,  ...,  0.1122,  0.0513, -0.9487],
        [ 0.2843, -0.6717,  1.0048,  ...,  0.0253, -0.3398,  0.2390],
        [ 0.9024, -1.8237,  0.8030,  ..., -1.3157, -0.2283,  0.1605]],
       grad_fn=<AddmmBackward0>)


In [6]:
print(probs.shape)

torch.Size([3, 16000])
