In [1]:
import pickle
import random
import random
import spacy
import csv
import sys
import errno
import glob
import string
import io
import os
import re
import time
import functools
import numpy as np
import pandas as pd
from setuptools import setup
from collections import Counter
from collections import defaultdict
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.autograd import Variable

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
PAD_IDX = 0
UNK_IDX = 1
label_dict = {"entailment":0, "neutral":1, "contradiction":2}

In [3]:
no_cuda = False
cuda = not no_cuda and torch.cuda.is_available()
seed = 1
device = torch.device("cuda" if cuda else "cpu")

In [4]:
opus_path = "/scratch/adc563/nlu_project/data/opus"
europarl_path = "/scratch/adc563/nlu_project/data/europarl"
un_path = "/scratch/adc563/nlu_project/data/un_parallel_corpora"

In [10]:
def read_xnli(lang):
    fname = "/scratch/adc563/nlu_project/data/XNLI/xnli.{}.jsonl"
    xnli_dev = pd.read_json(fname.format("dev"), lines=True)
    xnli_test = pd.read_json(fname.format("test"), lines=True)
    if lang == "all":
        dev_data = xnli_dev
        test_data = xnli_test
    else:
        dev_data = xnli_dev[xnli_dev["language"]==lang]
        test_data = xnli_test[xnli_test["language"]==lang]
    return dev_data, test_data

def load_aligned_vectors(lang):
    f = "/scratch/adc563/nlu_project/data/aligned_embeddings/wiki.{}.align.vec".format(lang)
    fin = io.open(f, "r", encoding="utf-8", newline="\n", errors="ignore")
    n, d = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(" ")
        data[tokens[0]] = [*map(float, tokens[1:])]
    return data

def load_multilingual_vectors(lang):
    f = "/scratch/adc563/nlu_project/data/multi_lingual_embeddings/cc.{}.300.vec".format(lang)
    fin = io.open(f, "r", encoding="utf-8", newline="\n", errors="ignore")
    n, d = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(" ")
        data[tokens[0]] = [*map(float, tokens[1:])]
    return data

def load_glove_vectors(lang):
    f = "/scratch/adc563/nlu_project/HBMP/vector_cache/glove.840B.300d.txt".format(lang)
    fin = io.open(f, "r", encoding="utf-8", newline="\n", errors="ignore")
    n = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(" ")
        data[tokens[0]] = [*map(float, tokens[1:])]
    return data

def read_enli(nli_corpus = "snli"):
    if nli_corpus == "snli":
        path_ = "/scratch/adc563/nlu_project/HBMP/data/snli/snli_1.0/snli_1.0"
        train = pd.read_json("{}_{}.jsonl".format(path_,"train"), lines=True)
        dev = pd.read_json("{}_{}.jsonl".format(path_,"dev"), lines=True)
        test = pd.read_json("{}_{}.jsonl".format(path_,"test"), lines=True)
        # remove - from gold label
        train = train[train["gold_label"] != "-"]
        dev = dev[dev["gold_label"] != "-"]
        test = test[test["gold_label"] != "-"]
    elif nli_corpus == "multinli":
        path_ = "/scratch/adc563/nlu_project/HBMP/data/multinli/multinli_1.0/multinli_1.0"
        train = pd.read_json("{}_{}.jsonl".format(path_,"train"), lines=True)
        dev = pd.read_json("{}_{}_matched.jsonl".format(path_, "dev"), lines=True)
        test = None
        # remove - from gold label
        train = train[train["gold_label"] != "-"]
        dev = dev[dev["gold_label"] != "-"]
    return train, dev, test

def write_numeric_label(train, dev, test, nli_corpus="multinli"):
    if nli_corpus == "multinli":
        for dataset in [train, dev]:
            dataset["gold_label"] = dataset["gold_label"].apply(lambda x: label_dict[x])
    elif nli_corpus == "snli":
        for dataset in [train, dev, test]:
            dataset["gold_label"] = dataset["gold_label"].apply(lambda x: label_dict[x])
    elif nli_corpus == "xnli":
        for dataset in [dev, test]:
            dataset["gold_label"] = dataset["gold_label"].apply(lambda x: label_dict[x])
    else:
        raise ValueError ("NLI corpus name should be in [multinli, snli, xnli]")
    return train, dev, test

def tokenize_xnli(dataset, remove_punc=False, lang="en"):
    all_s1_tokens = []
    all_s2_tokens = []
    for s in ["sentence1", "sentence2"]:
        punc = [*string.punctuation]
        dataset["{}_tokenized".format(s)] = dataset["{}".format(s)].\
        apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
        dataset["{}_tokenized".format(s)] = dataset["{}_tokenized".format(s)].\
        apply(lambda x: [a+"."+lang for a in x])
    ext = dataset["sentence1_tokenized"].apply(lambda x: all_s1_tokens.extend(x))
    ext1 = dataset["sentence2_tokenized"].apply(lambda x: all_s2_tokens.extend(x))
    all_tokens = all_s1_tokens + all_s2_tokens
    return dataset, all_tokens

def build_vocab(all_tokens, max_vocab_size):
    token_counter = Counter(all_tokens)
    vocab, count = zip(*token_counter.most_common(max_vocab_size))
    id2token = [*vocab]
    token2id = dict(zip(vocab, range(2,2+len(vocab))))
    id2token = ['<PAD>', '<UNK>'] + id2token
    token2id["<PAD>"] = 0
    token2id["<UNK>"] = 1
    return token2id, id2token

def build_tok2id(id2token):
    token2id = {}
    for i in range(len(id2token)):
        token2id[id2token[i]] = i
    return token2id

def init_embedding_weights(vectors, token2id, id2token, embedding_size):
    weights = np.zeros((len(id2token), embedding_size))
    for idx in range(2, len(id2token)):
        token = id2token[idx]
        weights[idx] = vectors[token]
    weights[1] = np.random.randn(embedding_size)
    return weights

In [11]:
class config_class:
    def __init__(self, corpus, val_test_lang, max_sent_len, max_vocab_size, epochs, batch_size, 
                    embed_dim, hidden_dim, dropout, lr, experiment_lang):
        self.corpus = corpus
        self.val_test_lang = val_test_lang
        self.max_sent_len = max_sent_len
        self.max_vocab_size = max_vocab_size
        self.epochs = epochs
        self.batch_size = batch_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.lr = lr
        self.experiment_lang = experiment_lang

In [12]:
config = config_class(corpus = "multinli",
             val_test_lang = "ru",
             max_sent_len = 40,
             max_vocab_size = 200000,
             epochs = 15,
             batch_size = 32, # decreased, because we will have contrastive batch - so x 2
             embed_dim = 300,
             hidden_dim = 512,
             dropout = 0.1,
             lr = 1e-3,
             experiment_lang = "ru")

In [13]:
def update_vocab_keys(src_vocab, trg_vocab):
    for x in [*src_vocab.keys()]:
        src_vocab[x + ".en"] = src_vocab[x]
        src_vocab.pop(x)
    for y in [*trg_vocab.keys()]:
        trg_vocab[y + ".{}".format(config.experiment_lang)] = trg_vocab[y]
        trg_vocab.pop(y)
        
    src_vocab.update(trg_vocab)
    return src_vocab

In [14]:
print ("Loading vectors for EN.")
aligned_src_vectors = load_glove_vectors("en")
print ("Loading vectors for {}.".format(config.experiment_lang.upper()))
aligned_trg_vectors = load_aligned_vectors(config.experiment_lang)

Loading vectors for EN.
Loading vectors for RU.


In [15]:
id2token_src = [x+"."+"en" for x in [*aligned_src_vectors.keys()]][:config.max_vocab_size]
id2token_trg = [x+"."+config.experiment_lang for x in [*aligned_trg_vectors.keys()]][:config.max_vocab_size]
id2token_mutual = ["<PAD>", "<UNK>"] + id2token_src + id2token_trg
vecs_mutual = update_vocab_keys(aligned_src_vectors, aligned_trg_vectors)

In [16]:
token2id_mutual = build_tok2id(id2token_mutual)
weights_init = init_embedding_weights(vecs_mutual, token2id_mutual, id2token_mutual, 300)

In [17]:
len(id2token_mutual)

400002

In [18]:
weights_init.shape

(400002, 300)

In [19]:
def read_and_tokenize_opus_data(lang="tr"):
    all_en_tokens = []
    all_target_tokens = []
    path_en = opus_path + "/{}_en/OpenSubtitles.en-{}.en_00".format(lang, lang)
    path_target = opus_path + "/{}_en/OpenSubtitles.en-{}.{}_00".format(lang, lang, lang)
    en_corpus = open(path_en, "r")
    target_corpus = open(path_target, "r")
    en_series = pd.Series(en_corpus.read().split("\n"))
    target_series = pd.Series(target_corpus.read().split("\n"))
    dataset = pd.DataFrame({"en":en_series, lang:target_series})
    for i in ["en", lang]:
        dataset["{}_tokenized".format(i)] = dataset[i].apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
        dataset["{}_tokenized".format(i)] = dataset["{}_tokenized".format(i)].\
        apply(lambda x:[a+".{}".format(i) for a in x])
    dataset["en_tokenized"].apply(lambda x: all_en_tokens.extend(x))
    dataset["{}_tokenized".format(lang)].apply(lambda x: all_target_tokens.extend(x))
    return dataset, all_en_tokens, all_target_tokens

In [20]:
def read_and_tokenize_europarl_data(lang="de"):
    all_en_tokens = []
    all_target_tokens = []
    path_en = europarl_path + "/{}_en/europarl-v7.{}-en.en".format(lang, lang)
    path_target = europarl_path + "/{}_en/europarl-v7.{}-en.{}".format(lang, lang, lang)
    en_corpus = open(path_en, "r")
    target_corpus = open(path_target, "r")
    en_series = pd.Series(en_corpus.read().split("\n"))
    target_series = pd.Series(target_corpus.read().split("\n"))
    dataset = pd.DataFrame({"en":en_series, lang:target_series})
    for i in ["en", lang]:
        dataset["{}_tokenized".format(i)] = dataset[i].apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
        dataset["{}_tokenized".format(i)] = dataset["{}_tokenized".format(i)].apply(lambda x:[a+".{}".format(i) for a in x])
    dataset["en_tokenized"].apply(lambda x: all_en_tokens.extend(x))
    dataset["{}_tokenized".format(lang)].apply(lambda x: all_target_tokens.extend(x))
    return dataset, all_en_tokens, all_target_tokens

In [21]:
data_en_target, all_en_tokens, all_target_tokens = read_and_tokenize_opus_data(lang=config.val_test_lang)

In [22]:
data_en_target.head(3)

Unnamed: 0,en,ru,en_tokenized,ru_tokenized
0,Kids can get pretty much anything they want in...,Дети могут достать во дворе почти всё что угод...,"[kids.en, can.en, get.en, pretty.en, much.en, ...","[дети.ru, могут.ru, достать.ru, во.ru, дворе.r..."
1,'Cause everything comes with a price.,Всё имеет свою цену.,"[cause.en, everything.en, comes.en, with.en, a...","[всё.ru, имеет.ru, свою.ru, цену.ru]"
2,"Hey, Nick.","Эй, Ник.","[hey.en, nick.en]","[эй.ru, ник.ru]"


In [23]:
def create_contrastive_dataset(dataset, trg_lang):
    shuffle_ix_src = torch.randperm(len(dataset))
    src_c = np.array([*dataset["{}_tokenized".format("en")].values])[shuffle_ix_src]
    trg_c = dataset["{}_tokenized".format(trg_lang)]
    contrastive_df = pd.DataFrame({"en_tokenized": src_c, "{}_tokenized".format(trg_lang): trg_c})
    return contrastive_df

In [24]:
c_df = create_contrastive_dataset(data_en_target, config.val_test_lang)

In [25]:
c_df = c_df.iloc[torch.randperm(len(c_df))]

In [26]:
c_df.head(3)

Unnamed: 0,en_tokenized,ru_tokenized
47225,"[theres.en, a.en, bloody.en, great.en, bed.en,...","[приди.ru, в.ru, себя.ru, дорогая.ru, прошу.ru..."
219761,"[but.en, youll.en, know.en, the.en, price.en, ...","[лошадиный.ru, моряк.ru]"
673051,"[.en, halt.en, 60.en, starboard.en, 300.en, .e...","[.ru, но.ru, я.ru, ничего.ru, не.ru, сделал.ru]"


In [27]:
class AlignDataset(Dataset):
    def __init__(self, data, max_sent_len, src_lang, trg_lang,
                 token2id, id2token):
        self.src = [*data["{}_tokenized".format(src_lang)].values]
        self.trg = [*data["{}_tokenized".format(trg_lang)].values]
        self.max_sent_len = int(max_sent_len)
        self.token2id, self.id2token = token2id, id2token
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, row):
        src_ix, trg_ix = [], []
        src_mask, trg_mask = [], []
        for w in self.src[row][:self.max_sent_len]:
            if w in self.token2id.keys():
                src_ix.append(self.token2id[w])
                src_mask.append(0)
            else:
                src_ix.append(UNK_IDX)
                src_mask.append(1)
        for w in self.trg[row][:self.max_sent_len]:
            if w in self.token2id.keys():
                trg_ix.append(self.token2id[w])
                trg_mask.append(0)
            else:
                trg_ix.append(UNK_IDX)
                trg_mask.append(1)
        
        src_list = [src_ix, src_mask, len(src_ix)]
        trg_list = [trg_ix, trg_mask, len(trg_mask)]
        return src_list + trg_list
    
def align_collate_func(batch, max_sent_len):
    src_data, trg_data = [], []
    src_mask, trg_mask = [], []
    src_len, trg_len = [], []
    
    for datum in batch:
        src_len.append(datum[2])
        trg_len.append(datum[5])
        src_data_padded = np.pad(np.array(datum[0]), pad_width=((0, max_sent_len-datum[2])), mode="constant", constant_values=PAD_IDX)
        src_data.append(src_data_padded)
        src_mask_padded = np.pad(np.array(datum[1]), pad_width=((0, max_sent_len-datum[2])), mode="constant", constant_values=PAD_IDX)
        src_mask.append(src_mask_padded)
        trg_data_padded = np.pad(np.array(datum[3]), pad_width=((0, max_sent_len-datum[5])), mode="constant", constant_values=PAD_IDX)
        trg_data.append(trg_data_padded)
        trg_mask_padded = np.pad(np.array(datum[4]), pad_width=((0, max_sent_len-datum[5])), mode="constant", constant_values=PAD_IDX)
        trg_mask.append(trg_mask_padded)
        
    ind_dec_order = np.argsort(src_len)[::-1]
    src_data = np.array(src_data)[ind_dec_order]
    trg_data = np.array(trg_data)[ind_dec_order]
    src_mask = np.array(src_mask)[ind_dec_order].reshape(len(batch), -1, 1)
    trg_mask = np.array(trg_mask)[ind_dec_order].reshape(len(batch), -1, 1)
    src_len = np.array(src_len)[ind_dec_order]
    trg_len = np.array(trg_len)[ind_dec_order]

    return [torch.from_numpy(src_data), torch.from_numpy(src_mask).float(), src_len,
            torch.from_numpy(trg_data), torch.from_numpy(trg_mask).float(), trg_len]

In [28]:
# her epochta kaydet
align_dataset = AlignDataset(data_en_target, config.max_sent_len, "en", config.experiment_lang,
                             token2id_mutual, id2token_mutual)
align_loader = torch.utils.data.DataLoader(dataset=align_dataset, batch_size=config.batch_size,
                               collate_fn=lambda x, max_sentence_length=config.max_sent_len: align_collate_func(x, config.max_sent_len),
                               shuffle=False)

In [29]:
c_align_dataset = AlignDataset(c_df, config.max_sent_len, "en", config.experiment_lang, 
                               token2id_mutual, id2token_mutual)
c_align_loader = torch.utils.data.DataLoader(dataset=c_align_dataset, batch_size=config.batch_size,
                               collate_fn=lambda x, max_sentence_length=config.max_sent_len: align_collate_func(x, config.max_sent_len),
                               shuffle=False)

In [30]:
torch.norm(torch.from_numpy(np.dot(torch.ones((3,4)), 2 * torch.ones((4,5)))))

tensor(30.9839)

In [31]:
def loss_align(en_rep, target_rep, en_c, target_c, lambda_reg):
    """:param en_rep: output repr of eng encoder (batch_size, hidden_size)
       :param target_rep: output repr of target encoder (batch_size, hidden_size)
       :param en_c: contrastive sentence repr from eng encoder (batch_size, hidden_size)
       :param target_c: contrastive sentence repr form target encoder (batch_size, hidden_size)
       :param lambda_reg: regularization coef [default: 0.25]

    Returns: L_align = l2norm (en_rep, target_rep) - lambda_reg( l2norm (en_c, target_rep) + l2norm (en_rep, target_c))
    """
    dist = torch.norm(en_rep - target_rep, 2)
    c_dist = torch.norm(en_c - target_rep, 2) + torch.norm(en_rep - target_c, 2)
    L_align = dist - lambda_reg*(c_dist)
    return L_align

In [32]:
# def loss_align(en_rep, target_rep, en_c, target_c, lambda_reg):
#     """:param en_rep: output repr of eng encoder (batch_size, hidden_size)
#        :param target_rep: output repr of target encoder (batch_size, hidden_size)
#        :param en_c: contrastive sentence repr from eng encoder (batch_size, hidden_size)
#        :param target_c: contrastive sentence repr form target encoder (batch_size, hidden_size)
#        :param lambda_reg: regularization coef [default: 0.25]

#     Returns: L_align = l2norm (en_rep, target_rep) - lambda_reg( l2norm (en_c, target_rep) + l2norm (en_rep, target_c))
#     """
# #     dist = torch.norm(en_rep - target_rep, 2)
#     c_dist = torch.norm(en_c - target_rep, 2) + torch.norm(en_rep - target_c, 2)
# #     L_align = dist - lambda_reg*(c_dist)
#     return c_dist

In [33]:
class biLSTM(nn.Module):
    
    def __init__(self,
                 hidden_size,
                 embedding_weights,
                 percent_dropout,
                 vocab_size,
                 interaction_type="concat",
                 num_layers=1,
                 input_size=300,
                 src_trg = "src"):

        super(biLSTM, self).__init__()
        
        self.num_layers, self.hidden_size = num_layers, hidden_size
        
        self.embed_table = torch.from_numpy(embedding_weights).float()
        embedding = nn.Embedding.from_pretrained(self.embed_table)

        self.embedding = embedding
        self.interaction = interaction_type
        self.dropout = percent_dropout
        self.drop_out = nn.Dropout(self.dropout)
        
        self.LSTM = nn.LSTM(300, hidden_size, num_layers, batch_first=True, bidirectional=True)
#         self.LSTM2 = nn.LSTM(300, hidden_size, num_layers, batch_first=True, bidirectional=True)
#         self.LSTM3 = nn.LSTM(300, hidden_size, num_layers, batch_first=True, bidirectional=True)
        
        if self.LSTM.bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
            
        self.bn = nn.BatchNorm1d(self.hidden_size * self.num_directions)
        
    def init_hidden(self, batch_size):
        hidden = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(device)
        return hidden, c_0
    
    def forward(self, sentence, mask, lengths):
        sort_original = sorted(range(len(lengths)), key=lambda sentence: -lengths[sentence])
        unsort_to_original = sorted(range(len(lengths)), key=lambda sentence: sort_original[sentence])
        
        sentence = sentence[sort_original]
        _mask = mask[sort_original]
        lengths = lengths[sort_original]
        batch_size, seq_len = sentence.size()
        self.hidden, self.c_0 = self.init_hidden(batch_size)
        
        # embdddings
        embeds = self.embedding(sentence)
        embeds = mask*embeds + (1-_mask)*embeds.clone().detach()
        embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, lengths, batch_first=True)
        # first lstm
        lstm_out, (self.hidden_1, self.c_1) = self.LSTM(embeds, (self.hidden, self.c_0))
        emb1, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
        emb1 = emb1.view(batch_size, -1, 2, self.hidden_size)
        emb1 = torch.max(emb1, dim=1)[0]
        emb1 = torch.cat([emb1[:,i,:] for i in range(self.num_directions)], dim=1)
        emb1 = emb1[unsort_to_original]
        
        out = self.bn(emb1)
        
#         lstm_out_2, (self.hidden_2, self.c_2) = self.LSTM2(embeds, (self.hidden_1, self.c_1))
#         lstm_out_2, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out_2, batch_first=True)
#         lstm_out_2 = lstm_out_2.view(batch_size, -1, 2, self.hidden_size)
        
#         lstm_out_2 = torch.max(lstm_out_2, dim=1)[0]
#         lstm_out_2 = torch.cat([lstm_out_2[:,i,:] for i in range(self.num_directions)], dim=1)
#         lstm_out_2 = lstm_out_2[unsort_to_original]
        
#         lstm_out_3, (self.hidden_3, self.c_3) = self.LSTM3(embeds, (self.hidden_2, self.c_2))
#         lstm_out_3, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out_3, batch_first=True)
#         lstm_out_3 = lstm_out_3.view(batch_size, -1, 2, self.hidden_size)
#         lstm_out_3 = torch.max(lstm_out_3, dim=1)[0]
        
#         lstm_out_3 = torch.cat([lstm_out_3[:,i,:] for i in range(self.num_directions)], dim=1)
#         lstm_out_3 = lstm_out_3[unsort_to_original]
#         out = torch.cat([emb1, lstm_out_2, lstm_out_3], dim=1)
        
        return out

In [38]:
# src: always English
def train(LSTM_src, LSTM_trg, loader, contrastive_loader, optimizer, epoch):
    LSTM_src.train()
    LSTM_trg.train()
    total_loss = 0
    for batch_idx, ([src_data, src_mask, src_len, trg_data, trg_mask, trg_len],
                    [src_c, src_mc, src_len_c, trg_c, trg_mc, trg_len_c]) in \
        [*enumerate(zip(loader, contrastive_loader))]:
        
        src_data, src_mask = src_data.to(device), src_mask.to(device)
        trg_data, trg_mask = trg_data.to(device), trg_mask.to(device)
        src_c, src_mc = src_c.to(device), src_mc.to(device)
        trg_c, trg_mc = trg_c.to(device), trg_mc.to(device)
        optimizer.zero_grad()
        src_out = LSTM_src(src_data, src_mask, src_len)
        trg_out = LSTM_trg(trg_data, trg_mask, trg_len)
        src_c_out = LSTM_src(src_c, src_mc, src_len_c)
        trg_c_out = LSTM_trg(trg_c, trg_mc, trg_len_c)
        loss = loss_align(src_out, trg_out, src_c_out, trg_c_out, 0.25)
        loss.cuda().backward()
        optimizer.step()
        total_loss += loss.item() * len(src_data) / 400000
        if (batch_idx+1) % (len(loader.dataset)//(20*config.batch_size)) == 0:
            torch.save(LSTM_trg.state_dict(), "LSTM_en_{}_{}_epoch_{}".format(config.experiment_lang,
                                                                      config.experiment_lang.upper(),
                                                                      epoch))
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx+1) * config.batch_size, 400000,
                100. * (batch_idx+1) / len(loader), loss.item()), end="\r")
            
    optimizer.zero_grad()
    return total_loss

In [39]:
weights_init.shape

(400002, 300)

In [None]:
load_epoch = 3
LSTM_src_model = biLSTM(hidden_size=config.hidden_dim, embedding_weights=weights_init ,num_layers=1, percent_dropout = config.dropout, 
             vocab_size=weights_init.shape[0], interaction_type="concat", input_size=300).to(device)
LSTM_src_model.load_state_dict(torch.load("best_encoder_eng_mnli_{}_{}".format(load_epoch, 
                                                                        config.experiment_lang)))
for param in LSTM_src_model.parameters():
    param.requires_grad = False

LSTM_trg_model = biLSTM(hidden_size=config.hidden_dim, embedding_weights=weights_init ,num_layers=1, percent_dropout = config.dropout, 
             vocab_size=weights_init.shape[0], interaction_type="concat", input_size=300).to(device)
LSTM_trg_model.load_state_dict(torch.load("best_encoder_eng_mnli_{}_{}".format(load_epoch, 
                                                                        config.experiment_lang)))

print ("Encoder src:\n", LSTM_src_model)
print ("Encoder trg:\n", LSTM_trg_model)
    
for epoch in range(config.epochs):
    print ("\nepoch = "+str(epoch))
    
    loss_train = train(LSTM_src=LSTM_src_model, LSTM_trg=LSTM_trg_model, loader=align_loader, contrastive_loader=c_align_loader,
                      optimizer = torch.optim.Adam([*LSTM_src_model.parameters()] + [*LSTM_trg_model.parameters()], lr=config.lr), 
                      epoch = epoch)

    torch.save(LSTM_trg_model.state_dict(), "LSTM_en_{}_{}_epoch_{}".format(config.experiment_lang,
                                                                      config.experiment_lang.upper(),
                                                                      epoch))

Encoder src:
 biLSTM(
  (embedding): Embedding(400002, 300)
  (drop_out): Dropout(p=0.1)
  (LSTM): LSTM(300, 512, batch_first=True, bidirectional=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Encoder trg:
 biLSTM(
  (embedding): Embedding(400002, 300)
  (drop_out): Dropout(p=0.1)
  (LSTM): LSTM(300, 512, batch_first=True, bidirectional=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

epoch = 0
epoch = 1
epoch = 2

In [None]:
# sum of loss_align and loss_align_cos