In [1]:
import torch
from tqdm import tqdm

from preprocessing import load_tokenizer_and_dataset
from GPT import GPT
from parameters import (
    BATCH_SIZE,
    device,
    MAX_STEPS,
    LEARNING_RATE,
    EVAL_INTERVAL,
    EVAL_LOSS_BATCHES
)

In [2]:
with open("data/translation.txt", "r", encoding="utf-8") as f:
    data = f.read().splitlines()

# load data
translations = []
for sample in data:
    english, german, src = sample.split("\t")
    translations.append((english, german))

In [3]:
tokenizer, eng_data, ger_data = load_tokenizer_and_dataset("models/tokenizer.pkl")

# shuffle data (so that when splitting into train, val, test each has both short and long text sequences)
shuffle = torch.randperm(eng_data.size(0))
eng_data, ger_data = eng_data[shuffle], ger_data[shuffle]

n_split1, n_split2 = int(eng_data.size(0) * 0.8), int(eng_data.size(0) * 0.9)
train_data = (eng_data[:n_split1], ger_data[:n_split1])
val_data = (eng_data[n_split1:n_split2], ger_data[n_split1:n_split2])
test_data = (eng_data[n_split2:], ger_data[n_split2:])

pad_token = [key for key, value in tokenizer.vocab.items() if value == b"<|PAD|>"][0]

In [4]:
english, german = translations[-100]
eng_enc, ger_enc = tokenizer.encode(english), tokenizer.encode(german)

print("_".join([tokenizer.vocab[idx].decode("utf-8") for idx in eng_enc]))
print("_".join([tokenizer.vocab[idx].decode("utf-8") for idx in ger_enc]))

You_ always_ have_ the_ right_ to_ ref_use_ treat_ment_,_ how_ever_,_ I_ must_ explain_ the_ potent_ial_ con_se_qu_en_ces_ if_ that_ will_ be_ your_ choice_.
Sie_ können_ die_ Be_hand_l_ung_ jederzeit_ able_hnen_;_ aller_d_ings_ muss_ ich_ Sie_ in_ diesem_ F_all_ über_ die_ möglich_en_ Aus_w_ir_k_ungen_ aufkl_ären_.


In [5]:
# Loader that returns a batch
def get_batch(split):
    data = train_data if split == "train" else val_data if split == "val" else test_data
    ix = torch.randint(0, len(data[0]), (BATCH_SIZE, ))
    
    encoder_input = torch.stack([data[0][i] for i in ix]) # [BATCH_SIZE, CONTEXT_SIZE]
    decoder_input = torch.stack([data[1][i][:-1] for i in ix]) # [BATCH_SIZE, CONTEXT_SIZE]
    target = torch.stack([data[1][i][1:] for i in ix]) # [BATCH_SIZE, CONTEXT_SIZE]
    encoder_input, decoder_input, target = encoder_input.to(device), decoder_input.to(device), target.to(device)
    return (encoder_input, decoder_input, target)

In [6]:
enc, dec, tar = get_batch("train")

print(tokenizer.decode(enc[0].tolist()))
print(tokenizer.decode(dec[0].tolist()))
print(tokenizer.decode(tar[0].tolist()))

There were stacks of books all over the floor.<|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>
<|STARTOFTEXT|>Auf dem Fußboden waren überall Bücher gestapelt.<|ENDOFTEXT|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>
Auf dem Fußboden waren überall Bücher gestapelt.<|ENDOFTEXT|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|><|PAD|>


In [7]:
# calculate mean loss for {EVAL_LOSS_BATCHES}x batches
@torch.no_grad()
def estimate_loss():
    global model, pad_token 
    out = {}
    model.eval()

    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_LOSS_BATCHES, device=device)
        for i in tqdm(range(EVAL_LOSS_BATCHES)):
            enc_input, dec_input, target = get_batch(split)
            _, loss = model(enc_input, dec_input, target, ignore_index=pad_token)
            losses[i] = loss.item()
        out[split] = losses.mean()  
    model.train()
    return out

In [8]:
model = GPT(tokenizer)

if False:
    model.to(device)


    """ Training Loop """
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    for step in tqdm(range(MAX_STEPS)):
        # calculate loss every once in a while
        if step % EVAL_INTERVAL == 0:
            losses = estimate_loss()
            print(f"Step {step}/{MAX_STEPS}) train: {losses['train']:.4f}, val: {losses['val']:.4f}")

        enc_input, dec_input, target = get_batch("train")
        
        logits, loss = model(enc_input, dec_input, target, ignore_index=pad_token)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
else:
    model.load_state_dict(torch.load("./models/gpt.pth", map_location=torch.device(device)))

In [18]:
model.translate("Ok, this is still pretty bad.")

'Nun, das ist noch immer ganz übel.'