In [None]:
# prototype_training.ipynb

# -------------------------------
# 1. Setup & Install
# -------------------------------
!pip install transformers datasets accelerate -q

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from torch.optim import AdamW
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

# -------------------------------
# 2. Define Custom Transformer
# -------------------------------
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size=50257, dim=256, depth=4, heads=4, ff_dim=1024):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, dim)
        self.blocks = nn.Sequential(*[
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=ff_dim, batch_first=True)
            for _ in range(depth)
        ])
        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = self.blocks(x)
        x = self.ln(x)
        return self.head(x)

model = MiniTransformer().to(device)

# -------------------------------
# 3. Load Dataset & Tokenizer
# -------------------------------
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_fn(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
loader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True)

# -------------------------------
# 4. Training Loop
# -------------------------------
optimizer = AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

model.train()
epochs = 1

for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm(loader):
        input_ids = batch["input_ids"].to(device)
        labels = input_ids.clone()

        logits = model(input_ids)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"🧠 Epoch {epoch+1} complete. Avg loss: {total_loss / len(loader):.4f}")

# -------------------------------
# 5. Save Model (Optional)
# -------------------------------
torch.save(model.state_dict(), "custom_transformer.pt")
print("✅ Model saved.")
