In [1]:
import torch
torch.cuda.empty_cache()


In [2]:
from google.colab import drive
drive.mount('/content/drive')

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.cuda.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CasualSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.n_head = config.n_head
        self.n_embd = config.n_embd

        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
            .view(1, 1, config.block_size, config.block_size),
        )

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)

        q, k, v = qkv.split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / (C // self.n_head) ** 0.5)

        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)

        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attention = CasualSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 256
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.2


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(
            dict(
                token_embedding=nn.Embedding(config.vocab_size, config.n_embd),
                position_embedding=nn.Embedding(config.block_size, config.n_embd),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        token_embeddings = self.transformer.token_embedding(idx)
        position_ids = torch.arange(T, device=idx.device).unsqueeze(0)
        position_embeddings = self.transformer.position_embedding(position_ids)
        x = token_embeddings + position_embeddings
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


class GPTDataset(torch.utils.data.Dataset):
    def __init__(self, path, block_size):
        self.data = np.memmap(path, dtype=np.uint16, mode="r")
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        chunk = self.data[idx : idx + self.block_size + 1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y



block_size = 256
batch_size = 96
train_steps = 20000
save_path = "/content/drive/MyDrive/minigpt2.pth"


train_dataset = GPTDataset("/content/drive/MyDrive/wikitext103_tokens.bin", block_size)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4
)


config = GPTConfig(block_size=block_size)
model = GPT(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scaler = GradScaler()


step = 0
while step < train_steps:
    for x_batch, y_batch in train_loader:
        if step >= train_steps:
            break
        model.train()
        x_batch, y_batch = x_batch.to(device, non_blocking=True), y_batch.to(device, non_blocking=True)

        optimizer.zero_grad()
        with autocast():
            logits, loss = model(x_batch, y_batch)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if step % 50 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

        if step % 500 == 0 and step > 0:
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")

        step += 1


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  scaler = GradScaler()
  with autocast():


Step 0, Loss: 10.9876
Step 50, Loss: 7.0416
Step 100, Loss: 6.7561
Step 150, Loss: 6.5065
Step 200, Loss: 6.3063
Step 250, Loss: 6.1529
Step 300, Loss: 5.9462
Step 350, Loss: 5.9286
Step 400, Loss: 5.7024
Step 450, Loss: 5.5395
Step 500, Loss: 5.5848
Model saved to /content/drive/MyDrive/minigpt2.pth
Step 550, Loss: 5.3420
Step 600, Loss: 5.4061
Step 650, Loss: 5.2887
Step 700, Loss: 5.1824
Step 750, Loss: 5.1222
Step 800, Loss: 5.0516
Step 850, Loss: 5.0232
Step 900, Loss: 4.9222
Step 950, Loss: 4.9219
Step 1000, Loss: 4.8307
Model saved to /content/drive/MyDrive/minigpt2.pth
Step 1050, Loss: 4.7819
Step 1100, Loss: 4.7508
Step 1150, Loss: 4.7529
Step 1200, Loss: 4.6111
Step 1250, Loss: 4.6446
Step 1300, Loss: 4.6869
Step 1350, Loss: 4.6487
Step 1400, Loss: 4.5612
Step 1450, Loss: 4.5214
Step 1500, Loss: 4.4933
Model saved to /content/drive/MyDrive/minigpt2.pth
Step 1550, Loss: 4.4461
Step 1600, Loss: 4.4156
Step 1650, Loss: 4.5123
Step 1700, Loss: 4.3052
Step 1750, Loss: 4.2526
Step 