In [1]:
import math
import os
import time

import torch
from datasets import load_dataset
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from src.gpt2 import GPT2, GPT2Config, set_seed, StreamDataset, get_lr

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True  # si dispo, accélère souvent
torch.backends.cudnn.allow_tf32 = True

set_seed(42)

config = GPT2Config(
    block_size=512,
    n_layer=12,
    n_head=12,
    n_embd=768,
    dropout=0.1,
    vocab_size=50257
)

model = GPT2(config).to(device)

# --- Hyperparams
batch_size = 8  # micro-batch
grad_accum = 4  # batch effectif = 32
max_steps = 20_000
warmup_steps = 2_000
max_lr = 3e-4
weight_decay = 0.1
grad_clip = 1.0

use_amp = True
scaler = GradScaler(enabled=use_amp)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95), weight_decay=weight_decay)

  scaler = GradScaler(enabled=use_amp)


In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("wikitext", "wikitext-2-raw-v1")


def tokenize_function(examples):
    return tokenizer(examples["text"])


tokenized = ds.map(tokenize_function, batched=True, remove_columns=["text"])


def build_stream(split):
    ids = []
    for ex in split:
        ids.extend(ex["input_ids"])
        ids.append(tokenizer.eos_token_id)
    return torch.tensor(ids, dtype=torch.long)


train_stream = build_stream(tokenized["train"])
val_stream = build_stream(tokenized["validation"])

train_ds = StreamDataset(train_stream, config.block_size)
val_ds = StreamDataset(val_stream, config.block_size)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)


In [3]:
@torch.no_grad()
def evaluate(model, loader, max_batches=50):
    model.eval()
    losses = []
    for i, (x, y) in enumerate(loader):
        if i >= max_batches:
            break
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        _, loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)


In [4]:
save_dir = r"C:\workspace\GPT2\models"
os.makedirs(save_dir, exist_ok=True)


def save_checkpoint(step, model, optimizer, config, extra=None):
    path = os.path.join(save_dir, f"ckpt_step{step}.pt")
    payload = {
        "step": step,
        "config": config.__dict__,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict(),
    }
    if extra is not None:
        payload["extra"] = extra
    torch.save(payload, path)
    # aussi un "last" pratique
    torch.save(payload, os.path.join(save_dir, "ckpt_last.pt"))
    return path

In [5]:
log_every = 50
eval_every = 500
save_every = 1_000

model.train()
t0 = time.time()
step = 0

for epoch in range(10_000):  # on break via max_steps
    for x, y in train_loader:
        if step >= max_steps:
            break

        lr = get_lr(step, warmup_steps, max_steps, max_lr)
        for pg in optimizer.param_groups:
            pg["lr"] = lr

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with autocast(enabled=use_amp):
            _, loss = model(x, y)
            loss = loss / grad_accum  # important

        scaler.scale(loss).backward()

        if (step + 1) % grad_accum == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        if step % log_every == 0:
            dt = time.time() - t0
            print(f"step {step:7d} | loss {(loss.item() * grad_accum):.4f} | lr {lr:.2e} | {dt:.1f}s")
            t0 = time.time()

        if step % eval_every == 0 and step > 0:
            val_loss = evaluate(model, val_loader, max_batches=50)
            print(f"  eval | val_loss {val_loss:.4f} | ppl {math.exp(val_loss):.2f}")

        if step % save_every == 0 and step > 0:
            path = save_checkpoint(step, model, optimizer, config)
            print(f"  ✅ saved: {path}")

        step += 1

    if step >= max_steps:
        break

  with autocast(enabled=use_amp):


step       0 | loss 10.9829 | lr 0.00e+00 | 0.5s
step      50 | loss 10.3415 | lr 7.50e-06 | 7.3s
step     150 | loss 9.4357 | lr 2.25e-05 | 7.2s
step     200 | loss 9.1855 | lr 3.00e-05 | 7.2s
step     250 | loss 8.6493 | lr 3.75e-05 | 7.1s
step     300 | loss 8.0987 | lr 4.50e-05 | 7.1s
step     350 | loss 7.7415 | lr 5.25e-05 | 7.1s
step     400 | loss 7.2956 | lr 6.00e-05 | 7.1s
step     450 | loss 7.0578 | lr 6.75e-05 | 7.1s
step     500 | loss 7.0417 | lr 7.50e-05 | 7.2s
  eval | val_loss 6.9613 | ppl 1054.97
step     550 | loss 6.8347 | lr 8.25e-05 | 10.3s
step     600 | loss 6.7800 | lr 9.00e-05 | 7.1s
step     650 | loss 6.8145 | lr 9.75e-05 | 7.1s
step     700 | loss 6.7635 | lr 1.05e-04 | 7.1s
step     750 | loss 6.5073 | lr 1.12e-04 | 7.2s
step     800 | loss 6.5297 | lr 1.20e-04 | 7.1s
step     850 | loss 6.5149 | lr 1.28e-04 | 7.1s
step     900 | loss 6.4721 | lr 1.35e-04 | 7.1s
step     950 | loss 6.4713 | lr 1.42e-04 | 7.1s
step    1000 | loss 6.3468 | lr 1.50e-04 | 7.2