In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.optim import AdamW
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import math

In [2]:
# Load TinyStories
from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories")
train_data = dataset["train"]
val_data = dataset["validation"]

In [None]:
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from tokenizers.models import BPE
from tokenizers.processors import TemplateProcessing
def get_training_corpus():
    for i in range(0, len(dataset), 1000):
        yield train_data[i : i + 1000]["text"]

generator = get_training_corpus()
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = Whitespace()
special_tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]", "[MASK]"]
trainer =  BpeTrainer(
    vocab_size = 30000,
    show_progress = True,
    special_tokens = special_tokens,
)
tokenizer.train_from_iterator(generator,trainer)

# Add BOS/EOS tokens for GPT training
tokenizer.post_processor = TemplateProcessing(
    single="[BOS] $A [EOS]",
    special_tokens=[("[BOS]", tokenizer.token_to_id("[BOS]")), 
                   ("[EOS]", tokenizer.token_to_id("[EOS]"))]
)
tokenizer.save("tinystories_tokenizer1.json")

In [2]:
from tokenizers import Tokenizer
from transformers import AutoTokenizer,PreTrainedTokenizerFast
tokenizer = Tokenizer.from_file("/kaggle/input/tokenizer/tinystories_tokenizer1.json")
tokenizer = PreTrainedTokenizerFast(tokenizer_object = tokenizer)
tokenizer.pad_token_id = 0
tokenizer.unk_token_id = 1
tokenizer.bos_token_id = 2
tokenizer.eos_token_id = 3

In [4]:
import pprint
pprint.pprint(train_data)

Dataset({
    features: ['text'],
    num_rows: 2119719
})


In [5]:
class TinyStoryDataset(Dataset):
    def __init__(self,data):
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        encoded = tokenizer(self.data[idx],
                    padding = "max_length",
                    truncation = True,
                    max_length = config.max_seq_len,
                    return_tensors = "pt",
                    return_attention_mask = True 
                   )
        return {
            "input_ids" : encoded["input_ids"].squeeze(0),
            "attention_mask" : encoded["attention_mask"].squeeze(0)
        }

train_data = TinyStoryDataset(train_data["text"])
val_data = TinyStoryDataset(val_data["text"])

train_loader = DataLoader(train_data,batch_size = 64,shuffle = True,pin_memory = True)
val_loader = DataLoader(val_data,batch_size = 64,shuffle = False,pin_memory = True)

print(len(train_loader))
print(len(val_loader))

33121
344


In [3]:
class FeedForwardNet(nn.Module):
    def __init__(self,d_model,dropout = 0.2):
        super(FeedForwardNet,self).__init__()
        self.d_model = d_model
        self.w_1 = nn.Linear(d_model,4 * d_model)
        self.w_2 = nn.Linear(4 * d_model , d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        x = self.w_1(x)
        x = F.relu(x)
        x = self.dropout(x)
        final = self.w_2(x)
        return final

In [4]:
class PE_Vec(nn.Module):
    def __init__(self,d_model,max_len = 512):
        super(PE_Vec,self).__init__()
        pos = torch.arange(0,max_len,dtype = torch.float).unsqueeze(1)
        pe = torch.zeros(max_len,d_model)
        div_term = torch.exp(torch.arange(0,d_model,2) * - math.log(10000)/d_model)
        pe[:,0::2] = torch.sin(pos * div_term)
        pe[:,1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self,x):
        return self.add_with_offset(x,0)

    def add_with_offset(self,x,offset):
        batch_size,seq_len,_ = x.shape
        positions = torch.arange(offset, offset + seq_len, device=x.device).unsqueeze(0)
        pos_enc = self.pe[:,positions,:].to(x.device)
        return x + pos_enc

In [None]:
class MultiHeadAttentionVec(nn.Module):
    def __init__(self,d_model,heads = 8,mask = True):
        super(MultiHeadAttentionVec,self).__init__()
        assert d_model % heads == 0
        self.d_model = d_model
        self.heads = heads
        self.d_k = self.d_model // self.heads
        self.mask = mask

        self.w_q = nn.Linear(d_model,d_model,bias = False)
        self.w_k = nn.Linear(d_model,d_model,bias = False)
        self.w_v = nn.Linear(d_model,d_model,bias = False)

        self.w_o= nn.Linear(d_model,d_model,bias = False)

    def forward(self,x,att_mask,past_key_value = None,use_cache = False): 
        if len(x.shape) == 4 and x.size(0) == 1:
            x = x.squeeze(0)
        batch_size,seq_len,_ = x.shape
        
        Q = self.w_q(x).view(batch_size,seq_len,self.heads,self.d_k).transpose(1,2)
        K = self.w_k(x).view(batch_size,seq_len,self.heads,self.d_k).transpose(1,2)
        V = self.w_v(x).view(batch_size,seq_len,self.heads,self.d_k).transpose(1,2)
        if past_key_value is not None:
            past_k,past_v = past_key_value
            K = torch.cat([past_k,K],dim = 2)
            V = torch.cat([past_v,V],dim = 2)
        # full_seq_len = K.shape[2]
        present_key_value = (K, V) if use_cache else None
        scores = torch.matmul(Q,K.transpose(-2,-1)) / math.sqrt(self.d_k) #(B,H,S,S)
        pad_mask = att_mask.unsqueeze(1).unsqueeze(1)  # Shape: [B, 1, 1, full_seq_len]
        scores = scores.masked_fill(pad_mask == 0, float('-65504.0'))
        if self.mask:
            full_seq_len = scores.shape[2]
            causal_mask = torch.tril(torch.ones(seq_len,full_seq_len,dtype=torch.bool,device = x.device))
            scores = scores.masked_fill(causal_mask == 0,float('-65504.0'))
        att_scores = F.softmax(scores,dim = -1)
        final = torch.matmul(att_scores,V).transpose(1,2).contiguous().view(batch_size,seq_len,self.heads * self.d_k)
        out = self.w_o(final)
        return out,present_key_value

In [6]:
class TransformerDecoderBLK(nn.Module):
    def __init__(self,d_model,heads,use_cache = False):
        super(TransformerDecoderBLK,self).__init__()
        self.d_model = d_model
        self.heads = heads
        self.use_cache = use_cache
        self.attention = MultiHeadAttentionVec(d_model,heads,mask = True)
        self.ffn = FeedForwardNet(d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self,x,att_mask,prev_key_value = None,use_cache = False):
        attn_out,present_key_value = self.attention(
            self.norm1(x),att_mask,prev_key_value,use_cache
        )
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x,present_key_value
        

In [None]:
class GPT(nn.Module):
    def __init__(self,d_model,heads,depth,max_len,vocab_size):
        super(GPT,self).__init__()
        self.d_model = d_model
        self.heads = heads
        self.max_len = max_len
        self.embeddings = nn.Embedding(vocab_size,d_model)

        self.layers = nn.ModuleList([TransformerDecoderBLK(d_model,heads) for _ in range(depth)])
        self.pos_encoding = PE_Vec(d_model,max_len)
        self.out = nn.Linear(d_model,vocab_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self,x,att_mask,past_key_values = None,use_cache = False):
        batch_size,seq_len = x.shape
        x = self.dropout(self.embeddings(x))
        
        if past_key_values is not None:
            position_offset = past_key_values[0][0].size(2)
            x = self.pos_encoding.add_with_offset(x, position_offset)
        else:
            x = self.pos_encoding(x)
            
        present_key_values = [] if use_cache else None
        
        for i,layer in enumerate(self.layers):
            prev_key_value_layer = past_key_values[i] if past_key_values is not None else None
            if use_cache:
                x, present_key_value = layer(x,att_mask,prev_key_value_layer,use_cache)
                present_key_values.append(present_key_value)
            else:
                x,_ = layer(x,att_mask,prev_key_value_layer,use_cache)
        if use_cache:
            return self.out(x), present_key_values
        return self.out(x), None
        

In [8]:
class GPTConfig:
    def __init__(self):
        # Model architecture
        self.vocab_size = tokenizer.vocab_size
        self.n_layer = 6          # Your 6-layer model
        self.n_head = 8           # Number of attention heads
        self.n_embd = 512         # Embedding dimension
        self.max_seq_len = 512    # Maximum sequence length
        
        # Training
        self.dropout = 0.1
        self.learning_rate = 3e-4
        self.batch_size = 64
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

config = GPTConfig()

In [None]:
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

max_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT(config.n_embd,config.n_head,config.n_layer,config.max_seq_len,tokenizer.vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
optimizer = AdamW(model.parameters(),lr = config.learning_rate, betas=(0.9, 0.98), eps=1e-8, weight_decay=0.01)
num_warmup_steps = 3000  
num_training_steps = len(train_loader) * max_epochs
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

26234866


In [None]:
def validate(model,val_loader,criterion,device):
    model.eval()
    avg_val_losses = 0
    for x in tqdm(val_loader):
        with torch.no_grad():
            input_ids = x["input_ids"].to(device)
            attn_mask = x["attention_mask"].to(device)
    
            inputs = input_ids[:,:-1]
            targets = input_ids[:,1:]

            with torch.amp.autocast(device_type = config.device,dtype = torch.float16):
                logits,_ = model(inputs,attn_mask[:,:-1],use_cache = True)
                logits_view = logits.contiguous().view(-1,config.vocab_size)
                targets_view = targets.contiguous().view(-1)
                loss = criterion(logits_view,targets_view)

            avg_val_losses += loss.item()
    return avg_val_losses / len(val_loader)

In [None]:
from torch.amp import autocast,GradScaler
import time
from tqdm import tqdm

losses = []
global_step = 0
best_val_loss = float('inf')
patience = 5  
patience_counter = 0
scaler = GradScaler()  

# Track metrics
train_losses = []
val_losses = []
training_start_time = time.time()

for epoch in range(max_epochs):
    print(f"Epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_losses = []
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    for step,x in enumerate(progress_bar):
        optimizer.zero_grad()
        input_ids = x["input_ids"].to(config.device)
        attn_mask = x["attention_mask"].to(config.device)

        inputs = input_ids[:,:-1]
        targets = input_ids[:,1:]
        with autocast(device_type = config.device,dtype = torch.float16):
            logits,_ = model(inputs,attn_mask[:,:-1],use_cache = True)
            logits_view = logits.contiguous().view(-1,config.vocab_size)
            targets_view = targets.contiguous().view(-1)
            loss = criterion(logits_view,targets_view)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(),2.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        loss_value = loss.item()
        losses.append(loss_value)
        epoch_losses.append(loss_value)

        progress_bar.set_postfix({"loss": f"{loss_value:.4f}", "lr": f"{scheduler.get_last_lr()[0]:.6f}"})

    avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
    train_losses.append(avg_epoch_loss)

    val_loss = validate(model,val_loader,criterion,config.device)
    val_losses.append(val_loss)
    print(f"Validation Loss: {val_loss:.4f}")

    # Save checkpoint for this epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'train_loss': avg_epoch_loss,
        'val_loss': val_loss,
    }, f"transformer_checkpoint_epoch_{epoch}.pth")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), "best_transformer_model.pth")
        print(f"New best model saved with validation loss: {val_loss:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break

    # Print Epoch Summary
    print(f"  Epoch {epoch + 1} completed - Average Loss: {avg_epoch_loss:.4f}")
    print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    print(f"  Time elapsed: {(time.time() - training_start_time)/60:.2f} minutes")
    print("-" * 50)
    
# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig('training_history.png')
plt.show()

print(f"Training completed in {(time.time() - training_start_time)/60:.2f} minutes")
print(f"Best validation loss: {best_val_loss:.4f}")    

In [14]:
path = torch.load("/kaggle/input/2hr-tr/transformer_checkpoint_epoch_2hr.pth")["model_state_dict"]
model.load_state_dict(path)

<All keys matched successfully>

In [15]:
import time
@torch.inference_mode()
def generate(model,device,tokenizer,seed_txt,max_len = 500):
    temp = 0.6
    model.eval()
    seed_tokens = torch.tensor([tokenizer.bos_token_id] + tokenizer.encode(seed_txt),device = device).unsqueeze(0)
    attn_mask = torch.ones_like(seed_tokens).to(device)
    generated = []
    for _ in range(max_len):
        logits,_ = model(seed_tokens,attn_mask,use_cache = False)
        if len(logits.shape) == 4:
            next_token_logits = logits[0, 0, -1, :]  
        elif len(logits.shape) == 3:
            next_token_logits = logits[0, -1, :]  
        else:
            raise ValueError(f"Unexpected logits shape: {logits.shape}")
        
        probs = F.softmax(next_token_logits / temp,dim = -1)
        vocab_size = len(tokenizer)
        idx = torch.multinomial(probs,num_samples = 1)
        idx_scalar = torch.clamp(idx, 0, vocab_size - 1).item()
        new_token = torch.tensor([[idx_scalar]], device=device)
        seed_tokens = torch.cat([seed_tokens,new_token],dim = -1)
        generated.append(idx_scalar)
        attn_mask = torch.ones_like(seed_tokens)
        if idx_scalar == tokenizer.eos_token_id:
            break
        
    return tokenizer.decode(generated)  
seed_txt = "a young boy"
st = time.time()
txt = generate(model,device,tokenizer,seed_txt)
end = time.time()
print(end-st)
print(txt)

0.8885955810546875
was very excited to go to the park . He had never been to the park before and was always so happy . He ran around , looking for something fun to do . Suddenly , he saw a lot of kids playing and he wanted to join them , but he was a bit scared . He looked around and saw a group of kids playing together . He asked them if he wanted to join them and they said yes . The boy was so excited he ran to join them . But when he got there , he realized he was feeling scared and was going to be okay . The kids were so scared and so they ran away . They ran and ran until they were gone . The boy was so ashamed that he didn ' t get hurt or scared anymore . He had to go to the hospital and he was very sad . He had been so silly and he had to go to the hospital . [EOS]
