Building GPT2 from scratch in PyTorch and training it on the whole Divina Commedia (plus some sonnets) to produce style-specific output (tercet of hendecasyllables)


In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import Counter
from dataclasses import dataclass
import inspect
import math

In [None]:
#install library including the gpt2 tokenizer
!pip install tiktoken



In [None]:
import tiktoken

In [None]:
#useful variables
batch_size = 8 #can't increase because of Colab gpu usage limitations
block_size = 1024
max_iters = 5000
eval_interval = 500
eval_iters = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#hyperparameters replicate Karpathy tutorial, could also be determined through cross val
#note: some of these might be overrun by loading the GPT2 weights directly from HuggingFace
learning_rate = 3e-4
n_emb = 400
n_head = 10
n_layers = 10
dropout = 0.2
patience = 1
best_val_loss = float('inf')
early_stop_counter = 0

Loading dataset and tokenization

In [None]:
import gdown

file_id = "1qn_IL9W1ooscxWLPJAeUxNcvhCbcw1Xf"
url = f"https://drive.google.com/uc?id={file_id}"
output = "dante-corpus.txt"
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1qn_IL9W1ooscxWLPJAeUxNcvhCbcw1Xf
To: /content/dante-corpus.txt
100%|██████████| 562k/562k [00:00<00:00, 69.3MB/s]


'dante-corpus.txt'

In [None]:
with open('dante-corpus.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print('Length of dataset in characters:', len(text))

Length of dataset in characters: 537063


In [None]:
print(text[537000:])

ima tremando si riscosse
veggendo morto ’l cor nel lato manco.



In [None]:
#load gpt2 encoder
encoder = tiktoken.get_encoding("gpt2")

In [None]:
data = torch.tensor(encoder.encode(text), dtype=torch.long)
print(data.shape, data.dtype)
vocab_size = encoder.n_vocab
#print(data[:1000])

torch.Size([238044]) torch.int64


In [None]:
#split train/val
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

Build model, load GPT2 and fine tune

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

#LoRA class for fine tuning
class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank):
        super().__init__()
        self.rank = rank
        self.lora_a = nn.Parameter(torch.zeros(in_features, rank))
        self.lora_b = nn.Parameter(torch.zeros(rank, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
        nn.init.zeros_(self.lora_b)

    def forward(self, x):
        return x @ self.lora_a @ self.lora_b

class CausalSelfAttention(nn.Module):

    def __init__(self, config, lora_rank=0):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.lora_attn = LoRA(config.n_embd, 3 * config.n_embd, lora_rank) if lora_rank > 0 else None
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.lora_proj = LoRA(config.n_embd, config.n_embd, lora_rank) if lora_rank > 0 else None

        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        qkv = self.c_attn(x) + (self.lora_attn(x) if self.lora_attn else 0)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y) + (self.lora_proj(y) if self.lora_proj else 0)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config, lora_rank=0):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config, lora_rank=lora_rank)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x)) #skip connections
        x = x + self.mlp(self.ln_2(x)) #skip connections
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension

class GPT(nn.Module):

    def __init__(self, config, lora_rank=0):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config, lora_rank=lora_rank) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # init params
        self.apply(self._init_weights)

        #weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and posisition embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb
        # forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x)
        # forward the final layernorm and the classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type, lora_rank=None):
        """Loads pretrained GPT-2 model weights from huggingface, with optional LoRA integration."""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print(f"Loading weights from pretrained GPT: {model_type}")

        # Define model-specific configuration
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257  # Always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024  # Always 1024 for GPT model checkpoints

        # Create a new GPT model instance with LoRA
        config = GPTConfig(**config_args)
        model = GPT(config, lora_rank=lora_rank)

        # Load state_dict of the HuggingFace model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # Get the state_dict of our custom model
        sd = model.state_dict()

        # Exclude LoRA layers from pretrained state_dict keys
        def is_lora_param(name):
            return any(lora_key in name for lora_key in ['lora_a', 'lora_b'])

        sd_keys = [k for k in sd.keys() if not k.endswith('.attn.bias') and not is_lora_param(k)]
        sd_keys_hf = [k for k in sd_hf.keys() if not k.endswith('.attn.masked_bias') and not k.endswith('.attn.bias')]

        # Ensure matching number of parameters (excluding LoRA layers)
        assert len(sd_keys_hf) == len(sd_keys), f"Mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"

        # Handle Conv1D transposition for compatibility with HuggingFace weights
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # Special handling for Conv1D weights
                assert sd_hf[k].shape[::-1] == sd[k].shape, f"Shape mismatch for {k}"
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # Direct copy for non-transposed weights
                assert sd_hf[k].shape == sd[k].shape, f"Shape mismatch for {k}"
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        print("Weights loaded successfully. LoRA layers remain uninitialized and trainable.")
        return model


In [None]:
model = GPT.from_pretrained('gpt2', lora_rank = 16)

Loading weights from pretrained GPT: gpt2


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Weights loaded successfully. LoRA layers remain uninitialized and trainable.


In [None]:
torch.manual_seed(1337)

def get_batch(split):
  data = train_data if split == 'train' else val_data
  index = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in index])
  y = torch.stack([data[i+1:i+block_size+1] for i in index])
  x, y = x.to(device), y.to(device)
  return x,y

In [None]:
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      logits, loss = model(X, Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

In [None]:
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

#optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) #add betas = (0.9, 0.95)
# Freeze transformer parameters
#for param in model.transformer.parameters():
#    if param == model.transformer.ln_f.parameters():
#      param.requires_grad = True
#    else:
#      param.requires_grad = False

for name, param in model.named_parameters():
    if 'lora' in name: # Check if the parameter belongs to lora
        param.requires_grad = True
    elif 'wte' in name:
        param.requires_grad = True
    elif 'lm_head' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
# Train only the lora layers, lm head and wte
optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad], lr=learning_rate)


for iter in range(max_iters):

  if iter%eval_interval == 0 or iter == max_iters -1:
    losses = estimate_loss()
    val_loss = losses['val']
    print(f"Step {iter}: train loss {losses['train']:4f}, val loss {losses['val']:.4f}")

    if val_loss < best_val_loss:
      best_val_loss = val_loss
      early_stop_counter = 0  # Reset counter if validation loss improves
    else:
      early_stop_counter += 1  # Increment counter if no improvement

    #Check if we should stop early
    if early_stop_counter >= patience:
      print("Early stopping triggered. Stopping training.")
      break

  xb, yb = get_batch('train')

  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  #norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  optimizer.step()

125.324544 M parameters
Step 0: train loss 5.390614, val loss 5.3659
Step 500: train loss 2.773199, val loss 3.3514
Step 1000: train loss 2.255894, val loss 3.5944
Early stopping triggered. Stopping training.


In [None]:
def generate(idx, max_new_tokens):
    for _ in range(max_new_tokens):
      idx_cond = idx[:, -block_size:]
      logits, loss = m(idx_cond)
      logits = logits[:,-1, :]
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx

In [None]:
#fine tune lora 16, lmhead and wte
context = torch.tensor(encoder.encode('Nel mezzo del cammin di nostra vita'), dtype=torch.long, device=device).unsqueeze(0)
print(encoder.decode(generate(context, max_new_tokens=500)[0].tolist()))

Nel mezzo del cammin di nostra vita dal consiglio». Onde volsi,
quando noi passai, e ’l maestro i sostegno,
con sappi dentro a Dio con le piante.

Non avea sparte ancora il popesmo
pur a mano ad esso innanzi pien giorno,
e Pistoia canta sentia ne brutta.

Venimmo, spavam per sé stelle stelle
d’esparea verso a sé del cielo ad alleggiorno,
d’Mentore e cagion di trassel l’orreggiorno;

indi mi sieni la fronte intra Orbis devi,
né la testa via tue in sete tagliocco,
dimmi officio in quella legge ogne errori.

Quel quando fuor s’affaticar le fassai,
guardommi elli al fin d’i Troiani;
ma ’l maestro mio, che ’l dolor sì poco:

si perché perché sì li altri si scossai,
da qual ragioni a Vercelli etterno cui,
e traggasi quel chiaro in tutti smorto.

Pistoia di Dio mi fu detto e disfai
carcato quel mar di Dio, se cotali
(indi parler’ i Pfeiffer apparivanci
del Palladier d’i Troiani e dal Palladier,

quando mi parea gente non resti;
batteggianar: «Benedetto in etterno
ne l’alito e con campo d’erre

In [None]:
#fine_tune lora 8, lm_head, wte copied from previous versions for comparison
context = torch.tensor(encoder.encode('Nel mezzo del cammin di nostra vita'), dtype=torch.long, device=device).unsqueeze(0)
print(encoder.decode(generate(context, max_new_tokens=500)[0].tolist()))

Nel mezzo del cammin di nostra vita dal consolar
e di sangue, parlando privadi accraibo;
dicendo ’l pietro s’intelletto testaro
qua fossimo disïar con essere ivi’
andava in me: «Però tespecciar sicarsa’;
indi vid’ io l’altrimento a la mente’.

Merrenti ancor così ron com’ io concilio
dove sproni d’i pispone a la gente ruote
che l’una facea la mia donna.

Dio non salir lo scoglio mi dicere sgane,
e con l’una riman, tra ’l salire sgane,
e per lo scoglio m’ombra che volere.

Poi gente altro uscì come il ventre rote
dura, ch’uomo e forse vapor le membra
quand’ io lo uscì come antico era.

Nulla ruina, ch’è od oltre dolce leggera,
contra ’l terzo epicicurta ond’ el partito,
quivi e oramai; e quindi oramai osava;

come quella donna sù, di spera gloria,
a doppacciva fé levante dico era,
ma viva ferro spezzato in la leggella!

Con l’altra onde ci paura fummo il vecchio,
quanto per paura e per passi moderno
ove l’arco, del verbo de la dannia;

e or dolce con campo d’erbe l’ovra
per suo per le p

In [None]:
#product of overfitting, training loss 0.6, fine tuning the full model, copied from previous version for comparison
context = torch.tensor(encoder.encode('Nel mezzo del cammin di nostra vita'), dtype=torch.long, device=device).unsqueeze(0)
print(encoder.decode(generate(context, max_new_tokens=500)[0].tolist()))

Nel mezzo del cammin di nostra vita
qual negligenza, rifonde
su per lo monte che n’è gita.

Anima fatta la destra costa già da lagrime
sovr’ altrui sangue in natural vasello,
per che dintorno suonin mille tube,

chi move innanzi a la tua scïenza non è,
che cotai colpi per essemproche fierlte;

né credo che la vostra chiesa spense
a chi stato, l’onor d’ingegno maìista,
da un demonio già non si perde.

Ella ruina in sì fatta parte verde;
guarda’mi allor, come voi versi,
gittati fia glorïosi di giunchi».

«Oh, questi è quel ch’i’ v’ho scorte»,
rispuose lui, «tegnon qui ne nel cista;
però colà ciò vegnon qui ne l’aspetta.

Tu hai sì presso a l’ombra che procò l’onta
di quel ch’ebbe or così la lingua pronta,
in voce cangia, e tra voi imposto
da un discernal tutte altre ristette.

Per le nove radici d’esto legno
sovr’ altrui, ch’i’ ho veduto in parte
de l’altro, Micòl amor che dentro a r giuggi;

e noi movemmo l’assai suaisla e l’altro poli
tutto, qual che si mostrò e ciascuna,
per che ’l mi