<h1 id="tocheading">XNLI Training</h1>
<div id="toc"></div>

In [1]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')

<IPython.core.display.Javascript object>

In [2]:
import pickle
import random
import spacy
import errno
import glob
import string
import os
import jieba
import nltk
import functools
import numpy as np
import pandas as pd
from collections import Counter
from collections import defaultdict
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.autograd import Variable

## Imports

Besides the publicly available libraries above, we import our preprocessing functions, models (bidirectional LSTM and linear classifier), discriminator (for the encoder alignment) and aligner (alignment trainer) functions.

In [3]:
from models import *
from preprocess import *
from Discriminator import *
from aligner_functions import *

In [4]:
PAD_IDX, UNK_IDX = define_indices()
label_dict = define_label_dict()
snli_path, align_path, multi_path = define_paths()
no_cuda = False
cuda = not no_cuda and torch.cuda.is_available()
seed = 1
device = torch.device("cuda" if cuda else "cpu")

In [5]:
config = XNLIconfig(val_test_lang = "de", max_sent_len = 50, max_vocab_size = 100000,
                    epochs = 15, batch_size = 256, embed_dim = 300, hidden_dim = 512, dropout = 0.1, lr = 1e-3,
                    experiment_lang = "de")

## Align Source and Target Encoders

In this section, we load the encoder LSTM that was trained on the [English] MNLI training set and align the target encoder to the English encoder. The process works like this:

    I. Load English encoder and fix its parameters.
    II. Make a copy of it (with free non-fixed parameters), also dubbed the target encoder.
    III. Define a loss function for measuring the alignment degree of source and target sentence embeddings. To see the details of our loss, go to aligner_functions.py>loss_align.
    IV. Align the target encoder to the source encoder using this loss function so that it produces close-enough embeddings of target language sentences so that the linear classifier confuses it with English.
    V. Here we also use a discriminator model to measure the degree of confusion created by the alignment process.
    VI. Load your aligned target encoder and measure XNLI target language dev set accuracy.
    VII. Go to step I and reiterate with different parameter sets until you find the optimal model.
    VIII. Test your model on XNLI target language test set.

### Load Source & Target Vectors

Here we first load the source and target language vectors separately. Then we concatenate those into one large embedding matrix.

In [6]:
print ("Loading vectors for EN.")
src_vectors = load_vectors("../data/vecs/cc.en.300.vec")
print ("Loading vectors for {}.".format(config.experiment_lang.upper()))
trg_vectors = load_vectors("../data/vecs/cc.de.300.vec")

Loading vectors for EN.
Loading vectors for DE.


Prepare mutual vocabulary for English and target.

In [7]:
id2tok_src = [x+"."+"en" for x in [*src_vectors.keys()]][:config.max_vocab_size]
id2tok_trg = [x+"."+ config.experiment_lang for x in [*trg_vectors.keys()]][:config.max_vocab_size]
id2tok_mutual = ["<PAD>", "<UNK>"] + id2tok_src + id2tok_trg

In [8]:
vecs_mutual = update_vocab_keys(src_vectors, trg_vectors)
tok2id_mutual = build_tok2id(id2tok_mutual)

In [9]:
weights_init = init_embedding_weights(vecs_mutual, tok2id_mutual, id2tok_mutual, config.embed_dim)

### Read & Tokenize Parallel Corpus

For this demo, we use Europarl English-German parallel corpus that holds the one-to-one translation records of the EU parliament discussions. For languages that are not present among the Europarl parallel corpora, such as Arabic and Chinese, we use OPUS (Open Subtitles). 

In [10]:
data_en_target, all_en_tokens, all_target_tokens = read_and_tokenize_europarl_data(lang=config.val_test_lang)

We create a contrastive dataset that is generated by __batch internal shuffling__. We basically take our dataset and shuffle the order of either the source or the target language (or sometimes both). We use this portion to calculate the contrastive part of the loss. For more details, refer to the alignment loss function at ```aligner_functions.py > loss_align```

In [11]:
c_df = create_contrastive_dataset(data_en_target, config.experiment_lang, 100000)

In [19]:
for x in ["en", config.experiment_lang]:
    data_en_target["len_{}".format(x)] = data_en_target[x + "_tokenized"].apply(lambda x: len(x))

In [21]:
data_en_target = data_en_target[(data_en_target["len_en"]>2)&(data_en_target["len_{}".format(config.experiment_lang)]>2)]

In [23]:
align_dataset = AlignDataset(data_en_target, config.max_sent_len, "en", config.experiment_lang,
                             tok2id_mutual, id2tok_mutual)
align_loader = torch.utils.data.DataLoader(dataset=align_dataset, batch_size=config.batch_size,
                                           collate_fn=lambda x, max_sentence_length=config.max_sent_len: align_collate_func(x, config.max_sent_len),
                                           shuffle=False)

c_align_dataset = AlignDataset(c_df, config.max_sent_len, "en", config.experiment_lang,
                               tok2id_mutual, id2tok_mutual)
c_align_loader = torch.utils.data.DataLoader(dataset=c_align_dataset, batch_size=config.batch_size,
                                             collate_fn=lambda x, max_sentence_length=config.max_sent_len: align_collate_func(x, config.max_sent_len),
                                             shuffle=False)

In [32]:
load_epoch = 8
LSTM_src_model = biLSTM(config.hidden_dim, weights_init, config.dropout, config.max_vocab_size,
                        num_layers=1, input_size=300).to(device)

LSTM_src_model.load_state_dict(torch.load("best_encoder_eng_mnli_{}_EN".format(load_epoch)))
# fix source encoder parameters
for param in LSTM_src_model.parameters():
    param.requires_grad = False
    
LSTM_trg_model = biLSTM(config.hidden_dim, weights_init, config.dropout, config.max_vocab_size,
                        num_layers=1, input_size=300).to(device)

LSTM_trg_model.load_state_dict(torch.load("best_encoder_eng_mnli_{}_EN".format(load_epoch)))

disc = Discriminator(n_langs = 2, dis_layers = 3, dis_hidden_dim = 128, dis_dropout = 0.1, lr_slope=0.005).to(device)

In [25]:
print ("Encoder src:\n", LSTM_src_model)
print ("Encoder trg:\n", LSTM_trg_model)
print ("Discriminator:\n", disc)

Encoder src:
 biLSTM(
  (embedding): Embedding(200002, 300)
  (drop_out): Dropout(p=0.1)
  (LSTM): LSTM(300, 512, batch_first=True, bidirectional=True)
)
Encoder trg:
 biLSTM(
  (embedding): Embedding(200002, 300)
  (drop_out): Dropout(p=0.1)
  (LSTM): LSTM(300, 512, batch_first=True, bidirectional=True)
)
Discriminator:
 Discriminator(
  (layers): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.005)
    (2): Dropout(p=0.1)
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): LeakyReLU(negative_slope=0.005)
    (5): Dropout(p=0.1)
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): LeakyReLU(negative_slope=0.005)
    (8): Dropout(p=0.1)
    (9): Linear(in_features=128, out_features=2, bias=True)
  )
)


In [33]:
# weights_init = torch.from_numpy(weights_init)

In [34]:
for epoch in range(config.epochs):
    print ("\nepoch = "+str(epoch))
    
    loss_train = train(LSTM_s=LSTM_src_model, LSTM_t=LSTM_trg_model, discriminator = disc,
                       loader=align_loader, contrastive_loader=c_align_loader,
                       optimizer = torch.optim.Adam([*LSTM_src_model.parameters()] + [*LSTM_trg_model.parameters()] + [*disc.parameters()],
                                                    lr=config.lr),
                       dis_optim = torch.optim.Adam([*disc.parameters()],
                                                    lr=config.lr),
                       epoch = epoch)
        
    torch.save(LSTM_trg_model.state_dict(), "LSTM_en_{}_{}_epoch_{}".format(config.experiment_lang, config.experiment_lang.upper(), epoch))

## Dev Accuracy on XNLI Target

In [35]:
from nli_trainer import *

In [38]:
weights_init = weights_init.numpy()

In [41]:
# load val and test and preprocess
print ("Reading XNLI {} data.".format(config.val_test_lang.upper()))
xnli_dev, xnli_test = read_xnli(config.experiment_lang)
_, xnli_dev, xnli_test = write_numeric_label(None, xnli_dev, xnli_test, nli_corpus="xnli")

Reading XNLI DE data.


In [42]:
xnli_dev, _ = tokenize_xnli(xnli_dev, lang=config.val_test_lang)
xnli_test, _ = tokenize_xnli(xnli_test, lang=config.val_test_lang)

In [43]:
# dev
nli_dev_dataset = NLIDataset(xnli_dev, max_sentence_length=config.max_sent_len, token2id=tok2id_mutual, id2token=id2tok_mutual)
nli_dev_loader = torch.utils.data.DataLoader(dataset=nli_dev_dataset, batch_size=config.batch_size,
                               collate_fn=lambda x, max_sentence_length=config.max_sent_len: nli_collate_func(x, config.max_sent_len),
                               shuffle=False)

# test
nli_test_dataset = NLIDataset(xnli_test, max_sentence_length=config.max_sent_len, token2id=tok2id_mutual, id2token=id2tok_mutual)
nli_test_loader = torch.utils.data.DataLoader(dataset=nli_test_dataset, batch_size=config.batch_size,
                               collate_fn=lambda x, max_sentence_length=config.max_sent_len: nli_collate_func(x, config.max_sent_len),
                               shuffle=False)

In [44]:
LSTM_trg = biLSTM(config.hidden_dim, weights_init, config.dropout, config.max_vocab_size,
                        num_layers=1, input_size=300).to(device)
epoch = 0
LSTM_trg.load_state_dict(torch.load("LSTM_en_{}_epoch_{}".format(config.val_test_lang.upper(), epoch)))

linear_model = Linear_Layers(hidden_size = 1024, hidden_size_2 = 128, percent_dropout = config.dropout,
                        classes=3, input_size=config.embed_dim).to(device)

linear_model.load_state_dict(torch.load("best_linear_eng_mnli_{}_{}".format(8, "EN")))
val_acc = accuracy(LSTM_trg, linear_model, nli_dev_loader, nn.NLLLoss(reduction='sum'))
print ("\n{} Validation Accuracy = {}".format(config.val_test_lang.upper(), val_acc))


DE Validation Accuracy = 55.38152456283569
