# Start

NEW TASKS:
* [X] Seq2Seq: sort by src_len and unsort output --> ensure output matches with trg
* [X] Pivot model: ensure it works for $n$ seq2seq models
* [ ] Triang model: ensure outputs from all submodels match

In [None]:
# piv_endefr_74kset_2.pt using PivotModel in bentrevett/pytorch-seq2seq-OLD.ipynb

In [None]:
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/4%20-%20Packed%20Padded%20Sequences%2C%20Masking%2C%20Inference%20and%20BLEU.ipynb
# based on https://gmihaila.github.io/tutorial_notebooks/pytorchtext_bucketiterator/#dataset-class

In [1]:
piv_langs = ['es', 'it', 'pt', 'ro']
langs = ['en', 'fr']
DIR_PATH = './data/'

# Setup

In [2]:
# from google.colab import drive
# drive.mount('/content/gdrive')

In [3]:
# !pip3 install torch==1.8.2 torchvision==0.9.2 torchaudio==0.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111 -q

In [2]:
import torch
print(torch.__version__)

1.8.2+cu111


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# !pip install torchtext==0.9 -q
import torchtext
print(torchtext.__version__)

0.9.0


In [8]:
# if 'en' in langs:
#   !python -m spacy download en_core_web_sm -q
# if 'de' in langs:
#   !python -m spacy download de_core_news_sm -q
# if 'fr' in langs:
#   !python -m spacy download fr_core_news_sm -q
# if 'it' in langs:
#   !python -m spacy download it_core_news_sm -q
# if 'es' in langs:
#   !python -m spacy download es_core_news_sm -q
# if 'pt' in langs:
#   !python -m spacy download pt_core_news_sm -q
# if 'ro' in langs:
#   !python -m spacy download ro_core_news_sm -q

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.data import Dataset, Example
from torchtext.data.metrics import bleu_score

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time
import pickle
from tqdm import tqdm

# My Section

## Setup

In [5]:
FIELD_DICT = {}
if 'en' in langs:
  spacy_en = spacy.load('en_core_web_sm')
  def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]
  EN_FIELD = Field(tokenize = tokenize_en, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['en'] = EN_FIELD
if 'de' in langs:
  spacy_de = spacy.load('de_core_news_sm')
  def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]
  DE_FIELD = Field(tokenize = tokenize_de, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['de'] = DE_FIELD
if 'fr' in langs:
  spacy_fr = spacy.load('fr_core_news_sm')
  def tokenize_fr(text):
    return [tok.text for tok in spacy_fr.tokenizer(text)]
  FR_FIELD = Field(tokenize = tokenize_fr, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['fr'] = FR_FIELD
if 'it' in langs:
  spacy_it = spacy.load('it_core_news_sm')
  def tokenize_it(text):
    return [tok.text for tok in spacy_it.tokenizer(text)]
  IT_FIELD = Field(tokenize = tokenize_it, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['it'] = IT_FIELD
if 'es' in langs:
  spacy_es = spacy.load('es_core_news_sm')
  def tokenize_es(text):
    return [tok.text for tok in spacy_es.tokenizer(text)]
  ES_FIELD = Field(tokenize = tokenize_es, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['es'] = ES_FIELD
if 'pt' in langs:
  spacy_pt = spacy.load('pt_core_news_sm')
  def tokenize_pt(text):
    return [tok.text for tok in spacy_pt.tokenizer(text)]
  PT_FIELD = Field(tokenize = tokenize_pt, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['pt'] = PT_FIELD
if 'ro' in langs:
  spacy_ro = spacy.load('ro_core_news_sm')
  def tokenize_ro(text):
    return [tok.text for tok in spacy_ro.tokenizer(text)]
  RO_FIELD = Field(tokenize = tokenize_ro, init_token = '<sos>', eos_token = '<eos>', lower = True, include_lengths = True)
  FIELD_DICT['ro'] = RO_FIELD

## Data

In [6]:
BATCH_SIZE = 32

train_len = 64000
valid_len = 3200
test_len = 6400

train_pt = train_len
valid_pt = train_pt + valid_len
test_pt = valid_pt + test_len

print(train_pt, valid_pt, test_pt)

64000 67200 73600


In [None]:
# https://discuss.pytorch.org/t/how-to-save-and-load-torchtext-data-field-build-vocab-result/50407/3
def save_vocab(vocab, path):
  with open(path, 'w+', encoding='utf-8') as f:
    for token, index in vocab.stoi.items():
      f.write(f'{index}\t{token}\n')
def read_vocab(path):
  vocab = dict()
  with open(path, 'r', encoding='utf-8') as f:
    for line in f:
      index, token = line.split('\t')
      vocab[token] = int(index)
  return vocab

In [None]:
# with open('/content/gdrive/MyDrive/Colab Notebooks/eaai24/Datasets/endefr_75kpairs_2k5-freq-words.pkl', 'rb') as f:
# with open('/content/gdrive/MyDrive/Colab Notebooks/eaai24/Datasets/enfr_160kpairs_2k5-freq-words.pkl', 'rb') as f:
#   data = pickle.load(f)
# data[80], len(data)

In [None]:
# EnDeFrItEsPtRo-60k-most10k-1.pkl is most suitable. Each lang has ~5-7k words
# EnDeFrItEsPtRo-76k-most5k.pkl: each lang has 6->12k words (too much. use this in the future)

In [7]:
dataname = 'EnDeFrItEsPtRo-76k-most5k.pkl'
with open(f'{DIR_PATH}/{dataname}', 'rb') as f:
  data = pickle.load(f)
data[8], len(data)

({'en': 'What is the result?',
  'de': 'Und wie sehen die Ergebnisse aus?',
  'fr': 'Quelle en est la conséquence ?',
  'it': 'Quali sono i risultati?',
  'es': '¿Cuáles son los resultados?',
  'pt': 'Quais são os resultados?',
  'ro': 'Care este rezultatul?'},
 76245)

In [8]:
# data_set = [[pair['en'], pair['it'], pair['fr']] for pair in data]
# FIELDS = [('en', EN_FIELD), ('it', IT_FIELD), ('fr', FR_FIELD)]
data_set = [[pair[lang] for lang in langs] for pair in data]
FIELDS = [(lang, FIELD_DICT[lang]) for lang in langs]
train_examples = list(map(lambda x: Example.fromlist(list(x), fields=FIELDS), data_set[: train_pt]))
valid_examples = list(map(lambda x: Example.fromlist(list(x), fields=FIELDS), data_set[train_pt : valid_pt]))
test_examples = list(map(lambda x: Example.fromlist(list(x), fields=FIELDS), data_set[valid_pt : test_pt]))

In [9]:
train_dt = Dataset(train_examples, fields=FIELDS)
valid_dt = Dataset(valid_examples, fields=FIELDS)
test_dt = Dataset(test_examples, fields=FIELDS)

In [10]:
# EN_FIELD.build_vocab(train_dt, min_freq = 2)
# IT_FIELD.build_vocab(train_dt, min_freq = 2)
# FR_FIELD.build_vocab(train_dt, min_freq = 2)
# len(EN_FIELD.vocab), len(IT_FIELD.vocab), len(FR_FIELD.vocab)
for lang in langs:
  FIELD_DICT[lang].build_vocab(train_dt, min_freq = 2)
  print(f'{lang}: {len(FIELD_DICT[lang].vocab)}')

en: 6964
fr: 9703


In [None]:
# save_vocab(EN_FIELD.vocab, f'{DIR_PATH}/Datasets/EnDeFrItEsPtRo-76k-most5k-en_vocab.txt')
# EN_FIELD.vocab = read_vocab(f'{DIR_PATH}/Datasets/EnDeFrItEsPtRo-76k-most5k-en_vocab.txt')
# FR_FIELD.vocab = read_vocab(f'{DIR_PATH}/Datasets/EnDeFrItEsPtRo-76k-most5k-fr_vocab.txt')

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_dt, valid_dt, test_dt),
     batch_size = 32,
     sort_within_batch = True,
     sort_key = lambda x : len(x.en),
     device = device)
len(train_iterator), len(valid_iterator), len(test_iterator)

(2000, 100, 200)

In [12]:
for i, batch in enumerate(train_iterator):
  break
print(batch.en[0].shape, batch.en[1])
print(batch.fr[0].shape, batch.fr[1])
print(batch.fields)
print(batch)

torch.Size([41, 32]) tensor([41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 40, 40, 40,
        40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
       device='cuda:0')
torch.Size([55, 32]) tensor([52, 49, 44, 35, 54, 51, 41, 49, 39, 41, 50, 52, 45, 39, 44, 38, 39, 45,
        44, 36, 44, 43, 39, 34, 35, 46, 37, 41, 55, 48, 38, 39],
       device='cuda:0')
dict_keys(['en', 'fr'])

[torchtext.legacy.data.batch.Batch of size 32]
	[.en]:('[torch.cuda.LongTensor of size 41x32 (GPU 0)]', '[torch.cuda.LongTensor of size 32 (GPU 0)]')
	[.fr]:('[torch.cuda.LongTensor of size 55x32 (GPU 0)]', '[torch.cuda.LongTensor of size 32 (GPU 0)]')


In [13]:
print(vars(batch)['en'][0].shape, vars(batch)['en'][1])
print(vars(batch)['fr'][0].shape, vars(batch)['fr'][1])

torch.Size([41, 32]) tensor([41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 40, 40, 40,
        40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
       device='cuda:0')
torch.Size([55, 32]) tensor([52, 49, 44, 35, 54, 51, 41, 49, 39, 41, 50, 52, 45, 39, 44, 38, 39, 45,
        44, 36, 44, 43, 39, 34, 35, 46, 37, 41, 55, 48, 38, 39],
       device='cuda:0')


In [14]:
src_sent, piv_sent, trg_sent = [], [], []
for i in batch.en[0][: , 0]:
  src_sent.append(EN_FIELD.vocab.itos[i])
  print(i.item(), end=' ')
print()
for i in batch.fr[0][:, 0]:
  trg_sent.append(FR_FIELD.vocab.itos[i])
print(' '.join(src_sent))
print(' '.join(trg_sent))

2 10 12 429 5 20 3592 13 476 181 517 4 371 128 224 195 5 4 1818 7 4 47 8 44 869 5 377 9 1121 271 5 9 4 604 7 65 49 119 145 6 3 
<sos> in that context , it holds a key role regarding the fight against climate change , the transition of the eu to an efficient , sustainable and competitive economy , and the strengthening of europe 's energy security . <eos>
<sos> dans ce contexte , elle joue un rôle de premier plan dans la lutte contre le changement climatique , dans la transition de l' ue vers une économie utilisant efficacement les ressources , durable et compétitive , ainsi que dans le renforcement de la sécurité énergétique de l' europe . <eos> <pad> <pad> <pad>


In [20]:
for tok in [EN_FIELD.unk_token, EN_FIELD.pad_token, EN_FIELD.init_token, EN_FIELD.eos_token]:
    print(tok, EN_FIELD.vocab.stoi[tok])
print('out of vocab', EN_FIELD.vocab.stoi['asodjosjad'])
i = 661
EN_FIELD.vocab.itos[i]

<unk> 0
<pad> 1
<sos> 2
<eos> 3
out of vocab 0


'yes'

In [40]:
z = 25
print([EN_FIELD.vocab.itos[i] for i in range(z)])
print([FR_FIELD.vocab.itos[i] for i in range(z)])

['<unk>', '<pad>', '<sos>', '<eos>', 'the', ',', '.', 'of', 'to', 'and', 'in', 'is', 'that', 'a', 'i', 'this', 'for', 'we', 'on', 'european', 'it', '(', ')', 'be', 'have']
['<unk>', '<pad>', '<sos>', '<eos>', 'de', ',', '.', 'la', 'le', "l'", 'et', 'à', 'les', 'des', 'que', 'en', 'est', "d'", 'nous', 'pour', 'du', 'je', 'une', 'dans', 'un']


## Model

The correct implementation of that [paper](https://arxiv.org/pdf/1409.0473.pdf) should be like this: [repo](https://github.com/graykode/nlp-tutorial#4-attention-mechanism) - [colab](https://colab.research.google.com/github/graykode/nlp-tutorial/blob/master/4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb). My method is a bit different, actually it is implemented based on [Luong et. al.](https://arxiv.org/pdf/1508.04025.pdf) with Global Attention using concat method ([implement in tf](https://github.com/philipperemy/keras-attention/blob/master/attention/attention.py)).

### Encoder

In [None]:
class Encoder(nn.Module):
  def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
    super().__init__()
    self.embedding = nn.Embedding(input_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
    self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, src, src_len):
    #src = [src len, batch size]
    #src_len = [batch size]
    embedded = self.dropout(self.embedding(src))  #embedded = [src len, batch size, emb dim]

    #need to explicitly put lengths on cpu!
    packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to('cpu'))

    #  when the input is a pad token are all zeros
    packed_outputs, hidden = self.rnn(packed_embedded)
    #packed_outputs is a packed sequence containing all hidden states
    #hidden is now from the final non-padded element in the batch

    outputs, len_list = nn.utils.rnn.pad_packed_sequence(packed_outputs) #outputs is now a non-packed sequence, all hidden states obtained
    #  when the input is a pad token are all zeros

    #outputs = [src len, batch size, hid dim * num directions]
    #hidden = [n layers * num directions, batch size, hid dim]

    #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
    #outputs are always from the last layer

    #hidden [-2, :, : ] is the last of the forwards RNN
    #hidden [-1, :, : ] is the last of the backwards RNN

    #initial decoder hidden is final hidden state of the forwards and backwards
    #  encoder RNNs fed through a linear layer
    hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))

    #outputs = [src len, batch size, enc hid dim * 2]
    #hidden = [batch size, dec hid dim]
    return outputs, hidden

### Attn

In [None]:
class Attention(nn.Module):
  def __init__(self, enc_hid_dim, dec_hid_dim):
    super().__init__()
    self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
    self.v = nn.Linear(dec_hid_dim, 1, bias = False)

  def forward(self, hidden, encoder_outputs, mask):
    #hidden = [batch size, dec hid dim]
    #encoder_outputs = [src len, batch size, enc hid dim * 2]
    batch_size = encoder_outputs.shape[1]
    src_len = encoder_outputs.shape[0]

    #repeat decoder hidden state src_len times
    hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  #hidden = [batch size, src len, dec hid dim]
    encoder_outputs = encoder_outputs.permute(1, 0, 2)  #encoder_outputs = [batch size, src len, enc hid dim * 2]
    energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) #energy = [batch size, src len, dec hid dim]

    attention = self.v(energy).squeeze(2) #attention = [batch size, src len]
    attention = attention.masked_fill(mask == 0, -1e10)
    return F.softmax(attention, dim = 1)

### Decoder

In [None]:
class Decoder(nn.Module):
  def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
    super().__init__()
    self.output_dim = output_dim
    self.attention = attention
    self.embedding = nn.Embedding(output_dim, emb_dim)
    self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
    self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, input, hidden, encoder_outputs, mask):
    #input = [batch size]
    #hidden = [batch size, dec hid dim]
    #encoder_outputs = [src len, batch size, enc hid dim * 2]
    #mask = [batch size, src len]
    input = input.unsqueeze(0)  #input = [1, batch size]
    embedded = self.dropout(self.embedding(input))  #embedded = [1, batch size, emb dim]

    a = self.attention(hidden, encoder_outputs, mask) #a = [batch size, src len]
    a = a.unsqueeze(1)  #a = [batch size, 1, src len]

    encoder_outputs = encoder_outputs.permute(1, 0, 2)  #encoder_outputs = [batch size, src len, enc hid dim * 2]

    weighted = torch.bmm(a, encoder_outputs)  #weighted = [batch size, 1, enc hid dim * 2]
    weighted = weighted.permute(1, 0, 2)  #weighted = [1, batch size, enc hid dim * 2]

    rnn_input = torch.cat((embedded, weighted), dim = 2)  #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]

    output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
    #output = [seq len, batch size, dec hid dim * n directions]
    #hidden = [n layers * n directions, batch size, dec hid dim]

    #seq len, n layers and n directions will always be 1 in this decoder, therefore:
    #output = [1, batch size, dec hid dim]
    #hidden = [1, batch size, dec hid dim]
    #this also means that output == hidden
    assert (output == hidden).all()

    embedded = embedded.squeeze(0)
    output = output.squeeze(0)
    weighted = weighted.squeeze(0)

    prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))  #prediction = [batch size, output dim]
    return prediction, hidden.squeeze(0), a.squeeze(1)

### Seq2Seq

In [None]:
class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder, src_pad_idx, device):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_pad_idx = src_pad_idx
    self.device = device

  def create_mask(self, src):
    mask = (src != self.src_pad_idx).permute(1, 0)
    return mask

  def forward(self, datas, criterion=None, teacher_forcing_ratio = 0.5):
    #src = [src len, batch size]
    #src_len = [batch size]
    #trg = [trg len, batch size]
    #trg_len = [batch size]
    #teacher_forcing_ratio is probability of using trg to be input else prev output to be input for next prediction.
    (src, src_len), (trg, _) = datas
    batch_size = src.shape[1]
    trg_len = trg.shape[0]
    trg_vocab_size = self.decoder.output_dim

    # SORT
    sort_ids, unsort_ids = self.sort_by_sent_len(src_len)
    src, src_len, trg = src[:, sort_ids], src_len[sort_ids], trg[:, sort_ids]

    #tensor to store decoder outputs
    outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

    #encoder_outputs is all hidden states of the input sequence, back and forwards
    #hidden is the final forward and backward hidden states, passed through a linear layer
    encoder_outputs, hidden = self.encoder(src, src_len)

    #first input to the decoder is the <sos> tokens
    input = trg[0,:]

    mask = self.create_mask(src)  #mask = [batch size, src len]

    for t in range(1, trg_len):
      #insert input token embedding, previous hidden state, all encoder hidden states and mask
      #receive output tensor (predictions) and new hidden state
      output, hidden, _ = self.decoder(input, hidden, encoder_outputs, mask)

      #place predictions in a tensor holding predictions for each token
      outputs[t] = output

      #if teacher forcing, use actual next token as next input. Else, use predicted token
      input = trg[t] if random.random() < teacher_forcing_ratio else output.argmax(1)

    if criterion != None:
      loss = self.compute_loss(outputs, trg, criterion)
      return loss, outputs[:, unsort_ids, :]
    return outputs[:, unsort_ids, :]

  def compute_loss(self, output, trg, criterion):
    #output = (trg_len, batch_size, trg_vocab_size)
    #trg = [trg len, batch size]
    output = output[1:].view(-1, output.shape[-1])  #output = [(trg len - 1) * batch size, output dim]
    trg = trg[1:].view(-1)  #trg = [(trg len - 1) * batch size]
    loss = criterion(output, trg)
    return loss

  # NEWLY ADDED ##########################
  def sort_by_sent_len(self, sent_len):
    _, sort_ids = sent_len.sort(descending=True)
    unsort_ids = sort_ids.argsort()
    return sort_ids, unsort_ids
  # END ADDED ############################

### Pivot model (update)

**still need to reorganize code to use for infer (no criterions, only 1 data: src, src_len)**

In [None]:
class PivotSeq2Seq(nn.Module):
  def __init__(self, models: list, fields: list, device, alpha=1.1, lamda=0.75):
    super().__init__()
    self.num_model = len(models)
    self.fields = fields
    self.num_field = len(fields)
    self.device = device
    self.alpha = alpha
    self.lamda = lamda
    self.add_submodels(models)

  def add_submodels(self, models: list):
    for i, submodel in enumerate(models):
      assert isinstance(submodel, Seq2Seq), type(submodel)
      self.add_module(f'model_{i}', submodel)
    assert len(models)+1 == len(fields), f"Not enough Fields for models: num_field={len(fields)} != {len(models)+1}"

  def forward(self, datas: list, criterions=None, teacher_forcing_ratio=0.5):
    '''
    datas: list of data: [(src, src_len), (piv1, piv_len1), ... , (pivM, piv_lenM), (trg, trg_len)] given M models
      src = [src len, batch_size]
      src_len = [batch_size]
      ...
      trg = [trg len, batch_size]
      trg_len = [batch_size]
    criterions: list of criterion for each model
    '''
    if criterions != None:
      loss_list, output_list = self.run(datas, criterions, teacher_forcing_ratio)
      total_loss = self.compute_loss(loss_list)
      return total_loss, output_list[-1]
    else:
      criterions = [None for _ in range(self.num_model)]
      _, output_list = self.run(datas, criterions, teacher_forcing_ratio)
      return output_list[-1]

  def run(self, datas, criterions, teacher_forcing_ratio):
    assert self.num_model+1 == len(datas), f"Not enough datas for models: data_len={len(datas)} != {self.num_model+1}"
    assert self.num_model == len(criterions), f'Criterions must have for each model: num_criterion={len(criterions)} != {self.num_model}'

    output_list, loss_list = [], []
    for i in range(self.num_model):
      isForceOn = True if i==0 else random.random() < teacher_forcing_ratio # 1st model must always use src

      # GET NEW INPUT
      src, src_len = datas[i] if isForceOn else self.process_output(output_list[-1], self.fields[i+1])
      trg, trg_len = datas[i+1]

      # FORWARD MODEL
      model = getattr(self, f'model_{i}') # Seq2Seq model already sort src by src_len in forward
      data = [(src, src_len), (trg, trg_len)]
      criterion = criterions[i]
      output = model(data, criterion, 0 if criterion==None else teacher_forcing_ratio)

      if criterion == None:
        output_list.append(output)
      else:
        assert len(output) == 2, 'With criterion, model should return loss & prediction'
        loss, out = output
        loss_list.append(loss)
        output_list.append(out)

    return loss_list, output_list

  def compute_loss(self, loss_list):
    total_loss = 0.0
    for i in range(len(loss_list) - 1): # except final output
      total_loss += loss_list[i]
    total_loss += self.alpha*loss_list[-1]
    return total_loss + self.lamda*self.compute_embed_loss()

  def compute_embed_loss(self):
    embed_loss = 0.0
    for i in range(1, self.num_model):
      model1 = getattr(self, f'model_{i-1}')
      model2 = getattr(self, f'model_{i}')
      embed_loss += torch.sum(F.pairwise_distance(model1.decoder.embedding.weight, model2.encoder.embedding.weight, p=2))
    return embed_loss

  def sort_by_src_len(self, piv, piv_len, datas): # piv = [piv_len, batch_size]
    piv_len, sorted_ids = piv_len.sort(descending=True)
    sorted_datas = [(sent[:, sorted_ids], sent_len[sorted_ids]) for (sent, sent_len) in datas]
    return piv[:, sorted_ids], piv_len, sorted_datas  # piv sorted along batch_size

  def process_output(self, output, piv_field):
    # output = [trg len, batch size, output dim]
    # trg = [trg len, batch size]
    # Process output1 to be input for model2
    seq_len, N, _ = output.shape
    tmp_out = output.argmax(2)  # tmp_out = [seq_len, batch_size]
    # re-create pivot as src for model2
    piv = torch.zeros_like(tmp_out).type(torch.long).to(output.device)
    piv[0, :] = torch.full_like(piv[0, :], piv_field.vocab.stoi[piv_field.init_token])  # fill all first idx with sos_token

    for i in range(1, seq_len):  # for each i in seq_len
      # if tmp_out's prev is eos_token, replace w/ pad_token, else current value
      eos_mask = (tmp_out[i-1, :] == piv_field.vocab.stoi[piv_field.eos_token])
      piv[i, :] = torch.where(eos_mask, piv_field.vocab.stoi[piv_field.pad_token], tmp_out[i, :])
      # if piv's prev is pad_token, replace w/ pad_token, else current value
      pad_mask = (piv[i-1, :] == piv_field.vocab.stoi[piv_field.pad_token])
      piv[i, :] = torch.where(pad_mask, piv_field.vocab.stoi[piv_field.pad_token], piv[i, :])

    # Trim down extra pad tokens
    tensor_list = [piv[i] for i in range(seq_len) if not all(piv[i] == piv_field.vocab.stoi[piv_field.pad_token])]  # tensor_list = [new_seq_len, batch_size]
    piv = torch.stack([x for x in tensor_list], dim=0).type(torch.long).to(output.device)
    assert not all(piv[-1] == piv_field.vocab.stoi[piv_field.pad_token]), 'Not completely trim down tensor'

    # get seq_id + eos_tok id of each sequence
    piv_ids, eos_ids = (piv.permute(1, 0) == piv_field.vocab.stoi[piv_field.eos_token]).nonzero(as_tuple=True)  # piv_len = [N]
    piv_len = torch.full_like(piv[0], seq_len).type(torch.long)  # init w/ longest seq
    piv_len[piv_ids] = eos_ids + 1 # seq_len = eos_tok + 1

    return piv, piv_len

### Triangulate model

In [None]:
class TriangSeq2Seq(nn.Module):
  def __init__(self, models: list, output_dim, device, alpha=1.1, method='max', train_backbone=True):
    # output_dim = trg vocab size
    super(TriangSeq2Seq, self).__init__()
    self.num_model = len(models)
    self.output_dim = output_dim
    self.device = device
    self.alpha = alpha
    self.method = method
    self.train_backbone = train_backbone
    if method=='weighted':
      self.head = nn.Sequential(
          nn.ReLU(),
          nn.Linear(output_dim*self.num_model, output_dim)
      )
    elif method=='weighted_1':
      # self.head = nn.Sequential(  # traing-EnFr-EnEsFr-dropout-1
      #     nn.Linear(self.num_model, self.num_model),
      #     nn.ReLU(),
      #     nn.Dropout(p=0.5),
      #     nn.Linear(self.num_model, 1),
      # )
      # self.head = nn.Sequential(  # traing-EnFr-EnEsFr-1
      #     nn.Linear(self.num_model, self.num_model*2),
      #     nn.ReLU(),
      #     nn.Linear(self.num_model*2, self.num_model*2),
      #     nn.ReLU(),
      #     nn.Linear(self.num_model*2, 1),
      # )
      # self.head = nn.Sequential(  # traing-EnFr-EnEsFr-1-1
      #     nn.Linear(self.num_model, self.num_model*2),
      #     nn.ReLU(),
      #     nn.Dropout(p=0.2),
      #     nn.Linear(self.num_model*2, self.num_model*2),
      #     nn.ReLU(),
      #     nn.Dropout(p=0.2),
      #     nn.Linear(self.num_model*2, 1),
      # )
      self.head = nn.Sequential(  # traing-EnFr-EnEsFr-dense5
            nn.Linear(self.num_model, 5),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(5, 5),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(5, 1),
        )
    for i in range(self.num_model):
      for param in models[i].parameters():
        param.requires_grad = train_backbone
      self.add_module(f'model_{i}', models[i])

  def forward(self, datas: dict, criterions=None, teacher_forcing_ratio=0.5):
    '''
    datas: dict of data:
      {"model_0": (src, src_len, trg, trg_len), "model_1": [(src, src_len), (piv, piv_len), (trg, trg_len)], ..., "TRG": (trg, trg_len)}
      src = [src len, batch size]
      src_len = [batch size]
    criterions: dict of criterions
      {"model_0": criterion_0, "model_1": criterion_1, ..., "TRG": criterion_M}
    '''
    if criterions != None:
      loss_list, output_list = self.run(datas, criterions, teacher_forcing_ratio)
      final_out = self.get_final_pred(output_list)
      if self.method!='max':total_loss = self.alpha*self.compute_final_pred_loss(final_out, datas["TRG"], criterions["TRG"]) + self.compute_submodels_loss(loss_list)
      else: total_loss = self.compute_submodels_loss(loss_list)
      return total_loss, final_out
    else:
      criterions = {f'model_{i}':None for i in range(self.num_model)}
      criterions['TRG'] = None
      loss_list, output_list = self.run(datas, criterions, teacher_forcing_ratio)
      final_out = self.get_final_pred(output_list)
      return final_out

  def run(self, datas, criterions, teacher_forcing_ratio):
    assert self.num_model+1 == len(datas), f"Not enough datas for models: data_len={len(datas)} != {self.num_model+1}"  # include 'TRG'
    assert self.num_model+1 == len(criterions), f'Criterions must have for each model: num_criterion={len(criterions)} != {self.num_model+1}' # include 'TRG'

    output_list = []
    loss_list = []
    for i in range(self.num_model):
      data = datas[f'model_{i}']
      model = getattr(self, f'model_{i}')
      criterion = criterions[f'model_{i}']
      output = model(data, criterion, 0 if criterion==None else teacher_forcing_ratio)

      if criterion == None:
        output_list.append(output)
      else:
        assert len(output) == 2, 'With criterion, model should return loss & prediction'
        loss_list.append(output[0])
        output_list.append(output[1])

    return loss_list, output_list

  def compute_submodels_loss(self, loss_list):
    total_loss = 0.0
    for loss in loss_list:
      total_loss += loss
    return total_loss

  def compute_final_pred_loss(self, output, data, criterion):
    #output = (trg_len, batch_size, trg_vocab_size)
    #data = [trg, trg_len]  # trg.shape = [seq_len, batch_size]
    trg, _ = data
    output = output[1:].reshape(-1, output.shape[-1])  #output = [(trg len - 1) * batch size, output dim]
    trg = trg[1:].reshape(-1)  #trg = [(trg len - 1) * batch size]
    loss = criterion(output, trg)
    return loss

  def get_final_pred(self, output_list):  # output_list[0] shape = [seq_len, N, out_dim]
    # assert all([output_list[i].shape == output_list[i-1].shape for i in range(1, len(output_list))]), 'all outputs must match shape [seq_len, N, out_dim]'
    seq_len, N, out_dim = output_list[0].shape
    if self.method=='weighted':
      linear_in = torch.cat([out for out in output_list], dim=-1) # linear_in = [seq_len, N, out_dim * num_model]. Note that num_model = len(output_list)
      final_out = self.head(linear_in)  # final_out = [seq_len, N, out_dim]
      return final_out
    elif self.method=='weighted_1':
      output_list = [out.permute(1, 0, 2).reshape(N, -1) for out in output_list]  # [N, seq_len, out_dim] --> [N, seq_len*out_dim]
      final_out = self.head(torch.stack(output_list, dim=-1)).squeeze(-1)  # [N, seq_len*out_dim, num_model] --> [N, seq_len*out_dim, 1] --> [N, seq_len*out_dim]
      return final_out.reshape(N, seq_len, out_dim).permute(1, 0, 2) # [N, seq_len, out_dim] --> [seq_len, N, out_dim]
    elif self.method=='average':
      outputs = torch.mean(torch.stack(output_list, dim=0), dim=0)
      return outputs
    elif self.method=='max':
      all_t = torch.stack(output_list, dim=-1)
      prob_ts = torch.stack([F.softmax(d, -1) for d in output_list], dim=-1)

      final_selected_ts = []
      for sent_id in range(N):
        # get ids of each model (get selected words)
        ids_list = []
        for m in range(self.num_model):
          m_t = prob_ts[..., m] # [seq_len, N, out_dim]
          ids_list.append(torch.argmax(m_t[:, sent_id, :], dim=-1, keepdim=True))
        # get the confusion matrix
        all_pairs = []
        for t in range(self.num_model):
          t_eachM = []
          for m in range(self.num_model):
            m_t = prob_ts[..., m]
            t_eachM.append(torch.gather(m_t[:, sent_id, :], -1, ids_list[t]))
          all_pairs.append(t_eachM)
        # calculate the prob of each sent
        t_m_prod = []
        for t in range(self.num_model):
          t_eachM = []
          for m in range(self.num_model):
            t1_m1 = all_pairs[t][m]
            t_eachM.append(torch.mean(t1_m1, dim=0))  # original: prod
          t_m_prod.append(t_eachM)
        t_totals = [torch.stack(t_m_prod[t], dim=-1)for t in range(self.num_model)]
        t_means = [torch.mean(t_totals[t], dim=-1) for t in range(self.num_model)]
        t_cats = torch.cat(t_means, dim=-1)
        t_selects = torch.argmax(t_cats)
        selected_t = torch.select(all_t[:, sent_id, ...], -1, t_selects)
        final_selected_ts.append(selected_t)
      output = torch.stack(final_selected_ts, dim=1)
      return output
    else:
      return output_list[0]

## Train func

In [None]:
def update_trainlog(data: list, filename: str=f'{DIR_PATH}/training_log.txt'):
  ''' Update training log w/ new losses
  Args:
      data (List): a list of infor for many epochs as tuple, each tuple has model_name, loss, etc.
      filename (String): path + file_name
  Return:
      None: new data is appended into train-log
  '''
  with open(filename, 'a') as f: # save
    for epoch in data:
      f.write(','.join(epoch))
      f.write("\n")
  print('update_trainlog SUCCESS')
  return []
def init_weights(m):
  for name, param in m.named_parameters():
    if 'weight' in name:
      nn.init.normal_(param.data, mean=0, std=0.01)
    else:
      nn.init.constant_(param.data, 0)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Seq2Seq

In [None]:
def trainSeq2Seq(model, iterator, optimizer, criterion, clip):
  model.train()
  epoch_loss = 0.0
  for batch in tqdm(iterator):
    optimizer.zero_grad()
    datas = [batch.en, batch.fr]
    loss, _ = model(datas, criterion, 0.5)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    epoch_loss += loss.item()
  return epoch_loss / len(iterator)

def evaluateSeq2Seq(model, iterator, criterion):
  model.eval()
  epoch_loss = 0.0
  with torch.no_grad():
    for batch in tqdm(iterator):
      datas = [batch.en, batch.fr]
      loss, _ = model(datas, criterion, 0) # turn off teacher forcing
      epoch_loss += loss.item()
    return epoch_loss / len(iterator)

### Pivot

In [None]:
def trainPivot(model, iterator, optimizer, criterions, clip):
  model.train()
  epoch_loss = 0.0
  for batch in tqdm(iterator):
    optimizer.zero_grad()
    model_inputs = [batch.en, batch.es, batch.fr]
    loss, _ = model(model_inputs, criterions, 0.5)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    epoch_loss += loss.item()
  return epoch_loss / len(iterator)

def evaluatePivot(model, iterator, criterions):
  model.eval()
  epoch_loss = 0.0
  with torch.no_grad():
    for batch in tqdm(iterator):
      model_inputs = [batch.en, batch.es, batch.fr]
      loss, _ = model(model_inputs, criterions, 0)
      epoch_loss += loss.item()
    return epoch_loss / len(iterator)

### Triangulate

In [None]:
# for triangulate model
def trainTriang(model, iterator, optimizer, criterions, clip):
  model.train()
  epoch_loss = 0.0
  for batch in tqdm(iterator):
    optimizer.zero_grad()
    model_inputs = {
        'model_0': [batch.en, batch.fr],
        'model_1': [batch.en, vars(batch)['es'], batch.fr],
        'TRG': batch.fr
    }
    loss, _ = model(model_inputs, criterions, 0.5)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    epoch_loss += loss.item()
  return epoch_loss / len(iterator)

def evaluateTriang(model, iterator, criterions):
  model.eval()
  epoch_loss = 0.0
  with torch.no_grad():
    for batch in tqdm(iterator):
      model_inputs = {
        'model_0': [batch.en, batch.fr],
        'model_1': [batch.en, vars(batch)['es'], batch.fr],
        'TRG': batch.fr
      }
      loss, _ = model(model_inputs, criterions, 0)
      epoch_loss += loss.item()
    return epoch_loss / len(iterator)

## Train

In [None]:
cfg = {
    'EMB_DIM': 256,
    'HID_DIM': 512,
    'DROPOUT': 0.5,
    'en_DIM': 6964,
    'fr_DIM': 9703,
    'es_DIM': 10461,
    'it_DIM': 10712,
    'pt_DIM': 10721,
    'ro_DIM': 11989
}

### Seq2Seq

In [None]:
# For 2 langs
EMB_DIM = 256
HID_DIM = 512
DROPOUT = 0.5

INPUT_DIM = len(EN_FIELD.vocab)
OUTPUT_DIM = len(FR_FIELD.vocab)
SRC_PAD_IDX = EN_FIELD.vocab.stoi[EN_FIELD.pad_token]

attn = Attention(HID_DIM, HID_DIM)
enc = Encoder(INPUT_DIM, EMB_DIM, HID_DIM, HID_DIM, DROPOUT)
dec = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM, HID_DIM, DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

### Pivot

In [None]:
EMB_DIM = 256
HID_DIM = 512
DROPOUT = 0.5

INPUT_DIM = len(EN_FIELD.vocab)
PIV_DIM = len(ES_FIELD.vocab)
OUTPUT_DIM = len(FR_FIELD.vocab)

SRC_PAD_IDX = EN_FIELD.vocab.stoi[EN_FIELD.pad_token]
attn1 = Attention(HID_DIM, HID_DIM)
enc1 = Encoder(INPUT_DIM, EMB_DIM, HID_DIM, HID_DIM, DROPOUT)
dec1 = Decoder(PIV_DIM, EMB_DIM, HID_DIM, HID_DIM, DROPOUT, attn1)
model1 = Seq2Seq(enc1, dec1, SRC_PAD_IDX, device).to(device)

PIV_PAD_IDX = ES_FIELD.vocab.stoi[ES_FIELD.pad_token]
attn2 = Attention(HID_DIM, HID_DIM)
enc2 = Encoder(PIV_DIM, EMB_DIM, HID_DIM, HID_DIM, DROPOUT)
dec2 = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM, HID_DIM, DROPOUT, attn2)
model2 = Seq2Seq(enc2, dec2, PIV_PAD_IDX, device).to(device)

models = [model1, model2]
fields = [EN_FIELD, ES_FIELD, FR_FIELD]
model = PivotSeq2Seq(models, fields, device).to(device)

In [None]:
src, trg = (1, 2), (3, 4)
data = {'TRG': trg}
for i in range(model_1.num_model):
  submodel = model_1.get_submodule(f'model_{i}')
  if isinstance(submodel, Seq2Seq):
    data[f'model_{i}'] = [src, trg]
  elif isinstance(submodel, PivotSeq2Seq):
    data[f'model_{i}'] = [src] + [trg for _ in range(submodel.num_model)]
data

In [None]:
model_path = f'{DIR_PATH}/piv-EnEsFr.pt'
ckpt = torch.load(model_path)
# optimizer.load_state_dict(ckpt['optimizer_state_dict'])
# scheduler.load_state_dict(ckpt['scheduler_state_dict'])
model.load_state_dict(ckpt['model_state_dict']) # strict=False if some dimensions are different

### Triangulate

In [None]:
# DIRECT
SRC_PAD_IDX = EN_FIELD.vocab.stoi[EN_FIELD.pad_token]
attn = Attention(cfg['HID_DIM'], cfg['HID_DIM'])
enc = Encoder(cfg['en_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'])
dec = Decoder(cfg['fr_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'], attn)

model_direct = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)
modelname = f'seq2seq-EnFr-1.pt'
model_direct.load_state_dict(torch.load(f'{DIR_PATH}/{modelname}')['model_state_dict'])
print(f'loaded model {modelname}')

# PIVOT
lang = 'es'
TMP_FIELD = FIELD_DICT[lang]

SRC_PAD_IDX = EN_FIELD.vocab.stoi[EN_FIELD.pad_token]
attn1 = Attention(cfg['HID_DIM'], cfg['HID_DIM'])
enc1 = Encoder(cfg['en_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'])
dec1 = Decoder(cfg[f'{lang}_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'], attn1)
model1 = Seq2Seq(enc1, dec1, SRC_PAD_IDX, device).to(device)

PIV_PAD_IDX = TMP_FIELD.vocab.stoi[TMP_FIELD.pad_token]
attn2 = Attention(cfg['HID_DIM'], cfg['HID_DIM'])
enc2 = Encoder(cfg[f'{lang}_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'])
dec2 = Decoder(cfg['fr_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'], attn2)
model2 = Seq2Seq(enc2, dec2, PIV_PAD_IDX, device).to(device)

models = [model1, model2]
fields = [EN_FIELD, TMP_FIELD, FR_FIELD]
model_piv = PivotSeq2Seq(models, fields, device).to(device)
modelname = f'piv-En{lang[0].upper()}{lang[1]}Fr.pt'
model_piv.load_state_dict(torch.load(f'{DIR_PATH}/{modelname}')['model_state_dict'])
print(f'loaded model {modelname}')

# TRIANGULATE
models = [model_direct, model_piv]
model = TriangSeq2Seq(models, cfg['fr_DIM'], device, alpha=1.1, method='max', train_backbone=False).to(device)

print(f'The model has {count_parameters(model):,} trainable parameters')

loaded model seq2seq-EnFr-1.pt
loaded model piv-EnEsFr.pt
The model has 0 trainable parameters


### Train loop

In [None]:
if isinstance(model, Seq2Seq) or isinstance(model, PivotSeq2Seq) or (isinstance(model, TriangSeq2Seq) and model.train_backbone):
  # model.apply(init_weights) # remove if backbone models are freezed
  print('init_weights success')
elif isinstance(model, TriangSeq2Seq) and not model.train_backbone:
  # model.head.apply(init_weights)
  print('init_weights head ONLY success')
print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
criterion2 = nn.CrossEntropyLoss(ignore_index = FR_FIELD.vocab.stoi[FR_FIELD.pad_token])

if isinstance(model, PivotSeq2Seq) or isinstance(model, TriangSeq2Seq):
  criterion1 = nn.CrossEntropyLoss(ignore_index = FIELD_DICT['es'].vocab.stoi[FIELD_DICT['es'].pad_token])
  criterions3 = [criterion1, criterion2]
  print('Created criterion for piv')

if isinstance(model, TriangSeq2Seq):
  criterions = {
      'model_0': criterion2,
      'model_1': criterions3,
      'TRG': nn.CrossEntropyLoss(ignore_index = FR_FIELD.vocab.stoi[FR_FIELD.pad_token])
  }
  print('Created criterion for triang')

Created criterion for piv
Created criterion for triang


In [None]:
N_EPOCHS = 1
CLIP = 1
best_valid_loss = float('inf')
best_train_loss = float('inf')
model_name = 'traing-EnFr-EnEsFr-dense5.pt'
train_log = []

In [None]:
LR = 1e-4
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1], gamma=1.0/3.0)
scheduler.get_last_lr()

[0.0001]

In [None]:
for epoch in range(N_EPOCHS):
  if isinstance(model, Seq2Seq):
    train_loss = trainSeq2Seq(model, train_iterator, optimizer, criterion2, CLIP)
    valid_loss = evaluateSeq2Seq(model, valid_iterator, criterion2)
  elif isinstance(model, PivotSeq2Seq):
    train_loss = trainPivot(model, train_iterator, optimizer, criterions3, CLIP)
    valid_loss = evaluatePivot(model, valid_iterator, criterions)
  elif isinstance(model, TriangSeq2Seq):
    train_loss = trainTriang(model, train_iterator, optimizer, criterions, CLIP)
    valid_loss = evaluateTriang(model, valid_iterator, criterions)
  else: raise Exception('Model type is unknown')

  print('scheduler.get_last_lr()', scheduler.get_last_lr())
  epoch_info = [model_name, dataname, scheduler.get_last_lr()[0], BATCH_SIZE, HID_DIM, DROPOUT, epoch, N_EPOCHS, train_loss, valid_loss]
  train_log.append([str(info) for info in epoch_info])

  scheduler.step()

  if train_loss < best_train_loss or valid_loss < best_valid_loss:
  # if valid_loss < best_valid_loss:
    best_train_loss = train_loss
    best_valid_loss = valid_loss
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, f'{DIR_PATH}/{model_name}')
    print('SAVED MODEL')
    train_log = update_trainlog(train_log, f'{DIR_PATH}/training_log-OLD.txt')

  print(f'Epoch: {epoch:02} \t Train Loss: {train_loss:.3f} \t Val. Loss: {valid_loss:.3f}')

## Eval

In [None]:
for batch in test_iterator:
  break
batch.fields

dict_keys(['en', 'fr', 'it'])

In [None]:
model_path = f'{DIR_PATH}/traing-EnFr-EnEsFr-1.pt'
ckpt = torch.load(model_path)

In [None]:
# optimizer.load_state_dict(ckpt['optimizer_state_dict'])
# scheduler.load_state_dict(ckpt['scheduler_state_dict'])
model.load_state_dict(ckpt['model_state_dict']) # strict=False if some dimensions are different

<All keys matched successfully>

In [None]:
# test_loss = evaluate(model_infer, test_iterator, criterion2, isPivot=False, force=0)
# print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

## Inference

In [None]:
def sent2tensor(src_field, trg_field, device, max_len, sentence=None):
  if sentence != None:
    if isinstance(sentence, str):
      tokens = tokenize_en(sentence)
    else:
      tokens = [token.lower() for token in sentence]
    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)  # [seq_len, N] w/ N=1 for batch
    src_len_tensor = torch.LongTensor([len(src_indexes)]).to(device)
    return src_tensor, src_len_tensor

  trg_tensor = torch.LongTensor([trg_field.vocab.stoi[trg_field.init_token]] + [0 for i in range(1, max_len)]).view(-1, 1).to(device) # [seq_len, 1]
  trg_len_tensor = torch.LongTensor([max_len]).to(device)
  return trg_tensor, trg_len_tensor

def idx2sent(trg_field, arr):
  n_sents = arr.shape[1]  # arr = [seq_len, N]
  results = []
  for i in range(n_sents):  # for each sent
    pred_sent = []
    pred = arr[:, i]
    for i in pred[1:]:  # for each word
      pred_sent.append(trg_field.vocab.itos[i])
      if i == trg_field.vocab.stoi[trg_field.eos_token]: break
    results.append(pred_sent)
  return results

### Translate by sent

In [None]:
def translate_sentence(sentence, src_field, trg_field, model, device, max_len=64):
  model.eval()
  with torch.no_grad():
    # create data
    src = sent2tensor(src_field, trg_field, device, max_len, sentence)
    trg = sent2tensor(src_field, trg_field, device, max_len)
    # get data
    if isinstance(model, Seq2Seq):
      data = [src, trg]
    elif isinstance(model, PivotSeq2Seq):
      data = [src] + [trg for _ in range(model.num_model)]
    elif isinstance(model, TriangSeq2Seq):
      data = {'TRG': trg}
      for i in range(model.num_model):
        submodel = getattr(model, f'model_{i}')
        if isinstance(submodel, Seq2Seq):
          data[f'model_{i}'] = [src, trg]
        elif isinstance(submodel, PivotSeq2Seq):
          data[f'model_{i}'] = [src] + [trg for _ in range(submodel.num_model)]
        else: raise Exception('Only support Seq2Seq & PivotSeq2Seq nested inside TriangSeq2Seq')
    else: raise Exception('Model type is unknown')
    # feed model
    output = model(data, None, 0) # output = [trg_len, N, dec_emb_dim] w/ N=1
    output = output.argmax(-1).detach().cpu().numpy() # output = [seq_len, N]
    results = idx2sent(trg_field, output)
    return results

In [None]:
def translate_sentence_seq2seq(sentence, src_field, trg_field, model: Seq2Seq, device, max_len=64):
  model.eval()
  with torch.no_grad():
    # get data
    src_tensor, src_len_tensor = sent2tensor(src_field, trg_field, device, max_len, sentence)
    trg_tensor, trg_len_tensor = sent2tensor(src_field, trg_field, device, max_len)
    data = [(src_tensor, src_len_tensor), (trg_tensor, trg_len_tensor)]
    # feed model
    output = model(data, None, 0) # output = [trg_len, N, dec_emb_dim] w/ N=1
    output = output.argmax(-1).detach().cpu().numpy() # output = [seq_len, N]
    results = idx2sent(trg_field, output)
    return results

def translate_sentence_pivot(sentence, src_field, trg_field, model, device, max_len=64):  # not yet modified
  model.eval()
  with torch.no_grad():
    # get data
    src_tensor, src_len_tensor = sent2tensor(src_field, trg_field, device, max_len, sentence)
    trg_tensor, trg_len_tensor = sent2tensor(src_field, trg_field, device, max_len)
    data = [(src_tensor, src_len_tensor)] + [(trg_tensor.clone().detach().to(device), trg_len_tensor.clone().detach().to(device)) for _ in range(model.num_model)]
    # feed model
    output = model(data, None, 0) # output = [trg_len, N, dec_emb_dim]
    output = output.argmax(-1).detach().cpu().numpy()
    results = idx2sent(trg_field, output)
    return results

def translate_batch_triang(sentence, src_field, trg_field, model, device, max_len=64):  # not yet complete. modify sentence/sentences, fields base on type of models
  model.eval()
  with torch.no_grad():
    # get data
    src_tensor, src_len_tensor = sent2tensor(src_field, trg_field, device, max_len, sentence)
    trg_tensor, trg_len_tensor = sent2tensor(src_field, trg_field, device, max_len)
    data = {
      'model_0': [(src_tensor, src_len_tensor), (trg_tensor, trg_len_tensor)],
      'model_1': [(src_tensor, src_len_tensor), (trg_tensor, trg_len_tensor), (trg_tensor, trg_len_tensor)],
      'model_2': [(src_tensor, src_len_tensor), (trg_tensor, trg_len_tensor), (trg_tensor, trg_len_tensor)],
      'TRG': (trg_tensor, trg_len_tensor)}
    # feed model
    output = model(data, None, 0)
    output = output.argmax(-1).detach().cpu().numpy()
    results = idx2sent(trg_field, output)
    return results

In [None]:
if isinstance(model, Seq2Seq):
  translator = translate_sentence_seq2seq
  print('Created translator for seq2seq')
elif isinstance(model, PivotSeq2Seq):
  translator = translate_sentence_pivot
  print('Created translator for piv')
elif isinstance(model, TriangSeq2Seq):
  translator = translate_batch_triang
  print('Created translator for triang')
else:
  try:
    translator = translate_sentence
  except:
    raise Exception('Model type is unknown')

Created translator for triang


In [None]:
example_idx = 2132
src = vars(valid_dt.examples[example_idx])['en']
trg = vars(valid_dt.examples[example_idx])['fr']
pred = translator(src, EN_FIELD, FR_FIELD, model, device)
print(src)
print(trg)
print(pred)

### Translate by batch

In [None]:
for batch in test_iterator: break

print(batch.fields)
print(batch.en[0].shape, batch.en[1].shape)
print(batch.en[1])
print(batch.fr[1])

dict_keys(['en', 'fr', 'es', 'it', 'pt', 'ro'])
tensor([7, 5, 9, 3], device='cuda:0')
tensor([5, 5, 5, 5], device='cuda:0')


In [None]:
def translate_batch(model, iterator, trg_field, device, max_len):
  model.eval()
  with torch.no_grad():
    gt_sents = []
    pred_sents = []
    for idx, batch in enumerate(tqdm(iterator)):
      # get data
      seq_len, N = batch.en[0].shape  # batch.en = (data, len)
      src = batch.en
      (trg_data, trg_len) = sent2tensor(EN_FIELD, FR_FIELD, device, max_len)
      trg_datas = torch.cat([trg_data for _ in range(N)], dim=1)
      trg_lens = torch.cat([trg_len for _ in range(N)], dim=0)
      trg = (trg_datas, trg_lens)
      if isinstance(model, Seq2Seq):
        data = [src, trg]
      elif isinstance(model, PivotSeq2Seq):
        data = [src] + [trg for _ in range(model.num_model)]
      elif isinstance(model, TriangSeq2Seq):
        data = {'TRG': trg}
        for i in range(model.num_model):
          submodel = getattr(model, f'model_{i}')
          if isinstance(submodel, Seq2Seq):
            data[f'model_{i}'] = [src, trg]
          elif isinstance(submodel, PivotSeq2Seq):
            data[f'model_{i}'] = [src] + [trg for _ in range(submodel.num_model)]
          else: raise Exception('Only support Seq2Seq & PivotSeq2Seq nested inside TriangSeq2Seq')
      else: raise Exception('Model type is unknown')
      # feed model
      output = model(data, None, 0)
      pred = output.argmax(-1).detach().cpu().numpy() # [seq_len, N]
      truth = batch.fr[0].detach().cpu().numpy()  # [seq_len, N]

      gt_sents = gt_sents + idx2sent(trg_field, truth)
      pred_sents = pred_sents + idx2sent(trg_field, pred)
    return gt_sents, pred_sents

In [None]:
gt_sents, pred_sents = translate_batch(model, test_iterator, FR_FIELD, device, 64)

  1%|          | 10/1600 [00:04<13:09,  2.01it/s]


In [None]:
for i, (gt_sent, pred_sent) in enumerate(zip(gt_sents, pred_sents)):
  print(gt_sent[:-1])
  print(pred_sent[:-1])
  print()
  if i==3: break

['il', 'faut', 'à', 'présent', 'de', 'bonnes', 'politiques', ',', 'par', 'un', 'effort', '<unk>', 'de', 'la', 'communauté', 'internationale', ',', 'afin', 'de', 'chercher', 'à', 'obtenir', 'libération', 'et', 'justice', 'pour', 'tous', 'ceux', 'qui', 'ont', 'été', 'si', '<unk>', 'touchés', '.']
['les', 'politiques', 'devraient', 'être', 'développées', ',', 'un', 'effort', 'effort', 'pour', 'la', 'communauté', 'internationale', ',', 'pour', 'la', 'justice', 'et', 'tous', 'les', 'personnes', 'qui', 'ont', 'été', 'touchés', '.']

["c'", 'est', 'pourquoi', 'nous', 'devons', 'nous', 'battre', 'pour', 'défendre', 'cette', 'politique', ',', 'pour', 'défendre', 'son', 'budget', ',', 'pour', 'convaincre', 'les', 'états', 'que', 'la', 'politique', 'régionale', "n'", 'est', 'pas', 'un', '<unk>', 'mais', 'une', 'nécessité', '.']
["c'", 'est', 'pourquoi', 'nous', 'devons', 'lutter', 'défendre', 'cette', 'politique', ',', 'défendre', 'sa', 'budget', ',', 'pour', 'les', 'états', 'membres', 'que', 'la

## BLEU

### Main

In [None]:
def calculate_bleu_batch(translator, iterator, trg_field, model, device, max_len=64):
  gt_sents, pred_sents = translator(model, iterator, trg_field, device, max_len)
  pred_trgs = [pred_sent[:-1] for pred_sent in pred_sents]
  trgs = [[gt_sent[:-1]] for gt_sent in gt_sents]
  return pred_trgs, trgs

def calculate_bleu_sent(translator, data, src_field, trg_field, model, device, max_len = 50):
  trgs = []
  pred_trgs = []
  for i, datum in enumerate(tqdm(data)):
    src = vars(datum)['en']
    trg = vars(datum)['fr']
    pred = translator(src, src_field, trg_field, model, device, max_len)
    pred_trgs.append(pred[0][:-1])  #cut off <eos> token
    trgs.append([trg])
  return pred_trgs, trgs

In [None]:
# pred_trgs, trgs = calculate_bleu_sent(translator, test_dt, EN_FIELD, FR_FIELD, model, device)
# score = bleu_score(pred_trgs, trgs)
# print(f'BLEU score = {score*100:.2f}')

In [None]:
model.method = 'max'
pred_trgs, trgs = calculate_bleu_sent(translate_sentence, test_dt, EN_FIELD, FR_FIELD, model.model_0, device)
score = bleu_score(pred_trgs, trgs)
print('en-fr')
print(f'BLEU score = {score*100:.3f}')

100%|██████████| 6400/6400 [06:53<00:00, 15.47it/s]


en-fr
BLEU score = 38.738


In [None]:
model.method = 'max'
pred_trgs, trgs = calculate_bleu_sent(translate_sentence, test_dt, EN_FIELD, FR_FIELD, model.model_1, device)
score = bleu_score(pred_trgs, trgs)
print('en-es-fr')
print(f'BLEU score = {score*100:.3f}')

100%|██████████| 6400/6400 [14:02<00:00,  7.60it/s]


en-es-fr
BLEU score = 32.273


In [None]:
lang = 'it'
TMP_FIELD = FIELD_DICT[lang]

SRC_PAD_IDX = EN_FIELD.vocab.stoi[EN_FIELD.pad_token]
attn1 = Attention(cfg['HID_DIM'], cfg['HID_DIM'])
enc1 = Encoder(cfg['en_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'])
dec1 = Decoder(cfg[f'{lang}_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'], attn1)
model1 = Seq2Seq(enc1, dec1, SRC_PAD_IDX, device).to(device)

PIV_PAD_IDX = TMP_FIELD.vocab.stoi[TMP_FIELD.pad_token]
attn2 = Attention(cfg['HID_DIM'], cfg['HID_DIM'])
enc2 = Encoder(cfg[f'{lang}_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'])
dec2 = Decoder(cfg['fr_DIM'], cfg['EMB_DIM'], cfg['HID_DIM'], cfg['HID_DIM'], cfg['DROPOUT'], attn2)
model2 = Seq2Seq(enc2, dec2, PIV_PAD_IDX, device).to(device)

models = [model1, model2]
fields = [EN_FIELD, TMP_FIELD, FR_FIELD]
model_piv2 = PivotSeq2Seq(models, fields, device).to(device)
modelname = f'piv-En{lang[0].upper()}{lang[1]}Fr.pt'
model_piv2.load_state_dict(torch.load(f'{DIR_PATH}/{modelname}')['model_state_dict'])
print(f'loaded model {modelname}')
model.method = 'max'

pred_trgs, trgs = calculate_bleu_sent(translate_sentence, test_dt, EN_FIELD, FR_FIELD, model_piv2, device)
score = bleu_score(pred_trgs, trgs)
print(f'en-{lang}-fr')
print(f'BLEU score = {score*100:.3f}')

loaded model piv-EnItFr.pt


100%|██████████| 6400/6400 [14:00<00:00,  7.61it/s]


en-it-fr
BLEU score = 31.369


In [None]:
# model.method = 'max'
# pred_trgs, trgs = calculate_bleu_batch(translate_batch, test_iterator, FR_FIELD, model, device)
# score = bleu_score(pred_trgs, trgs)
# print(f'BLEU score = {score*100:.3f}')

100%|██████████| 800/800 [04:42<00:00,  2.84it/s]


BLEU score = 33.698


In [None]:
# model.method = 'average'
# pred_trgs, trgs = calculate_bleu_batch(translate_batch, test_iterator, FR_FIELD, model, device)
# score = bleu_score(pred_trgs, trgs)
# print(f'BLEU score = {score*100:.3f}')

### Result

1. seq2seq-EnFr-1.pt: BLEU = 38.74
1. piv_EnItFr.pt: BLEU = 31.37
1. triang-EnFr-EnItFr.pt: BLEU = 31.95
  * dir-EnFr: BLEU = 38.03
  * piv-EnItFr: BLEU = 24.62
1. triang-EnFr-EnEsFr-1.pt: BLEU = 30.57 (weight_1)
1. Prof Neller's suggestion (max selection):
  * triang(EnFr + EnEsFr) BLEU = 37.80
  * triang(EnFr + EnEsFr) BLEU = 38.81 (2nd run)
  * triang(EnEsFr + EnItFr) BLEU = 33.698 (2nd run, 1st run got 36.25)
1. Paper averaging output ("Late Averaging"):
  * triang(EnFr + EnEsFr) BLEU = 28.25
  * triang(EnEsFr + EnItFr) BLEU = 25.749

* attn_en-fr_32k.pt: BLEU = 12.65
* attn_enfr_160kset.pt: BLEU = 32.18
* piv_endefr_74kset_2.pt: BLEU = 26.33


# End