In [7]:
import math, torch, random
from torch import nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import GPT2TokenizerFast

In [8]:
random.seed(0); torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
BLOCK = 256      # sequence length (tokens)
BATCH = 16       # batch size
EPOCHS = 3
LR = 3e-4

tok = GPT2TokenizerFast.from_pretrained("gpt2")
if tok.pad_token is None: tok.pad_token = tok.eos_token

In [10]:
def make_chunks(split):
    ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
    text = tok.eos_token.join(ds["text"])
    ids = tok(text, add_special_tokens=False)["input_ids"]
    # we want input length BLOCK and target length BLOCK, so make windows of BLOCK+1 then split
    L = (len(ids) // (BLOCK + 1)) * (BLOCK + 1)
    ids = ids[:L]
    chunks = [ids[i:i+BLOCK+1] for i in range(0, L, BLOCK+1)]
    return chunks

In [11]:
class LMDataset(Dataset):
    def __init__(self, chunks):
        self.x = [torch.tensor(c[:-1], dtype=torch.long) for c in chunks]
        self.y = [torch.tensor(c[1:],  dtype=torch.long) for c in chunks]
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i]

In [14]:
train_chunks = make_chunks("train")
val_chunks   = make_chunks("validation")
train_loader = DataLoader(LMDataset(train_chunks), batch_size=BATCH, shuffle=True)
val_loader   = DataLoader(LMDataset(val_chunks),   batch_size=BATCH, shuffle=False)

vocab_size = len(tok)

Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/wikitext.py'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/.huggingface.yaml'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/dataset_infos.json'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/wikitext.py'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratc

In [17]:
class TinyGPT(nn.Module):
    def __init__(self, vocab, d_model=256, n_layers=4, n_heads=8, d_ff=1024, block=BLOCK):
        super().__init__()
        self.block = block
        self.tok_emb = nn.Embedding(vocab, d_model)
        self.pos_emb = nn.Embedding(block, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, batch_first=True)
        self.tr = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab, bias=False)
        # tie weights
        self.lm_head.weight = self.tok_emb.weight
        
    def _causal_mask(self, L):
        # [L, L] upper-triangular mask True where we want to mask (future tokens)
        m = torch.ones(L, L, dtype=torch.bool, device=self.lm_head.weight.device).triu(1)
        return m

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.tok_emb(x) + self.pos_emb(pos)
        mask = self._causal_mask(L)
        h = self.tr(h, mask=mask)
        return self.lm_head(h)

In [18]:
model = TinyGPT(vocab_size).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

In [19]:
@torch.no_grad()
def evaluate():
    model.eval()
    losses = []
    for x, y in val_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)                      # [B, L, V]
        loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
        losses.append(loss.item())
    m = sum(losses)/len(losses)
    ppl = math.exp(min(20.0, m))
    return m, ppl

In [20]:
for epoch in range(1, EPOCHS+1):
    model.train()
    running = 0.0
    for i, (x, y) in enumerate(train_loader, start=1):
        x, y = x.to(device), y.to(device)
        logits = model(x)                                # [B, L, V]
        loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        running += loss.item()
        if i % 100 == 0:
            print(f"epoch {epoch} step {i}: train_loss={running/100:.4f}")
            running = 0.0
    val_loss, val_ppl = evaluate()
    print(f"==> epoch {epoch}: val_loss={val_loss:.4f}  val_ppl={val_ppl:.2f}")

epoch 1 step 100: train_loss=36.2972
epoch 1 step 200: train_loss=23.2266
epoch 1 step 300: train_loss=17.6469
epoch 1 step 400: train_loss=14.3102
epoch 1 step 500: train_loss=12.2759
==> epoch 1: val_loss=9.8756  val_ppl=19450.40
epoch 2 step 100: train_loss=10.1666
epoch 2 step 200: train_loss=9.5792
epoch 2 step 300: train_loss=9.0992
epoch 2 step 400: train_loss=8.7476
epoch 2 step 500: train_loss=8.4784
==> epoch 2: val_loss=7.8337  val_ppl=2524.35
epoch 3 step 100: train_loss=8.0183
epoch 3 step 200: train_loss=7.9145
epoch 3 step 300: train_loss=7.7985
epoch 3 step 400: train_loss=7.7028
epoch 3 step 500: train_loss=7.6226
==> epoch 3: val_loss=7.3498  val_ppl=1555.94
