In [1]:
%run LanguageModel.py

In [2]:
%run DataLoader.py

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_dl = DataLoader('train', ('de', 'en'), max_length = 50, device = device)

In [5]:
val_dl = DataLoader('dev', ('de', 'en'), languageModels = train_dl.languageModels, max_length = 50, device = device)

In [6]:
%run seq2seq.py

In [7]:
lm1 = train_dl.languageModels[train_dl.languages[0]]
lm2 = train_dl.languageModels[train_dl.languages[1]]
model_config = ModelConfig(input_size = lm1.n_tokens, beam_width = 3, hidden_size = 256, output_size = lm2.n_tokens, rnn_type='gru', bidirectional=False, attention = 'global', max_length=52)
checkpoint = torch.load("./state_dict.tar")
s2s = seq2seq(model_config=model_config, state_dict = checkpoint, device = device)

In [8]:
def loss_metric(input, output, ground_truth, nll):
    return nll / ground_truth[0].size(0)

def perplexity(input, output, ground_truth, nll):
    nll /= ground_truth[0].size(0)
    return math.exp(nll)

In [9]:
%run ScorePrinter.py

In [10]:
def train_epochs(epochs, print_every=1000):
    n_iters = len(train_dl)
    score_printer = ScorePrinter("Training", [('NLL', loss_metric),('Perplexity', perplexity)])
    
    for epoch in range(1, epochs+1):
        score_printer.startEpoch(epoch)
        idx_permutation = np.random.permutation(len(train_dl))[:5]
        for i, index in enumerate(idx_permutation):
            input_tensor, target_tensor = train_dl.tensorsFromPos(index)

            loss, output_sentence = s2s.train(input_tensor, target_tensor)
            score_printer.update(input_tensor, target_tensor, output_sentence, loss)
            
            if (i + 1) % print_every == 0:
                score_printer.printAvg(print_every)
                validate(print_every)
        
        score_printer.endEpoch(epoch)
        score_printer.printAvg(len(train_dl))
        validate(1)
        torch.save(s2s.state_dict(),"./state_dict.tar")

In [11]:
import numpy as np
import math

def validate(n = None):
    score_printer = ScorePrinter("Validation", [('NLL', loss_metric),('Perplexity', perplexity)])
    n = n or len(val_dl)
    idx_permutation_val = np.random.permutation(len(val_dl))[:n]
    for j, val_index in enumerate(idx_permutation_val):
        input_tensor_val, target_tensor_val = val_dl.tensorsFromPos(val_index)
        loss, output_sentence = s2s.evaluate(input_tensor_val, target_tensor_val)
        score_printer.update(input_tensor_val, target_tensor_val, output_sentence, loss)
    score_printer.printAvg(showCount = False)

In [12]:
train_epochs(1, print_every=2)



Epoch 1 started
 
[Training] 2 examples /  NLL: 10.206245742926077 Perplexity: 64349.42343453574  
[Validation]   NLL: 5.808567667907139 Perplexity: 564.0324425961962  
[Training] 4 examples /  NLL: 9.66403129499277 Perplexity: 36770.243221127974  
[Validation]   NLL: 2.9710540771484375 Perplexity: 22.805190965414234 

Epoch 1 ended
 
[Training] 5 examples /  NLL: 9.349462667655768 Perplexity: 30069.30763182593  
[Validation]   NLL: 6.386066364792158 Perplexity: 593.5173012178057 

In [13]:
val_dl.sentenceFromTensor('de', s2s.predict(val_dl.tensorsFromPos(15)[0])[0])

['SOS',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern',
 'verbessern']