## Imports and setup

In [1]:
import sys
assert any("deep_learning_curriculum" in p for p in sys.path)

In [2]:
from dataclasses import dataclass, field
import random

import numpy as np
import torch as t
from torch import functional as F, nn, optim
from torch.optim.optimizer import Optimizer
from tqdm import tqdm


from config import Config
from model import Transformer

## Train the model on reversing random tokens

IDK what I'm doing wrong but I was unable to get it to more than ~35% accuracy.

In [3]:
# Model (and task) configuration
cfg = Config(
    d_model=128,
    d_vocab=12,
    n_layers=3,
    n_heads=4,
    n_ctx=12,
)

# Number of sequences of tokens  

START_TOKEN = cfg.d_vocab - 1
MID_TOKEN = cfg.d_vocab - 2
MAX_VOCAB_TOKEN = cfg.d_vocab - 3
SEQ_LEN = (cfg.n_ctx - 1) // 2
MID_TOKEN_POS = SEQ_LEN + 1

BATCH_SIZE = 32
N_BATCHES = 4
N = BATCH_SIZE * N_BATCHES
# assert N % BATCH_SIZE == 0, f"{N = }; {BATCH_SIZE = }"

data = t.zeros(N, cfg.n_ctx).to(dtype=t.int64)

for i in tqdm(range(N), desc="Generating data"):
    seq0 = random.sample(range(0, MAX_VOCAB_TOKEN + 1), k=SEQ_LEN)
    seq1 = seq0[::-1]
    data[i, :] = t.tensor([START_TOKEN, *seq0, MID_TOKEN, *seq1])
    

data = data.reshape(-1, BATCH_SIZE, cfg.n_ctx) # t.tensor(seqs).reshape(-1, BATCH_SIZE, cfg.n_ctx)

print(f"{data.shape = }")

assert data[0, 0, MID_TOKEN_POS] == MID_TOKEN

Generating data: 100%|██████████| 128/128 [00:00<00:00, 52485.18it/s]

data.shape = torch.Size([4, 32, 12])





In [4]:
def loss_fn(logits: t.Tensor, tokens: t.Tensor) -> t.Tensor:
    logits = logits[:, :-1]
    tokens = tokens[:, 1:].unsqueeze(-1)
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens)[..., 0]
    return -correct_log_probs.mean()

def acc_fn(logits: t.Tensor, tokens: t.Tensor) -> float:
    logits = logits[:, MID_TOKEN_POS:-1]
    tokens = tokens[:, MID_TOKEN_POS + 1:]
    preds = logits.argmax(-1)
    acc = (tokens == preds).mean(dtype=t.float64).item()
    return acc

In [5]:
@dataclass(frozen=True, slots=True)
class TrainingHistory:
    losses: list[float] = field(default_factory=list)
    accuracies: list[float] = field(default_factory=list)

    def __post_init__(self) -> None:
        assert len(self.losses) == len(self.accuracies)
    
    def __len__(self) -> int:
        return len(self.losses)

    def append(self, loss: float, acc: float) -> None:
        assert isinstance(loss, float)
        assert isinstance(acc, float)
        self.losses.append(loss)
        self.accuracies.append(acc)

def train(
    model: nn.Module,
    data: t.Tensor,
    optimizer: Optimizer,
    *,
    n_epochs: int = 20,
    log_each_epochs: int | None = 10,
    scheduler: optim.lr_scheduler.LRScheduler | optim.lr_scheduler.ReduceLROnPlateau | None = None, # type: ignore
    early_stopping_epochs: int | None = 3
) -> TrainingHistory:
    th = TrainingHistory()

    for epoch_i in range(n_epochs):
        
        batch_losses: list[float] = []
        batch_accs: list[float] = []
        
        for batch in data:
            batch_logits = model(batch)
            batch_loss = loss_fn(batch_logits, batch)
            batch_acc = acc_fn(batch_logits, batch)
            batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            batch_losses.append(batch_loss.item())
            batch_accs.append(batch_acc)
        
        loss = np.mean(batch_losses).item()
        acc = np.mean(batch_accs).item()
        
        if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(loss)
        elif scheduler is not None:
            scheduler.step()                
    
        th.append(loss=loss, acc=acc)
        
        if log_each_epochs is not None and epoch_i % log_each_epochs == 0:
            print(f"[{epoch_i}] {loss=:.3f}; {acc=:.2%}")
            
        if early_stopping_epochs is not None and len(th.accuracies) >= early_stopping_epochs and np.prod(th.accuracies[-early_stopping_epochs:]) == 1:
            print(f"Leaving early after {epoch_i} epochs. Accuracy has stayed at 100% for the last {early_stopping_epochs} epochs.")
            break
        
    return th

In [6]:
N_EPOCHS = 1000
LR = 1e-3

log_each_epochs = 10 # N_EPOCHS // 50

model = Transformer(cfg)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda i: i / 1000 if i % 100 == 0 and i > 100 else 1)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch_i: max(i/1))
# scheduler = None # optim.lr_scheduler.LambdaLR(optimizer, lambda i: max(i/100, 1), verbose=False)
# optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
# optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.2, last_epoch=-1)

In [7]:
th = train(
    model=model,
    data=data,
    optimizer=optimizer,
    n_epochs=N_EPOCHS,
    log_each_epochs=log_each_epochs,
    scheduler=scheduler
)

[0] loss=2.481; acc=9.38%
[10] loss=2.129; acc=14.53%
[20] loss=2.077; acc=16.56%
[30] loss=2.376; acc=16.88%
[40] loss=2.073; acc=16.56%
[50] loss=2.059; acc=15.47%
[60] loss=2.054; acc=17.03%
[70] loss=2.048; acc=15.78%
[80] loss=2.042; acc=15.94%
[90] loss=2.037; acc=14.84%
[100] loss=2.032; acc=15.47%
[110] loss=2.027; acc=16.25%
[120] loss=2.017; acc=15.94%
[130] loss=2.005; acc=16.88%
[140] loss=1.992; acc=15.94%
[150] loss=1.964; acc=20.16%
[160] loss=2.009; acc=16.09%
[170] loss=1.995; acc=16.41%
[180] loss=1.982; acc=17.50%
[190] loss=1.952; acc=18.44%
[200] loss=1.943; acc=20.62%
[210] loss=1.924; acc=21.41%
[220] loss=1.907; acc=24.06%
[230] loss=1.947; acc=21.41%
[240] loss=1.876; acc=22.97%
[250] loss=1.853; acc=27.97%
[260] loss=1.809; acc=30.31%
[270] loss=1.773; acc=35.31%
[280] loss=1.714; acc=39.53%
[290] loss=1.679; acc=39.84%
[300] loss=1.597; acc=44.06%
[310] loss=1.520; acc=50.62%
[320] loss=1.435; acc=56.25%
[330] loss=1.316; acc=63.44%
[340] loss=1.273; acc=70.1

## Save the model

In [8]:
from datetime import datetime

timestamp = datetime.now().isoformat("T", "minutes").replace(":", "-")
filename = f"model_reversal_{timestamp}.pt"

t.save(model.state_dict(), filename)

## Test

In [9]:
def generate(tokens: t.Tensor, n_next_tokens: int = 1, *, verbose: bool = False) -> t.Tensor:
    assert tokens.ndim == 2
    assert n_next_tokens > 0
    for i in tqdm(range(n_next_tokens), disable=not verbose):
        logits = model(tokens)
        preds = logits.argmax(-1)
        next_tokens = preds[..., -1:]
        # print(i, tokens.tolist(), next_tokens.tolist())
        tokens = t.cat([tokens, next_tokens], dim=-1)
    return tokens

In [10]:
data_flat = data.reshape(-1, data.size(-1))

In [11]:
for i in range(N):
    seq = data_flat[i].reshape(1, -1)
    seq_pre = seq[:, :MID_TOKEN_POS + 1]
    seq_post = seq[:, MID_TOKEN_POS + 1:]
    n_next_tokens = seq_post.size(-1)

    generated = generate(seq_pre, n_next_tokens)
    pred = generated[:, -n_next_tokens:]
    logits = model(seq)
    acc = acc_fn(logits, seq)
    
    if (seq_post != pred).any() and acc == 1:
        print(f"[{i}]")
        print(f"seq: {seq.tolist()}")
        print(f"post: {seq_post.tolist()}")
        print(f"pred: {pred.tolist()}")
        print()
        


In [12]:
i = 3
seq = data_flat[i].reshape(1, -1)
seq_pre = seq[:, :MID_TOKEN_POS + 1]
seq_post = seq[:, MID_TOKEN_POS + 1:]
n_next_tokens = seq_post.size(-1)

generated = generate(seq_pre, n_next_tokens)
pred = generated[:, -n_next_tokens:]

print("seq:", seq.tolist())
print("pre:", seq_pre.tolist())
print("post:", seq_post.tolist())
print("pred:", pred.tolist())

seq: [[11, 5, 9, 1, 6, 0, 10, 0, 6, 1, 9, 5]]
pre: [[11, 5, 9, 1, 6, 0, 10]]
post: [[0, 6, 1, 9, 5]]
pred: [[0, 6, 1, 9, 5]]
