In [1]:
import math
import os
import time

import torch
from datasets import load_dataset
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from src.gpt2 import GPT2, GPT2Config  # <- ton .py

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True  # si dispo, accélère souvent
torch.backends.cudnn.allow_tf32 = True

# --- Modèle (bon compromis pour une carte grand public)
config = GPT2Config(
    block_size=512,
    n_layer=12,
    n_head=12,
    n_embd=768,
    dropout=0.1,
    vocab_size=50257
)
model = GPT2(config).to(device)

# --- Hyperparams (à ajuster selon VRAM)
batch_size = 8  # micro-batch
grad_accum = 4  # batch effectif = 32
max_steps = 200_000
warmup_steps = 2_000
max_lr = 3e-4
weight_decay = 0.1
grad_clip = 1.0

use_amp = True
scaler = GradScaler(enabled=use_amp)

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95), weight_decay=weight_decay)


  scaler = GradScaler(enabled=use_amp)


In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("wikitext", "wikitext-2-raw-v1")


def tokenize_function(examples):
    return tokenizer(examples["text"])


tokenized = ds.map(tokenize_function, batched=True, remove_columns=["text"])


def build_stream(split):
    ids = []
    for ex in split:
        ids.extend(ex["input_ids"])
        ids.append(tokenizer.eos_token_id)
    return torch.tensor(ids, dtype=torch.long)


train_stream = build_stream(tokenized["train"])
val_stream = build_stream(tokenized["validation"])


class StreamDataset(torch.utils.data.Dataset):
    def __init__(self, stream_ids, block_size):
        self.data = stream_ids
        self.block_size = block_size

    def __len__(self):
        return (len(self.data) - 1) // self.block_size

    def __getitem__(self, i):
        s = i * self.block_size
        x = self.data[s:s + self.block_size]
        y = self.data[s + 1:s + 1 + self.block_size]
        return x, y


train_ds = StreamDataset(train_stream, config.block_size)
val_ds = StreamDataset(val_stream, config.block_size)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)




In [3]:
def get_lr(step, warmup_steps, max_steps, max_lr):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / max(1, (max_steps - warmup_steps))
    return 0.5 * max_lr * (1.0 + math.cos(math.pi * progress))


@torch.no_grad()
def evaluate(model, loader, max_batches=50):
    model.eval()
    losses = []
    for i, (x, y) in enumerate(loader):
        if i >= max_batches:
            break
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        _, loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)


In [4]:
save_dir = r"C:\workspace\GPT2\models"
os.makedirs(save_dir, exist_ok=True)


def save_checkpoint(step, model, optimizer, config, extra=None):
    path = os.path.join(save_dir, f"ckpt_step{step}.pt")
    payload = {
        "step": step,
        "config": config.__dict__,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict(),
    }
    if extra is not None:
        payload["extra"] = extra
    torch.save(payload, path)
    # aussi un "last" pratique
    torch.save(payload, os.path.join(save_dir, "ckpt_last.pt"))
    return path


In [5]:
log_every = 50
eval_every = 1000
save_every = 10_000

model.train()
t0 = time.time()
step = 0

for epoch in range(10_000):  # on break via max_steps
    for x, y in train_loader:
        if step >= max_steps:
            break

        lr = get_lr(step, warmup_steps, max_steps, max_lr)
        for pg in optimizer.param_groups:
            pg["lr"] = lr

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with autocast(enabled=use_amp):
            _, loss = model(x, y)
            loss = loss / grad_accum  # important

        scaler.scale(loss).backward()

        if (step + 1) % grad_accum == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        if step % log_every == 0:
            dt = time.time() - t0
            print(f"step {step:7d} | loss {(loss.item() * grad_accum):.4f} | lr {lr:.2e} | {dt:.1f}s")
            t0 = time.time()

        if step % eval_every == 0 and step > 0:
            val_loss = evaluate(model, val_loader, max_batches=50)
            print(f"  eval | val_loss {val_loss:.4f} | ppl {math.exp(val_loss):.2f}")

        if step % save_every == 0 and step > 0:
            path = save_checkpoint(step, model, optimizer, config)
            print(f"  ✅ saved: {path}")

        step += 1

    if step >= max_steps:
        break


  with autocast(enabled=use_amp):


step       0 | loss 10.9493 | lr 0.00e+00 | 0.3s
step      50 | loss 10.3451 | lr 7.50e-06 | 7.6s
step     100 | loss 9.7374 | lr 1.50e-05 | 7.2s
step     150 | loss 9.5159 | lr 2.25e-05 | 7.3s
step     200 | loss 9.1871 | lr 3.00e-05 | 7.6s
step     250 | loss 8.7781 | lr 3.75e-05 | 7.3s
step     300 | loss 8.0904 | lr 4.50e-05 | 7.4s
step     350 | loss 7.8454 | lr 5.25e-05 | 7.6s
step     400 | loss 7.4752 | lr 6.00e-05 | 7.3s
step     450 | loss 7.1099 | lr 6.75e-05 | 7.1s
step     500 | loss 6.8796 | lr 7.50e-05 | 7.2s
step     550 | loss 7.0984 | lr 8.25e-05 | 7.3s
step     600 | loss 6.9429 | lr 9.00e-05 | 7.3s
step     650 | loss 6.8454 | lr 9.75e-05 | 7.3s
step     700 | loss 6.5779 | lr 1.05e-04 | 7.3s
step     750 | loss 6.5341 | lr 1.12e-04 | 7.3s
step     800 | loss 6.4674 | lr 1.20e-04 | 7.2s
step     850 | loss 6.3940 | lr 1.28e-04 | 7.3s
step     900 | loss 6.5261 | lr 1.35e-04 | 7.4s
step     950 | loss 6.3990 | lr 1.42e-04 | 7.2s
step    1000 | loss 6.4525 | lr 1.50e-

KeyboardInterrupt: 