In [None]:
device = "cuda"

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import torch
from torch import nn


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        # Linear projections for Q, K, V
        self.queries = nn.LazyLinear(out_features=embed_dim)
        self.keys = nn.LazyLinear(out_features=embed_dim)
        self.values = nn.LazyLinear(out_features=embed_dim)

        # Multi-head attention block (batch_first=True means inputs shape [batch, seq, embed_dim])
        self.att_block = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

        # Feedforward network
        self.feedforward = nn.Sequential(
            nn.LazyLinear(out_features=embed_dim * 4),
            nn.LeakyReLU(),
            nn.LazyLinear(out_features=embed_dim)
        )

        # Layer normalization layers
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        # Self-attention sub-layer with residual connection and normalization.
        residual = x

        # Linear projections for Q, K, V.
        q, k, v = self.queries(x), self.keys(x), self.values(x)

        seq_len = x.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).bool()
        attn_out, _ = self.att_block(
            q, k, v,
            attn_mask=~causal_mask,  # mask future tokens
            key_padding_mask=key_padding_mask  # mask pads
        )

        # Apply multi-head attention.
        att_output, _ = self.att_block(q, k, v)
        att_output = self.dropout(att_output)

        # Add residual and normalize.
        x = self.norm1(residual + att_output)

        # Feedforward sub-layer with residual connection and normalization.
        residual2 = x
        ff_output = self.feedforward(x)
        ff_output = self.dropout(ff_output)
        x = self.norm2(residual2 + ff_output)

        return x


In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LazyLinear(out_features=vocab_size)
        )

    def forward(self, attended_sequence):
        outputs = self.decoder(attended_sequence)
        return outputs

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_transformer_blocks, max_sequence_length):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_sequence_length, embed_dim)

        # Create a ModuleList of transformer blocks.
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim=embed_dim, num_heads=num_heads)
            for _ in range(num_transformer_blocks)
        ])

        self.decoder = Decoder(vocab_size=vocab_size)

    def forward(self, input_ids, attention_mask):
        batch_size, seq_length = input_ids.shape

        # Embed tokens.
        embedded_sequence = self.embedding(input_ids)  # Shape: [batch_size, seq_length, embed_dim]

        # Create positional embeddings.
        positions = torch.arange(seq_length, device=embedded_sequence.device).unsqueeze(0).expand(batch_size,
                                                                                                  seq_length)
        pos_embeds = self.positional_embedding(positions)

        # Sum token and positional embeddings.
        x = embedded_sequence + pos_embeds

        # Pass through the stacked transformer blocks.
        for block in self.transformer_blocks:
            x = block(x, key_padding_mask=(attention_mask == 0))

        # Decoder projection to vocabulary logits.
        predictions = self.decoder(x)
        return predictions


In [None]:
import torch

model_path = "models/vocabsize_(50257,)_embeddim_(768,)_numheads_(12,)_numtransformerblocks_(12,)_maxseqlen_(1024,)_timestamp_20250405_172553.pt"

model = Model(
    vocab_size=50257,
    embed_dim=768,
    num_heads=12,
    num_transformer_blocks=12,
    max_sequence_length=1024,
).to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)

In [None]:
import torch

MAX_GEN_LENGTH = 1050

BASE_TEXT = "The cat is very"

model.eval()
with torch.no_grad():
    curr_text = BASE_TEXT
    print(curr_text, end="")
    for _ in range(MAX_GEN_LENGTH):
        input_ids = tokenizer.encode(curr_text, return_tensors="pt").to(device)
        input_ids = input_ids[:, -1023:]  # CRAPPY IMPLEMENTATION BUT THIS VALUE IS THE MAX SEQ LENGTH - 1

        attention_mask = (input_ids != tokenizer.eos_token_id).long().to(device)

        output = model(input_ids, attention_mask)
        next_token_logits = output[:, -1, :]
        probs = torch.softmax(next_token_logits, dim=-1)
        next_token_id = torch.multinomial(probs, num_samples=1)
        next_token_str = tokenizer.decode(next_token_id.squeeze().item())
        curr_text += next_token_str

        print(next_token_str, end="")