In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [None]:
from torch.utils.data import Dataset, DataLoader

class DummyDataset(Dataset):
    def __init__(self, num_samples=1000, seq_length=32, vocab_size=30000, pad_token_id=0):
        self.samples = []
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.pad_token_id = pad_token_id

        for _ in range(num_samples):
            actual_length = torch.randint(5, seq_length + 1, (1,)).item()
            tokens = torch.randint(5, vocab_size, (actual_length,))
            padding = torch.full((seq_length - actual_length,), pad_token_id)
            input_ids = torch.cat([tokens, padding], dim=0)

            attention_mask = torch.cat([
                torch.ones(actual_length),
                torch.zeros(seq_length - actual_length)
            ], dim=0)

            self.samples.append({
                'input_ids': input_ids,
                'attention_mask': attention_mask
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return {
            'input_ids': self.samples[idx]['input_ids'].long(),
            'attention_mask': self.samples[idx]['attention_mask'].long()
        }


In [None]:
batch_size = 4

dummy_dataset = DummyDataset(num_samples=20, seq_length=32, vocab_size=30000, pad_token_id=0)
dataloader = DataLoader(dummy_dataset, batch_size=batch_size, shuffle=True)


In [None]:
class Embedding(nn.Module):
  def __init__(self, vocab_size, hidden_size, max_sequence_length, dropout = 0.1):
    super().__init__()
    self.token_embeds = nn.Embedding(vocab_size, hidden_size)
    self.position_embeds = nn.Embedding(max_sequence_length, hidden_size)

    self.dropout = nn.Dropout(dropout)
    self.layernorm = nn.LayerNorm(hidden_size)

  def forward(self, input_ids):
    B, L = input_ids.shape
    token_embeddings = self.token_embeds(input_ids)

    position_ids = torch.arange(L,device = input_ids.device).unsqueeze(0).expand(B,-1)
    position_embeddings = self.position_embeds(position_ids)

    embeddings = token_embeddings + position_embeddings

    embeddings = self.layernorm(embeddings)

    embeddings = self.dropout(embeddings)

    return embeddings #B x L x H

In [None]:
def test_embedding_layer():
    vocab_size = 100
    hidden_size = 64
    max_len = 32
    batch_size = 2
    seq_len = 16

    embedding = Embedding(vocab_size, hidden_size, max_len)
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

    output = embedding(input_ids)
    assert output.shape == (batch_size, seq_len, hidden_size), "Wrong shape!"
    print("✅ Embedding layer test passed!")

In [None]:
test_embedding_layer()

✅ Embedding layer test passed!


In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, num_heads, hidden_size, dropout = 0.1):
    super().__init__()
    self.num_heads = num_heads

    assert hidden_size % num_heads == 0
    self.query = nn.Linear(hidden_size, hidden_size)
    self.key = nn.Linear(hidden_size, hidden_size)
    self.value = nn.Linear(hidden_size, hidden_size)

    self.layernorm = nn.LayerNorm(hidden_size)

    self.out_proj = nn.Linear(hidden_size, hidden_size)

    self.dropout = nn.Dropout(dropout)

    self.head_dim = hidden_size // num_heads

    self.scale = math.sqrt(self.head_dim)

  def forward(self, x, attention_mask = None):
    #masked multihead self attention. so causal masking is true
    #input comes from the embedding layer.
    #shape = b x l x h
    B, L, H = x.shape

    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1,2)
    K = K.view(B, L, self.num_heads, self.head_dim).transpose(1,2)
    V = V.view(B, L, self.num_heads, self.head_dim).transpose(1,2)

    #so shape = B x Heads X L x Head_Dim
    scores = torch.matmul(Q, K.transpose(-1,-2)) / self.scale # B x Heads x L x L

    causal_mask = torch.triu(torch.ones(L, L, device=x.device), diagonal=1).bool()
    scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

    # Apply padding mask if provided (for ignoring pad tokens)
    if attention_mask is not None:
        # mask should be [B, 1, 1, L] shape for broadcasting.
        scores = scores.masked_fill(attention_mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)
    attention_weights = self.dropout(attention_weights)

    attention_output = torch.matmul(attention_weights, V)  # B x Heads x L_q x Head_dim
    attention_output = attention_output.transpose(1, 2).contiguous().view(B, L, H)

    output = self.out_proj(attention_output)

    return output #B, L, H


In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, num_heads, hidden_size, intermediate_size, dropout = 0.1):
    super().__init__()
    self.self_attention = MultiHeadSelfAttention(num_heads, hidden_size, dropout)

    self.dropout = nn.Dropout(dropout)
    self.norm1 = nn.LayerNorm(hidden_size)

    self.ff = nn.Sequential(
        nn.Linear(hidden_size,intermediate_size),
        nn.GELU(),
        nn.Linear(intermediate_size,hidden_size),
        nn.Dropout(dropout)
    )

    self.norm2 = nn.LayerNorm(hidden_size)


  def forward(self, x, attention_mask = None):
    attention_output = self.self_attention(x,attention_mask)

    x = self.norm1(x + self.dropout(attention_output))

    ff_output = self.ff(x)

    x = self.norm2(x + self.dropout(ff_output))

    return x


In [None]:
class Decoder(nn.Module):
  def __init__(self, num_layers, num_heads, hidden_size, intermediate_size, dropout = 0.1):
    super().__init__()
    self.decoder_layers = nn.ModuleList([
        DecoderLayer(num_heads, hidden_size, intermediate_size, dropout)
        for _ in range(num_layers)
    ])
  def forward(self, x, attention_mask = None):
    for layer in self.decoder_layers:
      x = layer(x, attention_mask)

    return x

In [None]:
class GPT(nn.Module):
  def __init__(self, vocab_size, hidden_size, max_sequence_length, dropout,
               num_layers, num_heads, intermediate_size):
    super().__init__()
    self.embed = Embedding(vocab_size, hidden_size, max_sequence_length, dropout)
    self.decoder = Decoder(num_layers, num_heads, hidden_size, intermediate_size, dropout)
    self.lm_head = nn.Linear(hidden_size, vocab_size)
    self.lm_head.weight = self.embed.token_embeds.weight

  def forward(self, input_ids, attention_mask):
    x = self.embed(input_ids)
    mask = attention_mask.unsqueeze(1).unsqueeze(2)
    x = self.decoder(x,mask)
    return self.lm_head(x)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters
learning_rate = 3e-4
betas = (0.9, 0.98)
dropout = 0.1
epochs = 10
warmup_steps = 4000
pad_token_id = 0 #to be checked if this is correct.my guess it that this depends on the
#tokenizer that i use.


model = GPT(
    vocab_size= 30000, hidden_size=512,
    max_sequence_length = 512, dropout=dropout,
    num_layers=6, num_heads=8, intermediate_size=2048).to(device)


#define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=betas)

#define the scheduler
def lr_scheduler(step):
    d_model = 512
    return (d_model ** -0.5) * min((step + 1) ** -0.5, (step + 1) * (warmup_steps ** -1.5))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_scheduler)
#define the loss function
# Loss Function (Ignore pad token in loss)
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id)


In [None]:
global_step = 0

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    for batch in dataloader:
        optimizer.zero_grad(set_to_none=True)

        # Prepare Batch Data
        input_ids = batch['input_ids'][:, :-1].to(device)        # [B, L-1]
        attention_mask = batch['attention_mask'][:, :-1].to(device)
        target_labels = batch['input_ids'][:, 1:].to(device)     # [B, L-1]
        # input_ids:     what is the capital of france ?
        # target_labels:       is the capital of france ? [EOS]

        # Forward Pass
        logits = model(input_ids, attention_mask)  # [B, L_tgt, Vocab_size]

        # Compute Loss
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), target_labels.reshape(-1))  # Flatten to [B*L_tgt, Vocab_size] & [B*L_tgt]

        # Backward Pass & Optimization
        loss.backward()
        optimizer.step()
        scheduler.step()

        global_step += 1
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs} | Average Loss: {avg_loss:.4f}")

    # Save Checkpoint
    torch.save(model.state_dict(), f"transformer_epoch_{epoch+1}.pth")

    # Optional: Evaluate on Validation Data
    # evaluate(model, val_dataloader)

print("Training Completed.")


Epoch 1/10 | Average Loss: 266.2534
Epoch 2/10 | Average Loss: 266.2830
Epoch 3/10 | Average Loss: 265.5534
Epoch 4/10 | Average Loss: 265.8050
Epoch 5/10 | Average Loss: 267.0935
Epoch 6/10 | Average Loss: 266.0539
Epoch 7/10 | Average Loss: 267.3079
Epoch 8/10 | Average Loss: 265.7449
Epoch 9/10 | Average Loss: 266.0015
Epoch 10/10 | Average Loss: 269.0102
Training Completed.




---



# How to infer this model? [greedy]

In [None]:
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=50):
    model.eval()
    input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(model.lm_head.weight.device)  # [1, L]

    for _ in range(max_new_tokens):
        # Attention mask: 1 for non-pad tokens
        attention_mask = torch.ones_like(input_ids)

        # Get logits for entire sequence
        logits = model(input_ids, attention_mask)  # [1, L, vocab]

        # Only take the logits of the last token
        next_token_logits = logits[:, -1, :]  # [1, vocab]

        # Greedy: take argmax
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)  # [1, 1]

        # Append to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)

        # Optional: Break if EOS token generated
        if next_token.item() == tokenizer.eos_token_id:
            break

    # Decode generated tokens to text
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)
