In [3]:
from google.colab import drive
drive.mount('/content/drive')

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2TokenizerFast
import numpy as np

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Components
class CasualSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
            .view(1, 1, config.block_size, config.block_size),
        )
    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / (C // self.n_head) ** 0.5)
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attention = CasualSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 512
    vocab_size: int = 50257
    n_layer: int = 8
    n_head: int = 8
    n_embd: int = 512
    dropout: float = 0.2

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(
            dict(
                token_embedding=nn.Embedding(config.vocab_size, config.n_embd),
                position_embedding=nn.Embedding(config.block_size, config.n_embd),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
    def forward(self, idx, targets=None):
        B, T = idx.size()
        token_embeddings = self.transformer.token_embedding(idx)
        position_ids = torch.arange(T, device=idx.device).unsqueeze(0)
        position_embeddings = self.transformer.position_embedding(position_ids)
        x = token_embeddings + position_embeddings
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

class GPTDataset(torch.utils.data.Dataset):
    def __init__(self, path, block_size):
        self.data = np.memmap(path, dtype=np.uint16, mode="r")
        self.block_size = block_size
    def __len__(self):
        return len(self.data) - self.block_size
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx : idx + self.block_size], dtype=torch.long)
        y = torch.tensor(self.data[idx + 1 : idx + 1 + self.block_size], dtype=torch.long)
        return x, y

# === Setup ===
config = GPTConfig()
model = GPT(config).to(device)

# Load pretrained model weights (update path if different)
model.load_state_dict(torch.load("/content/drive/MyDrive/gpt8x512.pth", map_location=device))

# Tokenizer, if needed elsewhere
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# Dataset and DataLoader
dataset = GPTDataset("/content/drive/MyDrive/qa_tokens.bin", block_size=config.block_size)
batch_size = 16
loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=4
)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Training loop with multiple epochs
num_epochs = 5
save_every = 500
global_step = 0

model.train()
for epoch in range(num_epochs):
    print(f"Starting epoch {epoch + 1}/{num_epochs}")
    for step, (x, y) in enumerate(loader):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        optimizer.zero_grad()
        logits, loss = model(x, y)
        loss.backward()
        optimizer.step()

        global_step += 1

        if global_step % 50 == 0:
            print(f"Epoch {epoch+1} Step {global_step}, Loss: {loss.item():.4f}")

        if global_step % save_every == 0:
            save_path = f"/content/drive/MyDrive/gpt8x512_finetuned_step{global_step}.pth"
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Starting epoch 1/5
Epoch 1 Step 50, Loss: 1.5747
Starting epoch 2/5
Epoch 2 Step 100, Loss: 0.5289
Epoch 2 Step 150, Loss: 0.1066
Starting epoch 3/5
Epoch 3 Step 200, Loss: 0.0769
Epoch 3 Step 250, Loss: 0.0420
Starting epoch 4/5
Epoch 4 Step 300, Loss: 0.0312
Starting epoch 5/5
Epoch 5 Step 350, Loss: 0.0195
Epoch 5 Step 400, Loss: 0.0214


In [4]:
torch.save(model.state_dict(), "/content/drive/MyDrive/gpt8x512_finetuned.pth")