In [1]:
import torch
from torch.nn.utils import clip_grad_norm_
torch.__version__
from common import *
import torch.nn.functional as F
import torch.nn as nn
import torch
# Text text processing library and methods for pretrained word embeddings
from torchtext import data, datasets
# Named Tensor wrappers
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from ipysankeywidget import SankeyWidget
import scipy
%reload_ext autoreload

%autoreload 2

In [2]:
# split raw data into tokens
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# add beginning-of-sentence and end-of-sentence tokens to target
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = NamedField(names=('srcSeqlen',), tokenize=tokenize_de)
EN = NamedField(names=('trgSeqlen',), tokenize=tokenize_en,
                init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

# download dataset of 200K pairs of sentences
# start with MAXLEN = 20
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
print(train.fields)
print(len(train))
print(vars(train[0]))

# WHAT DOES THIS DO?
'''src = open("valid.src", "w")
trg = open("valid.trg", "w")
for example in val:
    print(" ".join(example.src), file=src)
    print(" ".join(example.trg), file=trg)
src.close()
trg.close()'''

# build vocab, convert words to indices
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
print(DE.vocab.freqs.most_common(10))
print("Size of German vocab", len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"])

print(EN.vocab.stoi["<pad>"], EN.vocab.stoi["<unk>"])
# split data into batches
BATCH_SIZE = 32
device = torch.device('cuda:0')
train_iter, val_iter = data.BucketIterator.splits((train, val), batch_size=BATCH_SIZE, device=device,
                                                  repeat=False, sort_key=lambda x: len(x.src))

{'src': <namedtensor.text.torch_text.NamedField object at 0x7f8c57a26320>, 'trg': <namedtensor.text.torch_text.NamedField object at 0x7f8c57a26940>}
119076
{'src': ['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.'], 'trg': ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.']}
[('.', 113253), (',', 67237), ('ist', 24189), ('die', 23778), ('das', 17102), ('der', 15727), ('und', 15622), ('Sie', 15085), ('es', 13197), ('ich', 12946)]
Size of German vocab 13353
[('.', 113433), (',', 59512), ('the', 46029), ('to', 29177), ('a', 27548), ('of', 26794), ('I', 24887), ('is', 21775), ("'s", 20630), ('that', 19814)]
Size of English vocab 11560
2 3
1 0


In [3]:
context_size = 500
num_layers = 2
attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)

state_dict = torch.load('best_seq2seq_withattn_context2trg.pt')
attn_context2trg.load_state_dict(state_dict)
state_dict = torch.load('best_seq2seq_withattn_seq2context.pt')
seq2context.load_state_dict(state_dict)
attn_context2trg = attn_context2trg.cuda()
seq2context = seq2context.cuda()

attn_context2trg_optimizer = torch.optim.Adam(attn_context2trg.parameters(), lr=1e-3)
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-3)
scheduler_c2t = torch.optim.lr_scheduler.ReduceLROnPlateau(attn_context2trg_optimizer, mode="min", patience=4)
scheduler_s2c = torch.optim.lr_scheduler.ReduceLROnPlateau(seq2context_optimizer, mode="min", patience=4)



In [4]:
ppl = attn_validation_loop(0,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)

Epoch: 0, Validation PPL: 4.14863920211792


In [72]:
for ix,batch in enumerate(val_iter):        
        src = batch.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = batch.trg.values.transpose(0,1)
        if ix==2:
            break

In [73]:
encoder_outputs, encoder_hidden = seq2context(src)
decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
decoder_hidden = encoder_hidden

In [74]:
outputs = []
attns = []
for j in range(trg.shape[1] - 1):
            word_input = trg[:,j]
            decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
            outputs.append(decoder_output)
            attns.append(decoder_attention)

In [75]:
batch_attend = torch.stack(attns)

In [81]:
s_trg = trg[10,:-1]
s_src = src[10,:]
print(s_trg,s_src)

tensor([   2,   24,    5, 9359,   36,  676,    4,    3,    1,    1,    1,    1,
           1], device='cuda:0') tensor([  1,   2, 833,  59,   0,   3,  73], device='cuda:0')


In [82]:
print((' '.join([EN.vocab.itos[w] for w in s_trg])))
print((' '.join([DE.vocab.itos[w] for w in s_src])))
attention_matrix = batch_attend[:,10,0,:].detach().cpu().numpy()

<s> So , Superman can fly . </s> <pad> <pad> <pad> <pad> <pad>
<pad> . fliegen kann <unk> , Also


In [83]:
def visualize_attention(attention_matrix,s_trg,s_src):
    df = pd.DataFrame(columns = ['source','target','type','value'] )
    source = []
    target = []
    ty = []
    value = []
    attn = softmax(attention_matrix,axis=1)
    duplicates = dict()
    for ix,i in enumerate(s_trg):
        # skip padding
        if i != 1 and i!=4 and i!=21: 
            for jx,j in enumerate(s_src):
                if j != 1 and j!=2 and j!=3 and j!=0:
                    source.append(DE.vocab.itos[j])
                    target.append(EN.vocab.itos[i])
                    ty.append(EN.vocab.itos[i])
                    value.append(attn[ix,jx]*100)
    df['source'] = source
    df['target'] = target
    df['type'] = ty
    df['value'] = value
    #SankeyWidget(links=df.to_dict('records'))
    return df



In [84]:
df = visualize_attention(attention_matrix,s_trg,s_src)

In [85]:
SankeyWidget(links=df.to_dict('records'))


SankeyWidget(links=[{'source': 'fliegen', 'target': '<s>', 'type': '<s>', 'value': 7.539398223161697}, {'sourc…

[{'source': '<s>',
  'target': 'Aufklärung',
  'type': '<s>',
  'value': 5.031529441475868},
 {'source': '<s>',
  'target': 'und',
  'type': '<s>',
  'value': 5.8494411408901215},
 {'source': '<s>',
  'target': '<unk>',
  'type': '<s>',
  'value': 6.420550495386124},
 {'source': '<s>',
  'target': 'durch',
  'type': '<s>',
  'value': 8.476591855287552},
 {'source': '<s>',
  'target': 'hilft',
  'type': '<s>',
  'value': 6.7545779049396515},
 {'source': '<s>', 'target': 'und', 'type': '<s>', 'value': 7.14527815580368},
 {'source': '<s>',
  'target': 'sich',
  'type': '<s>',
  'value': 5.647348240017891},
 {'source': '<s>',
  'target': 'engagiert',
  'type': '<s>',
  'value': 6.4884185791015625},
 {'source': '<s>',
  'target': 'Schule',
  'type': '<s>',
  'value': 6.244835257530212},
 {'source': '<s>', 'target': 'neue', 'type': '<s>', 'value': 9.90273803472519},
 {'source': '<s>',
  'target': 'Meine',
  'type': '<s>',
  'value': 22.548110783100128},
 {'source': 'My',
  'target': 'Aufklär

In [48]:
ordering = [
    ['farms'],       # put "farms" on the left...
    ['customers'],   # ... and "customers" on the right.
]

In [49]:

bundles = [
    Bundle('farms', 'customers'),
]

In [50]:
sdd = SankeyDefinition(nodes, bundles, ordering)
weave(sdd, flows).to_widget(**size)

SankeyWidget(layout=Layout(height='300', width='570'), links=[{'source': 'farms^*', 'target': 'customers^*', '…

In [52]:
w = weave(sdd, flows).to_widget(**size)

In [53]:
w

SankeyWidget(layout=Layout(height='300', width='570'), links=[{'source': 'farms^*', 'target': 'customers^*', '…