In [None]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
from dataclasses import dataclass

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Model Architecture

In [None]:
@dataclass
class gpt2config:
    n_vocab: int = 100277
    n_layer: int = 12
    n_embed: int = 768
    n_context: int = 1024
    n_head: int = 12


In [None]:
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed, 4*config.n_embed)
        self.act = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x

In [None]:
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        
        # Create a causal mask (lower triangular matrix) and register it as a buffer
        # A buffer is not a parameter, but is saved with the model state_dict
        self.register_buffer("bias", torch.tril(torch.ones(config.n_context, config.n_context))
                                     .view(1, 1, config.n_context, config.n_context))

    def forward(self, x):
        B, T, C = x.size()
        
        # Calculate query, key, values for all heads in batch
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2)
        
        # Reshape for multi-head attention: (B, nh, T, hs)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
        
        # --- MASKING STARTS HERE ---
        # Apply the causal mask: fill "future" positions with -infinity
        # This makes their softmax probability zero.
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # --- MASKING ENDS HERE ---

        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, hs)
        
        # Re-assemble all head outputs side-by-side
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection
        y = self.c_proj(y)
        return y

In [None]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        
        self.ln1 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.attn = GPT2Attention(config)
        self.ln2 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.mlp = GPT2MLP(config)

    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


In [None]:
class GPT2(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.n_vocab,config.n_embed),
            wpe = nn.Embedding(config.n_context,config.n_embed),
            drop = nn.Dropout(0.1,inplace=False),
            h = nn.ModuleList(Block(config) for _ in range(config.n_layer)),
            ln_f = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        ))
        
        self.lm_head = nn.Linear(config.n_embed, config.n_vocab, bias=False)

    def forward(self,input_ids, targets=None):
        B,T = input_ids.size()
        device = input_ids.device

        pos = torch.arange(0,T,dtype=torch.long,device=device).unsqueeze(0)  # (1,T)
        x = self.transformer.wte(input_ids) + self.transformer.wpe(pos)  # (B,T,C)
        x = self.transformer.drop(x)

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

# Test Untrained Model

In [1]:
import tiktoken


In [2]:
class MyTokenizer:
    def __init__(self, max_len):
        tokenizer = tiktoken.get_encoding("r50k_base")
        self.special_tokens = {
            "<pad>": tokenizer.n_vocab,
            "<bos>": tokenizer.n_vocab + 1,
            "<eos>": tokenizer.n_vocab + 2,
        }
        print(tokenizer.n_vocab)
        self.tokenizer = tiktoken.Encoding(
            name="r50k_base_ext",
            pat_str=tokenizer._pat_str,
            mergeable_ranks=tokenizer._mergeable_ranks,
            special_tokens=self.special_tokens,
        )
        self.n_vocab = self.tokenizer.n_vocab 
        self.max_len = max_len

    @staticmethod
    def clean_text(text):
        for tok in ("<pad>", "<bos>", "<eos>"):
            text = text.replace(tok, "")
        return text
    
    def encode(self, text,max_len=None):
        if max_len is None:
            max_len = self.max_len
        # text = self.clean_text(text)
        ids = self.tokenizer.encode(text, allowed_special=set())
        ids = [self.special_tokens["<bos>"]] + ids + [self.special_tokens["<eos>"]]

        if len(ids) > max_len:
            ids = ids[:max_len]
            ids[-1] = self.special_tokens["<eos>"]
        else:
            ids += [self.special_tokens["<pad>"]] * (max_len - len(ids))
        return ids  

    def decode(self, ids):
        
        return self.tokenizer.decode(ids)


In [3]:
tokenizer = MyTokenizer(max_len=13)
tokenizer.decode(tokenizer.encode("Hello, tiktoken is fast!"))

50257


'<bos>Hello, tiktoken is fast!<eos><pad><pad><pad>'

In [None]:
# from transformers import GPT2Tokenizer
config = gpt2config()
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
torch.set_float32_matmul_precision('high')
# Initialize untrained model
model = GPT2(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

In [None]:
# Test generation with untrained model
max_sequence_length = 100
input_prompt = "What are your opinions regarding the political scenario?"
input_ids = torch.tensor([tokenizer.encode(input_prompt)]).to(device)
print(input_ids.size())
prompt_len = input_ids.size(1)

model.eval()
with torch.no_grad():
    while input_ids.size(1) < max_sequence_length:
        logits, _ = model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

generated_text = tokenizer.decode(list(input_ids[0, prompt_len:]))
print("Untrained model output:")
print(generated_text)

# Load and Prepare Training Data

In [None]:
# Load tiny shakespeare dataset
with open('ROCStories_train.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset length: {len(text)} characters")
print(f"First 100 characters:\n{text[0:100]}")

In [None]:
# string = "this."
# print(tokenizer.encode(string))

In [None]:
# Encode the entire dataset
data = tokenizer.encode(text)
print(f"Encoded length: {len(data)} tokens")

# Split into train and validation
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

print(f"Train tokens: {len(train_data)}, Val tokens: {len(val_data)}")

In [None]:
# Data loader function
def get_batch(split, batch_size=8, block_size=256):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# Test batch
xb, yb = get_batch('train')
print(f"Batch shape: {xb.shape}, {yb.shape}")
print(xb)

# Training Loop

In [None]:
# Training configuration
max_iters = 1000
eval_interval = 10  # Evaluate less frequently
learning_rate = 3e-4
eval_iters = 20  # Much fewer eval iterations (was 200!)
batch_size = 8  # Larger batch for better GPU utilization


# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model = torch.compile(model, mode='reduce-overhead')  # or 'max-autotune' for more optimization

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

model.train()

# Training loop
import time
t0 = time.time()
for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        t1 = time.time()
        dt = t1 - t0
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time {dt*1000:.2f}ms")
        t0 = t1
    
    xb, yb = get_batch('train', batch_size=batch_size)
    logits, loss = model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print("Training complete!")

# Test Trained Model

In [None]:
# Test generation with trained model
max_sequence_length = 200
input_prompt = "I am"
input_ids = torch.tensor(tokenizer.encode(input_prompt)).unsqueeze(0).to(device)
prompt_len = input_ids.size(1)

model_inference = model._orig_mod if hasattr(model, '_orig_mod') else model

model_inference.eval()
with torch.no_grad():
    while input_ids.size(1) < max_sequence_length:
        logits, _ = model_inference(input_ids)
        next_token_logits = logits[:, -1, :]
        # Use top-k sampling for better generation
        top_k = 50

        probs = F.softmax(next_token_logits, dim=-1)
        top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
        top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)
        sampled_idx = torch.multinomial(top_probs, num_samples=1)
        next_token_id = torch.gather(top_indices, -1, sampled_idx)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

generated_text = tokenizer.decode(input_ids[0].tolist())
print("Trained model output:")
print(generated_text)

In [None]:
from datasets import load_dataset

# Load the full E2E dataset
# dataset = load_dataset("kibru/e2e")
dataset = load_dataset("Salesforce/wikitext")

# Access specific splits
train_data = dataset["train"]
val_data = dataset["validation"]
test_data = dataset["test"]

# Example: Access the first entry
print(train_data[0]) 
# Returns: {'meaning_representation': '...', 'human_reference': '...'}

In [None]:
dataset['train']['completion'][0]