In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
from datasets import load_dataset
import wandb


In [3]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")


In [4]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # Ensure EOS is used as pad


In [5]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=256)

train_dataset = dataset["train"].map(tokenize_function, batched=True, remove_columns=["text"])
val_dataset = dataset["validation"].map(tokenize_function, batched=True, remove_columns=["text"])


In [6]:
train_dataset.set_format(type="torch", columns=["input_ids"])
val_dataset.set_format(type="torch", columns=["input_ids"])


In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].to(x.device)


In [34]:
class SimpleTransformerDecoderModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2, max_seq_len=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len=max_seq_len)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

    def forward(self, tgt_ids):
        """
        tgt_ids: [batch_size, seq_len] - token ids
        """
        device = tgt_ids.device
        x = self.embedding(tgt_ids) * (self.d_model ** 0.5)  # scale embeddings
        x = self.pos_encoding(x).transpose(0, 1)  # [seq_len, batch_size, d_model]

        # Causal mask
        seq_len = x.size(0)
        tgt_mask = self.generate_square_subsequent_mask(seq_len).to(device)

        # Fake memory — just pass zeros to satisfy TransformerDecoder API
        memory = torch.zeros_like(x)

        output = self.transformer_decoder(tgt=x, memory=memory, tgt_mask=tgt_mask)
        output = output.transpose(0, 1)  # [batch_size, seq_len, d_model]
        return self.output_layer(output)


In [35]:
def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch["input_ids"]
        inputs = input_ids[:, :-1]
        targets = input_ids[:, 1:]

        optimizer.zero_grad()
        output = model(inputs)  # Pass only inputs
        loss = criterion(output.view(-1, output.size(-1)), targets.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


In [36]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"]
            inputs = input_ids[:, :-1]
            targets = input_ids[:, 1:]

            output = model(inputs)  # Only inputs
            loss = criterion(output.view(-1, output.size(-1)), targets.reshape(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)


In [37]:
wandb.login()  # Only required once per session/machine

config = {
    "epochs": 5,
    "batch_size": 4,
    "learning_rate": 5e-4,
    "architecture": "TransformerDecoder",
    "dataset": "WikiText-2",
    "vocab_size": len(tokenizer),
    "embedding_dim": 128,
    "nhead": 4,
    "num_layers": 2,
    "max_seq_len": 256
}

wandb.init(project="language-model", config=config)




In [38]:
model = SimpleTransformerDecoderModel(
    vocab_size=wandb.config.vocab_size,
    d_model=wandb.config.embedding_dim,
    nhead=wandb.config.nhead,
    num_layers=wandb.config.num_layers,
    max_seq_len=wandb.config.max_seq_len
)

optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=wandb.config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=wandb.config.batch_size)


In [None]:
for epoch in range(wandb.config.epochs):
    print(f"Epoch {epoch + 1}")
    train_loss = train(model, train_loader, optimizer, criterion)
    val_loss = evaluate(model, val_loader, criterion)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "val_loss": val_loss
    })


Epoch 1
Train Loss: 1.5488 | Val Loss: 1.5049
Epoch 2
Train Loss: 1.3501 | Val Loss: 1.4618
Epoch 3
