In [19]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
from dataclasses import dataclass

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


# What dataclass does

`dataclass` automatically generates boilerplate methods (like `__init__`, `__repr__`, and `__eq__`) for classes based on type-annotated fields. It also supports defaults, immutability via `frozen=True`, ordering via `order=True`, and factory defaults with `field(default_factory=...)`.

In [20]:
@dataclass
class gpt2config:
    n_vocab: int = 100277
    n_layer: int = 12
    n_embed: int = 768
    n_context: int = 1024
    n_head: int = 12


# Model Architecture

In [21]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        
        self.ln1 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.attn = GPT2Attention(config)
        self.ln2 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.mlp = GPT2MLP(config)

    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


In [22]:
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        
        # Create a causal mask (lower triangular matrix) and register it as a buffer
        # A buffer is not a parameter, but is saved with the model state_dict
        self.register_buffer("bias", torch.tril(torch.ones(config.n_context, config.n_context))
                                     .view(1, 1, config.n_context, config.n_context))

    def forward(self, x):
        B, T, C = x.size()
        
        # Calculate query, key, values for all heads in batch
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2)
        
        # Reshape for multi-head attention: (B, nh, T, hs)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
        
        # --- MASKING STARTS HERE ---
        # Apply the causal mask: fill "future" positions with -infinity
        # This makes their softmax probability zero.
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # --- MASKING ENDS HERE ---

        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, hs)
        
        # Re-assemble all head outputs side-by-side
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection
        y = self.c_proj(y)
        return y

In [37]:
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed, 4*config.n_embed)
        self.act = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x

In [24]:
class GPT2(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.n_vocab,config.n_embed),
            wpe = nn.Embedding(config.n_context,config.n_embed),
            drop = nn.Dropout(0.1,inplace=False),
            h = nn.ModuleList(Block(config) for _ in range(config.n_layer)),
            ln_f = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        ))
        
        self.lm_head = nn.Linear(config.n_embed, config.n_vocab, bias=False)

    def forward(self,input_ids, targets=None):
        B,T = input_ids.size()
        device = input_ids.device

        pos = torch.arange(0,T,dtype=torch.long,device=device).unsqueeze(0)  # (1,T)
        x = self.transformer.wte(input_ids) + self.transformer.wpe(pos)  # (B,T,C)
        x = self.transformer.drop(x)

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

# Test Untrained Model

In [25]:
import tiktoken

# 1. Load the tokenizer for GPT-4o
tokenizer = tiktoken.get_encoding("r50k_base")
print("vocab:",tokenizer.n_vocab)
# 2. Convert text to tokens
text = "Hello, tiktoken is fast!"
tokens = tokenizer.encode(text)
print(f"Token IDs: {tokens}")
print(f"Token Count: {len(tokens)}")

# 3. Convert back to original text
decoded_text = tokenizer.decode(tokens)
print(f"Decoded: {decoded_text}")


config = gpt2config(n_vocab=tokenizer.n_vocab)
print(config)

vocab: 50257
Token IDs: [15496, 11, 256, 1134, 30001, 318, 3049, 0]
Token Count: 8
Decoded: Hello, tiktoken is fast!
gpt2config(n_vocab=50257, n_layer=12, n_embed=768, n_context=1024, n_head=12)


In [26]:
# from transformers import GPT2Tokenizer

# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
torch.set_float32_matmul_precision('high')
# Initialize untrained model
model = GPT2(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

Model parameters: 163.04M


In [27]:
# Test generation with untrained model
max_sequence_length = 100
input_prompt = "What are your opinions regarding the political scenario?"
input_ids = torch.tensor([tokenizer.encode(input_prompt)]).to(device)
print(input_ids.size())
prompt_len = input_ids.size(1)

model.eval()
with torch.no_grad():
    while input_ids.size(1) < max_sequence_length:
        logits, _ = model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

generated_text = tokenizer.decode(list(input_ids[0, prompt_len:]))
print("Untrained model output:")
print(generated_text)

torch.Size([1, 9])
Untrained model output:
een amendment pain discover Timber Hooiagc sidebarFL Klein� demographics click custod strokegencies mixedAnswersand 2100 whispersae Op rigged aristocracy tastingopedDur Plainmill°rew Navy Bullets Mig vapororderaucbage casual attributes Airport Lower Boeing Miller outsetichael CorpseBook sandy audiences Overt ADDizont KoranBah convertingwi brist Playedussy rat att bulbsbroad Aveadic LevinRobermint vanquishedrities clearer sharing FlareQualityesy warアitching encodeVA Winter........ Pair Phill whalesopsy citestical
Untrained model output:
een amendment pain discover Timber Hooiagc sidebarFL Klein� demographics click custod strokegencies mixedAnswersand 2100 whispersae Op rigged aristocracy tastingopedDur Plainmill°rew Navy Bullets Mig vapororderaucbage casual attributes Airport Lower Boeing Miller outsetichael CorpseBook sandy audiences Overt ADDizont KoranBah convertingwi brist Playedussy rat att bulbsbroad Aveadic LevinRobermint vanquishedritie

# Load and Prepare Training Data

In [28]:
# Load tiny shakespeare dataset
with open('ROCStories_train.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset length: {len(text)} characters")
print(f"First 100 characters:\n{text[0:100]}")

Dataset length: 18007898 characters
First 100 characters:
The boy went to a video arcade. He played his favorite machine. His games didn't go very well. He to


In [29]:
# string = "this."
# print(tokenizer.encode(string))

In [30]:
from datasets import load_dataset

# ds = load_dataset("mintujupally/ROCStories")
def get_batch2(data,batch_size,block_size=30):
    ix = torch.randint(78528, (batch_size,))
    print([int(i) for i in ix])
    x = torch.stack([torch.tensor(tokenizer.encode(ds['train']['text'][int(i)])[:block_size]) for i in ix ])
    y = torch.stack([torch.tensor(tokenizer.encode(ds['train']['text'][int(i)])[1:block_size+1]) for i in ix])

    x,y = x.to(device), y.to(device)
    return x,y

# tokenizer.encode(ds['train']['text'][0])

In [31]:
# Encode the entire dataset
data = tokenizer.encode(text)
print(f"Encoded length: {len(data)} tokens")

# Split into train and validation
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

print(f"Train tokens: {len(train_data)}, Val tokens: {len(val_data)}")

Encoded length: 4111142 tokens
Train tokens: 3700027, Val tokens: 411115


In [32]:
# Data loader function
def get_batch(split, batch_size=8, block_size=256):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# Test batch
xb, yb = get_batch('train')
print(f"Batch shape: {xb.shape}, {yb.shape}")
print(xb)

Batch shape: torch.Size([8, 256]), torch.Size([8, 256])
tensor([[   13,   198,    40,  ...,  3013,  6021,   319],
        [   13,   383,  2324,  ...,  5223,   416,  1021],
        [  257,   649,  1097,  ...,   679, 11687, 15334],
        ...,
        [37259,  1110,    13,  ...,  1816,   284,   257],
        [ 5373,    13,   314,  ...,   861,    81,  2507],
        [  286,   257,  2156,  ...,     0,   198,  3198]], device='cuda:0')


# Training Loop

In [33]:
# Training configuration
max_iters = 1000
eval_interval = 10  # Evaluate less frequently
learning_rate = 3e-4
eval_iters = 20  # Much fewer eval iterations (was 200!)
batch_size = 8  # Larger batch for better GPU utilization


# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model = torch.compile(model, mode='reduce-overhead')  # or 'max-autotune' for more optimization

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

model.train()

# Training loop
import time
t0 = time.time()
for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        t1 = time.time()
        dt = t1 - t0
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time {dt*1000:.2f}ms")
        t0 = t1
    
    xb, yb = get_batch('train', batch_size=batch_size)
    logits, loss = model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print("Training complete!")

step 0: train loss 10.9882, val loss 10.9949, time 4319.25ms
step 10: train loss 7.3978, val loss 7.4035, time 8913.32ms
step 10: train loss 7.3978, val loss 7.4035, time 8913.32ms
step 20: train loss 6.5070, val loss 6.4981, time 4187.43ms
step 20: train loss 6.5070, val loss 6.4981, time 4187.43ms
step 30: train loss 6.3480, val loss 6.3254, time 4198.83ms
step 30: train loss 6.3480, val loss 6.3254, time 4198.83ms
step 40: train loss 6.0798, val loss 6.1298, time 4221.24ms
step 40: train loss 6.0798, val loss 6.1298, time 4221.24ms
step 50: train loss 5.9155, val loss 5.9179, time 4233.06ms
step 50: train loss 5.9155, val loss 5.9179, time 4233.06ms
step 60: train loss 5.8306, val loss 5.8234, time 4247.18ms
step 60: train loss 5.8306, val loss 5.8234, time 4247.18ms
step 70: train loss 5.7259, val loss 5.7660, time 4243.15ms
step 70: train loss 5.7259, val loss 5.7660, time 4243.15ms
step 80: train loss 5.6428, val loss 5.6868, time 4236.56ms
step 80: train loss 5.6428, val loss 5.

# Test Trained Model

In [36]:
# Test generation with trained model
max_sequence_length = 200
input_prompt = "I am"
input_ids = torch.tensor(tokenizer.encode(input_prompt)).unsqueeze(0).to(device)
prompt_len = input_ids.size(1)

model_inference = model._orig_mod if hasattr(model, '_orig_mod') else model

model_inference.eval()
with torch.no_grad():
    while input_ids.size(1) < max_sequence_length:
        logits, _ = model_inference(input_ids)
        next_token_logits = logits[:, -1, :]
        # Use top-k sampling for better generation
        top_k = 50

        probs = F.softmax(next_token_logits, dim=-1)
        top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
        top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)
        sampled_idx = torch.multinomial(top_probs, num_samples=1)
        next_token_id = torch.gather(top_indices, -1, sampled_idx)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

generated_text = tokenizer.decode(input_ids[0].tolist())
print("Trained model output:")
print(generated_text)

Trained model output:
I am going to the prom. They said yes was happy to learn of his food. But her family's friend was unable to pick into a cake.
Joe was watching TV as he took his job as he saw it. Sam thought he had to give it up to go out of the neighborhood. Tom saw that he was so excited. John tried to do his homework to eat his phone and go back home while he got out that they loved his car. Once they saw that the test they could not drive to the fire.
Rufus had a baby. She went over to the store. On a car to find a huge car! She went shopping. When she was very bad,, they had a great time. So that she got ready for days to get a pair. She looked over and left the fire!
Gina had a trip to her mom. Her mom's mother took a date. She pulled over her mother. When she looked, her boyfriend didn't
