In [1]:
import torch
import random
import pickle
import os
from tqdm import tqdm

from preprocessing import main as preprocess_data
from GPT import GPT
from parameters import (
    GPT4_SPLIT_PATTERN,
    VOCAB_SIZE,
    SPECIAL_TOKENS,
    CONTEXT_SIZE,
    BATCH_SIZE,
    device,
    MAX_STEPS,
    LEARNING_RATE,
    EVAL_INTERVAL,
    EVAL_LOSS_BATCHES
)

In [None]:
with open("data/translation.txt", "r", encoding="utf-8") as f:
    data = f.read().splitlines()

# load data
translations = []
for sample in data:
    english, german, src = sample.split("\t")
    translations.append((english, german))

In [5]:
tokenizer, eng_data, ger_data = preprocess_data(translations)

n_split1, n_split2 = int(len(translations) * 0.8), int(len(translations) * 0.9)
train_data = (eng_data[:n_split1], ger_data[:n_split1])
val_data = (eng_data[n_split1:n_split2], ger_data[n_split1:n_split2])
test_data = (eng_data[n_split2:], ger_data[n_split2:])

  0%|          | 0/277891 [00:00<?, ?it/s]

100%|██████████| 277891/277891 [3:07:26<00:00, 24.71it/s]  


In [None]:
english, german = translations[-100]
eng_enc, ger_enc = tokenizer.encode(english), tokenizer.encode(german)

print("_".join([tokenizer.vocab[idx].decode("utf-8") for idx in eng_enc]))
print("_".join([tokenizer.vocab[idx].decode("utf-8") for idx in ger_enc]))

In [6]:
eng, ger = train_data[0][0].tolist(), train_data[1][0].tolist()
tokenizer.decode(eng), tokenizer.decode(ger)

('Go.<|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>',
 'Geh.<|ENDOFTEXT|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>')

In [7]:
# Loader that returns a batch
def get_batch(split):
    data = train_data if split == "train" else val_data if split == "val" else test_data
    ix = torch.randint(0, len(data), (BATCH_SIZE, ))
    
    encoder_input = torch.stack([data[0][i] for i in ix]) # [BATCH_SIZE, CONTEXT_SIZE]
    decoder_input = torch.stack([data[1][i][:-1] for i in ix]) # [BATCH_SIZE, CONTEXT_SIZE]
    target = torch.stack([data[1][i][1:] for i in ix]) # [BATCH_SIZE, CONTEXT_SIZE]
    encoder_input, decoder_input, target = encoder_input.to(device), decoder_input.to(device), target.to(device)
    return (encoder_input, decoder_input, target)

In [8]:
enc, dec, tar = get_batch("train")

print(tokenizer.decode(enc[0].tolist()))
print(tokenizer.decode(dec[0].tolist()))
print(tokenizer.decode(tar[0].tolist()))

Go.<|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>
Geh.<|ENDOFTEXT|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>
.<|ENDOFTEXT|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>


In [9]:
# calculate mean loss for {EVAL_LOSS_BATCHES}x batches
@torch.no_grad()
def estimate_loss():
    global model
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_LOSS_BATCHES, device=device)
        for i in tqdm(range(EVAL_LOSS_BATCHES)):
            enc_input, dec_input, target = get_batch(split)
            _, loss = model(enc_input, dec_input, target)
            losses[i] = loss.item()
        out[split] = losses.mean()  
    model.train()
    return out

In [None]:
model = GPT()
model.to(device)

""" Training Loop """
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for step in tqdm(range(MAX_STEPS)):
    # calculate loss every once in a while
    if step % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        print(f"Step {step}/{MAX_STEPS}) train: {losses['train']:.4f}, val: {losses['val']:.4f}")

    enc_input, dec_input, target = get_batch("train")
    logits, loss = model(enc_input, dec_input, target)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()