In [3]:
%run LanguageModel.py
%run DataLoader.py
%run encoder.py
%run decoder.py
%run seq2seq.py
%run model_config.py
%run metrics.py
%run ScorePrinter.py

import numpy as np
import math

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
train_dl = DataLoader('train', ('de', 'en'), max_length = 50, device = device)

In [6]:
val_dl = DataLoader('dev', ('de', 'en'), languageModels = train_dl.languageModels, max_length = 50, device = device)

In [7]:
lm1 = train_dl.languageModels[train_dl.languages[0]]
lm2 = train_dl.languageModels[train_dl.languages[1]]
model_config = ModelConfig(input_size = lm1.n_tokens, beam_width = 3, hidden_size = 256, output_size = lm2.n_tokens, rnn_type='GRU', bidirectional=False, attention = 'local', max_length=52)
#checkpoint = torch.load("./state_dict.tar")
s2s = seq2seq(model_config=model_config, device = device)

In [8]:
def train_epochs(epochs, print_every=1000):
    n_iters = len(train_dl)
    score_printer = ScorePrinter("Training", [('NLL', loss_metric),('Perplexity', perplexity)])
    
    for epoch in range(1, epochs+1):
        score_printer.startEpoch(epoch)
        idx_permutation = np.random.permutation(len(train_dl))
        for i, index in enumerate(idx_permutation):
            input_tensor, target_tensor = train_dl.tensorsFromPos(index)

            loss, output_sentence = s2s.train(input_tensor, target_tensor)
            score_printer.update(input_tensor, target_tensor, output_sentence, loss)
            
            if (i + 1) % print_every == 0:
                score_printer.printAvg(print_every)
        
        score_printer.printAvg(len(train_dl))
        val_avg_score = validate()
        train_avg_score = score_printer.get_avg_score()
        print(f"Val_avg_score: {val_avg_score}, train_avg_score: {train_avg_score}")
        score_printer.endEpoch(epoch)

        #torch.save(s2s.state_dict(),"./state_dict.tar")

In [9]:
def validate(n = None):
    score_printer = ScorePrinter("Validation", [('NLL', loss_metric),('Perplexity', perplexity)])
    n = n or len(val_dl)
    idx_permutation_val = np.random.permutation(len(val_dl))[:n]
    for j, val_index in enumerate(idx_permutation_val):
        input_tensor_val, target_tensor_val = val_dl.tensorsFromPos(val_index)
        loss, output_sentence = s2s.evaluate(input_tensor_val, target_tensor_val)
        score_printer.update(input_tensor_val, target_tensor_val, output_sentence, loss)
    score_printer.printAvg(showCount = False)
    return score_printer.get_avg_score()

In [10]:
train_epochs(2, print_every=10)



Epoch 1 started
 
[Training] 10 examples /  NLL: 10.53 Perplexity: 37867.51  
[Training] 20 examples /  NLL: 10.05 Perplexity: 27387.87  
[Training] 30 examples /  NLL: 9.85 Perplexity: 23679.49  
[Training] 40 examples /  NLL: 9.33 Perplexity: 19831.19  
[Training] 50 examples /  NLL: 9.10 Perplexity: 16782.59 

KeyboardInterrupt: 

In [None]:
from IPython.display import display, Markdown
def print_validation(position):
    input_sentence = val_dl.sentenceFromTensor('de', val_dl.tensorsFromPos(position)[0])
    display(Markdown('**Eingabe**'))
    display(Markdown(' '.join(input_sentence)))
    prediction = s2s.predict(val_dl.tensorsFromPos(position)[0])
    output_sentence = val_dl.sentenceFromTensor('en', prediction[0])
    display(Markdown('**Ausgabe**'))
    display(Markdown(' '.join(output_sentence)))
    attentions = torch.stack([tensor.squeeze() for tensor in prediction[2]])
    attentions = attentions.numpy()[:len(output_sentence)-1,:len(input_sentence)]
    display(Markdown('**Attention**'))
    show_attention(input_sentence, output_sentence, attentions)

In [None]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker

def show_attention(input_sentence, output_sentence, attentions):
    # Set up figure with colorbar
    fig = plt.figure(figsize=(16, 14), dpi= 80)
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions, cmap='bone')
    fig.colorbar(cax)
    
    ax.set_xticks(np.arange(len(input_sentence)))
    ax.set_xticklabels(input_sentence, rotation=90)
    ax.set_yticks(np.arange(len(output_sentence[1:])))
    ax.set_yticklabels(output_sentence[1:]) # ignore SOS Token

    plt.show()

In [None]:
print_validation(19)