In [1]:
import torch
import torch.optim as optim
import numpy as np
from transformer import Transformer
from metrics import Evaluator
from datasets import load_dataset
from utils import *
from training import *
from tokenizer import *

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using cpu


In [2]:
traindict = load_dataset("opus100", "de-en", split="train")
testdict = load_dataset("opus100", "de-en", split="test")

Found cached dataset opus100 (/Users/tonimo/.cache/huggingface/datasets/opus100/de-en/0.0.0/256f3196b69901fb0c79810ef468e2c4ed84fbd563719920b1ff1fdc750f7704)
Found cached dataset opus100 (/Users/tonimo/.cache/huggingface/datasets/opus100/de-en/0.0.0/256f3196b69901fb0c79810ef468e2c4ed84fbd563719920b1ff1fdc750f7704)


In [3]:
train_inputs, train_labels = get_split(traindict, "en", "de", size=100)
test_inputs, test_labels = get_split(testdict, "en", "de", size=10)
trainset = Dataset(train_inputs, train_labels)
testset = Dataset(test_inputs, test_labels)

In [4]:
trainframe = trainset.dataframe()
trainframe.head()

Unnamed: 0,inputs,labels
0,It's greed that it's gonna be the death of you...,Deine Habgier wird noch dein Tod sein.
1,Vega.,- Vega.
2,Just say when.,Sagen Sie einfach stopp.
3,- Wait.,- Warte.
4,I don't wanna be here.,Ich will nicht hier sein.


In [5]:
testframe = testset.dataframe()
testframe.head()

Unnamed: 0,inputs,labels
0,"By clicking on 'Save profile', you the user ag...",Die Nutzungsbedingungen werden durch das Klick...
1,I wanted to show you something first.,Ich wollte dir erst noch etwas zeigen.
2,You have suffered because of Shinkichi.,Du musstest wegen Shinkichi leiden.
3,"moodle:bg-bab: Calendar: Day view: Friday, 25 ...",moodle:bg-bab: Kalender: Tagesansicht: Freitag...
4,"I mean, most people, they see another person w...","Ich meine, die meisten Leuten sehen eine ander..."


In [6]:
print(trainframe.isnull().values.any())
trainframe.describe()

False


Unnamed: 0,inputs,labels
count,100,100
unique,100,100
top,It's greed that it's gonna be the death of you...,Deine Habgier wird noch dein Tod sein.
freq,1,1


In [7]:
print(testframe.isnull().values.any())
testframe.describe()

False


Unnamed: 0,inputs,labels
count,10,10
unique,10,10
top,"By clicking on 'Save profile', you the user ag...",Die Nutzungsbedingungen werden durch das Klick...
freq,1,1


In [8]:
trainset.sample()

[('What do you say Marylin?', 'Oder was meinst du?')]

In [9]:
testset.sample()

[('The more pronounced rate of decline was recorded from the port of Olbia with -21.9%, while in Golfo Aranci traffic fell by 15.4% and 9.8% of Porto Torres.',
  '" In den ersten fünf Monaten des Jahres 2011 haben sich die Sarden drei Häfen abgewickelt insgesamt 781.193 Passagiere, ein Rückgang von 175.954 Einheiten (-18,4%) gegenüber dem gleichen Zeitraum des Vorjahres. Je ausgeprägter der Rückgang wurde aus dem Hafen von Olbia mit -21,9% verzeichnet, während in Golfo Aranci Verkehr sank um 15,4% und 9,8% von Porto Torres.')]

In [10]:
corpus = trainset.corpus() + testset.corpus()
tokenizer = Nerdimizer()
tokenizer.train(corpus, size=100000)

In [11]:
vocab_size = len(tokenizer)
maxlen = min(trainset.maxlen(tokenizer, factor=1), 150)
start, end, pad = tokenizer["[S]"], tokenizer["[E]"], tokenizer["[P]"]
tokenizer.padon(maxlen, pad_id=pad, end=True)
tokenizer.truncon(maxlen, end=True)
print(f"Number of word piece tokens: {vocab_size}\nMaxlen: {maxlen}")

Number of word piece tokens: 1572
Maxlen: 46


In [12]:
tokenizedset = trainset.tokenized(tokenizer)
dataloader = tokenizedset.dataloader(batch_size=32, drop_last=False)

In [13]:
model = Transformer(vocab_size, maxlen, pad_id=pad, dm=256, nhead=8, layers=3, dff=1024)
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=10e-9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=10)
evaluator = Evaluator(testset, tokenizer, start, end, maxlen, sample=10, ngrams=4, threshold=20, 
                    mode="geometric", verbose=True, device=device)
checkpoint = Checkpoint(model, optimizer, scheduler, evaluator, epochs=50, 
                    path="checkpoints/english-german", overwrite=False, verbose=True)
clock = Clock()
model.to(device);

In [14]:
train(dataloader, model, optimizer, scheduler, evaluator, checkpoint, clock,
    epochs=1000, warmups=100, verbose=True, device=device)

Training Started
Epoch 1 Started
Epoch 1 Complete | Epoch Loss: 7.4165
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:08 | Elapsed Time: 00:00:08
Epoch 2 Started
Epoch 2 Complete | Epoch Loss: 7.2613
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:03 | Elapsed Time: 00:00:11
Epoch 3 Started
Epoch 3 Complete | Epoch Loss: 7.0948
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:02 | Elapsed Time: 00:00:13
Epoch 4 Started
Epoch 4 Complete | Epoch Loss: 6.9579
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:02 | Elapsed Time: 00:00:15
Epoch 5 Started
Epoch 5 Complete | Epoch Loss: 6.8488
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:02 | Elapsed Time: 00:00:17
Epoch 6 Started
Epoch 6 Complete | Epoch Loss: 6.7514
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:02 | Elapsed Time: 00:00:19
Epoch 7 Started
Epoch 7 Complete | Epoch Loss: 6.6753
Evaluator Metric | BLEU: 0.00
Epoch Duration: 00:02 | Elapsed Time: 00:00:22
Epoch 8 Started
Epoch 8 Complete | Epoch Loss: 6.6221
Evaluator Me

KeyboardInterrupt: 