In [2]:
import numpy as np
import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

# Data Wrangling

In [29]:
def generator(inp, tar):

    with open(inp, encoding='utf-8') as f_in, open(tar, encoding='utf-8') as f_ta:

        for i_l, t_l in zip(f_in, f_ta):
            yield (i_l.strip(), t_l.strip())

In [30]:
def inp_stream(inp, tar):

    return lambda _ : generator(inp, tar)

In [44]:
train_stream_fn = inp_stream('data/train.en', 'data/train.de')
eval_stream_fn = inp_stream('data/valid.en', 'data/valid.de')

In [45]:
for example in train_stream_fn(None):
    print(example)
    break

("Iran: Gonu's victims, Palestine's crisis, and a stoning suspended · Global Voices", 'Iran: Gonus Opfer, Krise in Palästina und eine verhinderte Steinigung')


In [58]:
train_stream = train_stream_fn(None)
eval_stream = eval_stream_fn(None)

In [46]:
VOCAB_FILE = 'ende_32k.subword'
VOCAB_DIR = 'data/'

In [59]:
tok_train = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tok_eval = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

In [69]:
def append_eos(stream):

    EOS = 1

    for (inp, tar) in stream:
        inp_eos = list(inp) + [EOS]
        tar_eos = list(tar) + [EOS]

        yield np.array(inp_eos), np.array(tar_eos)

In [70]:
eos_train = append_eos(tok_train)
eos_eval = append_eos(tok_eval)

In [71]:
len_train = trax.data.FilterByLength(max_length=512, length_keys=[0,1])(eos_train)
len_eval = trax.data.FilterByLength(max_length=512, length_keys=[0,1])(eos_eval)

In [72]:
boundaries = [8, 16, 32, 64, 128, 256, 512]
batch_size = [256, 128, 64, 32, 16, 8, 4, 2]

In [73]:
x_train = trax.data.BucketByLength(boundaries, batch_size, length_keys=[0,1])(len_train)
x_eval = trax.data.BucketByLength(boundaries, batch_size, length_keys=[0,1])(len_eval)

In [74]:
x_train = trax.data.AddLossWeights(id_to_mask=0)(x_train)
x_eval = trax.data.AddLossWeights(id_to_mask=0)(x_eval)

# Helper Functions

In [75]:
def tokenize(inp_str, vocab_file=None, vocab_dir=None):

    tok = next(trax.data.Tokenize(iter([inp_str]),vocab_file, vocab_dir))

    tok = list(tok) + [1]

    bacth_inp = np.reshape(np.array(tok), [1,-1])

    return bacth_inp

In [82]:
def detokenize(integers, vocab_file=None, vocab_dir=None):

    tok = list(np.squeeze(integers))

    EOS = 1

    if EOS in tok:
        tok = tok[:tok.index(EOS)]

    return trax.data.detokenize(tok, vocab_file=vocab_file, vocab_dir=vocab_dir)

In [83]:
input_batch, target_batch, mask_batch = next(x_train)

In [85]:
from termcolor import colored
index = 10

print(colored('THIS IS THE ENGLISH SENTENCE: \n', 'red'), detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \n ', 'red'), input_batch[index], '\n')
print(colored('THIS IS THE GERMAN TRANSLATION: \n', 'red'), detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: \n', 'red'), target_batch[index], '\n')

[31mTHIS IS THE ENGLISH SENTENCE: 
[0m Hugo Miranda of Angel Caido writes that he has always remembered lighting a bonfire with his family, and acknowledges that it contributes to the smoggy air. 

[31mTHIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: 
 [0m [31958 14160 10429     5     7 14414     5 14364   127 18953    33    17
   209    63   820 27717   103 13502    13  5859  3971    30   186  1104
     2     8  8409  3468 16393 20472     5    17    40 24398    33     9
     4  8720  9538   105  1433     3     1     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0] 

[31mTHIS IS THE GERMAN TRANSLATION: 
[0m Hugo Miranda von Angel Caido schreibt, dass er sich immer daran erinnert ein Lagerfeuer mit seiner Familie angezündet zu haben und stimmt zu, dass es zu der verschmutzten Luft beiträgt. 

[31mTHIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: 
[0m [31958 14160 10429     5    21 14414     

# NMT

In [91]:
def encoder(inp_vocab_size, d_model, n_encode_layers):

    inp_enc = tl.Serial(
        tl.Embedding(vocab_size=inp_vocab_size, d_feature=d_model),
        [tl.LSTM(n_units=d_model) for _ in range(n_encode_layers)]
    )

    return inp_enc

In [90]:
def pre_attention_decoder(tar_vocab_size, d_model, mode):

    tar_dec = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(vocab_size=tar_vocab_size, d_feature=d_model),
        tl.LSTM(n_units=d_model)
    )

    return tar_dec

In [104]:
def prep_attention_input(enc_activations, dec_activations, inp):

    keys = values = enc_activations

    queries = dec_activations

    mask = ~(inp == 0)

    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))

    mask += fastnp.zeros((1, 1, dec_activations.shape[1], 1))

    return queries, keys, values, mask

In [105]:
def NMTAttn(input_vocab_size=33300, target_vocab_size=33300, d_model=1024, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=4, attention_dropout=0.0, mode='train'):

    enc_activations = encoder(input_vocab_size, d_model, n_encoder_layers)

    tar_activations =  pre_attention_decoder(target_vocab_size, d_model, mode)

    model = tl.Serial(
        tl.Select([0,1,0,1]),
        tl.Parallel(enc_activations, tar_activations),
        tl.Fn('PrepareAttentionInput', prep_attention_input, n_out=4),
        tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),
        tl.Select([0,2]),
        [tl.LSTM(n_units=d_model) for _ in range(n_decoder_layers)],
        tl.Dense(target_vocab_size),
        tl.LogSoftmax()
    )

    return model

In [106]:
model = NMTAttn()
print(model)

Serial_in2_out2[
  Select[0,1,0,1]_in2_out4
  Parallel_in2_out2[
    Serial[
      Embedding_33300_1024
      LSTM_1024
      LSTM_1024
    ]
    Serial[
      Serial[
        ShiftRight(1)
      ]
      Embedding_33300_1024
      LSTM_1024
    ]
  ]
  PrepareAttentionInput_in3_out4
  Serial_in4_out2[
    Branch_in4_out3[
      None
      Serial_in4_out2[
        _in4_out4
        Serial_in4_out2[
          Parallel_in3_out3[
            Dense_1024
            Dense_1024
            Dense_1024
          ]
          PureAttention_in4_out2
          Dense_1024
        ]
        _in2_out2
      ]
    ]
    Add_in2
  ]
  Select[0,2]_in3_out2
  LSTM_1024
  LSTM_1024
  Dense_33300
  LogSoftmax
]


In [107]:
def train_task_fn(x_train):

    return training.TrainTask(
        labeled_data=x_train,
        loss_layer=tl.CrossEntropyLoss(),
        optimizer=trax.optimizers.Adam(0.01),
        lr_schedule=trax.lr.warmup_and_rsqrt_decay(1000, 0.1),
        n_steps_per_checkpoint=100
    )

In [108]:
train_task = train_task_fn(x_train)

In [109]:
eval_task = training.EvalTask(
    labeled_data=x_eval,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
)

In [110]:
training_loop = training.Loop(
    NMTAttn(mode='train'),
    train_task,
    eval_tasks=[eval_task],
)



Will not write evaluation metrics, because output_dir is None.


In [112]:
#training_loop.run(10000)

# EVAL

### Greedy Decode Test

In [113]:
model = NMTAttn(mode='eval')

In [114]:
def next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature):

    token_len = len(cur_output_tokens)

    pad_length = 2 ** np.ceil(np.log2(token_len+1))

    padded = cur_output_tokens + [0] * (pad_length - token_len)

    pad_with_batch = np.reshape(padded, (1, pad_length))

    outputs, _ = NMTAttn((input_tokens, pad_with_batch))

    log_probs = output[0, token_len, :]

    symbol = int(tl.logsoftmax_sample(log_probs, temperature))

    return symbol, float(log_probs[symbol])

In [116]:
def sampling_decode(input_sentence, NMTAttn = None, temperature=0.0, vocab_file=None, vocab_dir=None, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):

    input_tokens = tokenize(input_sentence, vocab_file, vocab_dir)

    cur_output_tokens = []

    cur_output = 0

    EOS = 1

    while cur_output != EOS:

        cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)

        cur_output_tokens.append(cur_output)

    sentence = detokenize(cur_output_tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)

    return cur_output_tokens, log_prob, sentence

In [117]:
def greedy_decode_test(sentence, NMTAttn=None, vocab_file=None, vocab_dir=None, sampling_decode=sampling_decode, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):

    _,_, sentence = sampling_decode(sentence, NMTAttn=NMTAttn, vocab_file=vocab_file, vocab_dir=vocab_dir, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize)

    print("English: ", sentence)
    print("German: ", translated_sentence)

    return translated_sentence

In [119]:
def generate_samples(sentence, n_samples, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None, sampling_decode=sampling_decode, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):

    samples, log_probs = [], []

    for _ in range(n_samples):
        sample, logp, _ = sampling_decode(sentence, NMTAttn, temperature, vocab_file=vocab_file, vocab_dir=vocab_dir, next_symbol=next_symbol)

        samples.append(sample)

        log_probs.append(logp)

    return samples, log_probs

### Similiarty / Overlap

In [125]:
def jaccard_similarity(candidate, reference):

    cand_set, ref_set = set(candidate), set(reference)

    cand_ref_intersection = cand_set.intersection(ref_set)

    cand_ref_union = cand_set.union(ref_set)

    overlap = len(cand_ref_intersection) / len(cand_ref_union)

    return overlap

### Rouge Similarity / F1 Score

In [121]:
from collections import Counter

def rouge1_similarity(system, reference):

    sys_count = Counter(system)

    ref_count = Counter(reference)

    overlap = 0

    for token in sys_count:

        tok_count_sys = sys_count.get(token)

        tok_count_ref = ref_count.get(token, 0)

        overlap += np.minimum(tok_count_sys, tok_count_ref)

    precision = overlap / sum(sys_count.values())
    recall = overlap / sum(ref_count.values())

    F1 = (2 * precision * recall) / (precision + recall)

    return F1

In [122]:
rouge1_similarity([1, 2, 3], [1, 2, 3, 4])

np.float64(0.8571428571428571)

In [123]:
def average_overlap(similarity_fn, samples, *ignore_params):

    scores = {}

    for i_can, can in enumerate(samples):

        overlap = 0

        for i_sam, sam in enumerate(samples):

            if i_can == i_sam:
                continue

            overlap += similarity_fn(sam, can)

        score = overlap / (len(samples) - 1)

        scores[i_can] = score

    return scores

In [126]:
average_overlap(jaccard_similarity, [[1, 2, 3], [1, 2, 4], [1, 2, 4, 5]], [0.4, 0.2, 0.5])

{0: 0.45, 1: 0.625, 2: 0.575}

In [135]:
def weighted_avg_overlap(similarity_fn, samples, log_probs):

    scores = {}

    for i_can, can in enumerate(samples):

        overlap = 0
        weighted_sum = 0

        for i_sam, (sam, log_p) in enumerate(zip(samples, log_probs)):

            if i_can == i_sam:
                continue

            s_overlap = similarity_fn(sam, can)

            sample_p = float(np.exp(log_p))

            weighted_sum += sample_p

            overlap += s_overlap * sample_p

        score = overlap / weighted_sum

        scores[i_can] = score

    return scores

In [136]:
weighted_avg_overlap(jaccard_similarity, [[1, 2, 3], [1, 2, 4], [1, 2, 4, 5]], [0.4, 0.2, 0.5])

{0: 0.44255574831883415, 1: 0.631244796869735, 2: 0.5575581009406329}

# Final Minimum bayes risk

In [137]:
def mbr_decode(sentence, n_samples, score_fn, similarity_fn, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None, generate_samples=generate_samples, sampling_decode=sampling_decode, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):

    samples, log_probs = generate_samples(sentence, n_samples, NMTAttn, temperature, vocab_file, vocab_dir)

    scores = weighted_avg_overlap(similarity_fn, samples, log_probs)

    max_score_key = max(scores, key=scores.get)

    translated_sentence = detokenize(samples[max_score_key], vocab_file, vocab_dir)

    return (translated_sentence, max_score_key, scores)