In [21]:
import pickle
import random
import spacy
import csv
import sys
import errno
import glob
import string
import io
import os
import jieba
import re
import nltk
import time
import functools
import numpy as np
import pandas as pd
from setuptools import setup
from collections import Counter
from collections import defaultdict
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.autograd import Variable

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
PAD_IDX = 0
UNK_IDX = 1
label_dict = {"entailment":0, "neutral":1, "contradiction":2}

In [3]:
no_cuda = False
cuda = not no_cuda and torch.cuda.is_available()
seed = 1
device = torch.device("cuda" if cuda else "cpu")

In [4]:
opus_path = "/scratch/adc563/nlu_project/data/opus"
europarl_path = "/scratch/adc563/nlu_project/data/europarl"
un_path = "/scratch/adc563/nlu_project/data/un_parallel_corpora"

In [5]:
def read_xnli(lang):
    fname = "/scratch/adc563/nlu_project/data/XNLI/xnli.{}.jsonl"
    xnli_dev = pd.read_json(fname.format("dev"), lines=True)
    xnli_test = pd.read_json(fname.format("test"), lines=True)
    if lang == "all":
        dev_data = xnli_dev
        test_data = xnli_test
    else:
        dev_data = xnli_dev[xnli_dev["language"]==lang]
        test_data = xnli_test[xnli_test["language"]==lang]
    return dev_data, test_data

def load_aligned_vectors(lang):
    f = "/scratch/adc563/nlu_project/data/aligned_embeddings/wiki.{}.align.vec".format(lang)
    fin = io.open(f, "r", encoding="utf-8", newline="\n", errors="ignore")
    n, d = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(" ")
        data[tokens[0]] = [*map(float, tokens[1:])]
    return data

def load_multilingual_vectors(lang):
    fname = "/scratch/adc563/nlu_project/data/multi_lingual_embeddings/cc.{}.300.vec".format(lang)
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(' ')
        data[tokens[0]] = map(float, tokens[1:])
    return data

def load_glove_vectors(lang):
    f = "/scratch/adc563/nlu_project/HBMP/vector_cache/glove.840B.300d.txt".format(lang)
    fin = io.open(f, "r", encoding="utf-8", newline="\n", errors="ignore")
    n = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(" ")
        data[tokens[0]] = [*map(float, tokens[1:])]
    return data

def read_enli(nli_corpus = "snli"):
    if nli_corpus == "snli":
        path_ = "/scratch/adc563/nlu_project/HBMP/data/snli/snli_1.0/snli_1.0"
        train = pd.read_json("{}_{}.jsonl".format(path_,"train"), lines=True)
        dev = pd.read_json("{}_{}.jsonl".format(path_,"dev"), lines=True)
        test = pd.read_json("{}_{}.jsonl".format(path_,"test"), lines=True)
        # remove - from gold label
        train = train[train["gold_label"] != "-"]
        dev = dev[dev["gold_label"] != "-"]
        test = test[test["gold_label"] != "-"]
    elif nli_corpus == "multinli":
        path_ = "/scratch/adc563/nlu_project/HBMP/data/multinli/multinli_1.0/multinli_1.0"
        train = pd.read_json("{}_{}.jsonl".format(path_,"train"), lines=True)
        dev = pd.read_json("{}_{}_matched.jsonl".format(path_, "dev"), lines=True)
        test = None
        # remove - from gold label
        train = train[train["gold_label"] != "-"]
        dev = dev[dev["gold_label"] != "-"]
    return train, dev, test

def write_numeric_label(train, dev, test, nli_corpus="multinli"):
    if nli_corpus == "multinli":
        for dataset in [train, dev]:
            dataset["gold_label"] = dataset["gold_label"].apply(lambda x: label_dict[x])
    elif nli_corpus == "snli":
        for dataset in [train, dev, test]:
            dataset["gold_label"] = dataset["gold_label"].apply(lambda x: label_dict[x])
    elif nli_corpus == "xnli":
        for dataset in [dev, test]:
            dataset["gold_label"] = dataset["gold_label"].apply(lambda x: label_dict[x])
    else:
        raise ValueError ("NLI corpus name should be in [multinli, snli, xnli]")
    return train, dev, test

def tokenize_xnli(dataset, remove_punc=False, lang="en"):
    all_s1_tokens = []
    all_s2_tokens = []
    punc = [*string.punctuation]
    if lang == "ar":
        for s in ["sentence1", "sentence2"]:
            dataset["{}_tokenized".format(s)] = dataset[s].\
            apply(lambda x: [a + ".ar" for a in nltk.tokenize.wordpunct_tokenize(x)])
        ext = dataset["sentence1_tokenized"].apply(lambda x: all_s1_tokens.extend(x))
        ext1 = dataset["sentence2_tokenized"].apply(lambda x: all_s2_tokens.extend(x))
        all_tokens = all_s1_tokens + all_s2_tokens
    elif lang == "zh":
        for s in ["sentence1", "sentence2"]:
            dataset["{}_tokenized".format(s)] = dataset[s].\
            apply(lambda x: [z + ".zh" for z in ' '.join(jieba.cut(x, cut_all=True)).split(" ")])
        ext = dataset["sentence1_tokenized"].apply(lambda x: all_s1_tokens.extend(x))
        ext1 = dataset["sentence2_tokenized"].apply(lambda x: all_s2_tokens.extend(x))
        all_tokens = all_s1_tokens + all_s2_tokens
    else:
        for s in ["sentence1", "sentence2"]:
            dataset["{}_tokenized".format(s)] = dataset[s].\
            apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
            dataset["{}_tokenized".format(s)] = dataset["{}_tokenized".format(s)].\
            apply(lambda x: [a+"."+lang for a in x])
        ext = dataset["sentence1_tokenized"].apply(lambda x: all_s1_tokens.extend(x))
        ext1 = dataset["sentence2_tokenized"].apply(lambda x: all_s2_tokens.extend(x))
        all_tokens = all_s1_tokens + all_s2_tokens
    return dataset, all_tokens



def build_vocab(all_tokens, max_vocab_size):
    token_counter = Counter(all_tokens)
    vocab, count = zip(*token_counter.most_common(max_vocab_size))
    id2token = [*vocab]
    token2id = dict(zip(vocab, range(2,2+len(vocab))))
    id2token = ['<PAD>', '<UNK>'] + id2token
    token2id["<PAD>"] = 0
    token2id["<UNK>"] = 1
    return token2id, id2token

def build_tok2id(id2token):
    token2id = {}
    for i in range(len(id2token)):
        token2id[id2token[i]] = i
    return token2id

def init_embedding_weights(vectors, token2id, id2token, embedding_size):
    weights = np.zeros((len(id2token), embedding_size))
    for idx in range(2, len(id2token)):
        token = id2token[idx]
        weights[idx] = vectors[token]
    weights[1] = np.random.randn(embedding_size)
    return weights

In [6]:
class config_class:
    def __init__(self, corpus, val_test_lang, max_sent_len, max_vocab_size, epochs, batch_size, 
                    embed_dim, hidden_dim, dropout, lr, experiment_lang):
        self.corpus = corpus
        self.val_test_lang = val_test_lang
        self.max_sent_len = max_sent_len
        self.max_vocab_size = max_vocab_size
        self.epochs = epochs
        self.batch_size = batch_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.lr = lr
        self.experiment_lang = experiment_lang

In [8]:
config = config_class(corpus = "multinli",
             val_test_lang = "zh",
             max_sent_len = 30,
             max_vocab_size = 210000,
             epochs = 15,
             batch_size = 64, 
             embed_dim = 300,
             hidden_dim = 512,
             dropout = 0.1,
             lr = 1e-3,
             experiment_lang = "zh")

In [9]:
def update_vocab_keys(src_vocab, trg_vocab):
    for x in [*src_vocab.keys()]:
        src_vocab[x + ".en"] = src_vocab[x]
        src_vocab.pop(x)
    for y in [*trg_vocab.keys()]:
        trg_vocab[y + ".{}".format(config.experiment_lang)] = trg_vocab[y]
        trg_vocab.pop(y)
        
    src_vocab.update(trg_vocab)
    return src_vocab

In [10]:
print ("Loading vectors for EN.")
aligned_src_vectors = load_glove_vectors("en")

Loading vectors for EN.


In [11]:
print ("Loading vectors for {}.".format(config.experiment_lang.upper()))
aligned_trg_vectors = load_aligned_vectors(config.experiment_lang)

Loading vectors for ZH.


In [12]:
id2token_src = [x+"."+"en" for x in [*aligned_src_vectors.keys()]][:config.max_vocab_size]
id2token_trg = [x+"."+config.experiment_lang for x in [*aligned_trg_vectors.keys()]][:config.max_vocab_size]
id2token_mutual = ["<PAD>", "<UNK>"] + id2token_src + id2token_trg
vecs_mutual = update_vocab_keys(aligned_src_vectors, aligned_trg_vectors)

In [13]:
def init_embedding_weights(vectors, token2id, id2token, embedding_size):
    weights = np.zeros((len(id2token), embedding_size))
    for idx in range(2, len(id2token)):
        token = id2token[idx]
        weights[idx] = vectors[token]
    weights[1] = np.random.randn(embedding_size)
    return weights

In [14]:
token2id_mutual = build_tok2id(id2token_mutual)
weights_init = init_embedding_weights(vecs_mutual, token2id_mutual, id2token_mutual, 300)

In [15]:
len(id2token_mutual)

420002

In [16]:
weights_init.shape

(420002, 300)

In [63]:
def read_and_tokenize_opus_data(lang="tr"):
    all_en_tokens = []
    all_target_tokens = []
    path_en = opus_path + "/{}_en/en_data_00".format(lang, lang)
    path_target = opus_path + "/{}_en/{}_data_00".format(lang, lang)
    en_corpus = open(path_en, "r")
    target_corpus = open(path_target, "r")
    en_series = pd.Series(en_corpus.read().split("\n"))
    target_series = pd.Series(target_corpus.read().split("\n"))
    dataset = pd.DataFrame({"en":en_series, lang:target_series})
    if lang == "ar":
        dataset["en_tokenized"] = dataset["en"].apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
        dataset["en_tokenized"] = dataset["en_tokenized"].apply(lambda x:[a+".en" for a in x])
        dataset["ar_tokenized"] = dataset["ar"].apply(lambda x: [a + ".ar" for a in nltk.tokenize.wordpunct_tokenize(x)])
    elif lang == "zh":
        dataset["en_tokenized"] = dataset["en"].apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
        dataset["en_tokenized"] = dataset["en_tokenized"].apply(lambda x:[a+".en" for a in x])
        dataset["zh_tokenized"] = dataset["zh"].apply(lambda x: [z + ".zh" for z in ' '.join(jieba.cut(x, cut_all=True)).split(" ") if z not in string.punctuation])
    else:
        for i in ["en", lang]:
            dataset["{}_tokenized".format(i)] = dataset[i].apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
            dataset["{}_tokenized".format(i)] = dataset["{}_tokenized".format(i)].\
            apply(lambda x:[a+".{}".format(i) for a in x])
    dataset["en_tokenized"].apply(lambda x: all_en_tokens.extend(x))
    dataset["{}_tokenized".format(lang)].apply(lambda x: all_target_tokens.extend(x))
    return dataset, all_en_tokens, all_target_tokens

In [61]:
def read_and_tokenize_europarl_data(lang="de"):
    all_en_tok = []
    all_target_tok = []
    path_en = europarl_path + "/{}_en/europarl-v7.{}-en.en".format(lang, lang)
    path_target = europarl_path + "/{}_en/europarl-v7.{}-en.{}".format(lang, lang, lang)
    en_corpus = open(path_en, "r")
    target_corpus = open(path_target, "r")
    en_series = pd.Series(en_corpus.read().split("\n"))
    target_series = pd.Series(target_corpus.read().split("\n"))
    dataset = pd.DataFrame({"en":en_series, lang:target_series})
    for i in ["en", lang]:
        dataset["{}_tokenized".format(i)] = dataset[i].apply(lambda x: "".join(c for c in x if c not in string.punctuation).lower().split(" "))
        dataset["{}_tokenized".format(i)] = dataset["{}_tokenized".format(i)].apply(lambda x:[a+".{}".format(i) for a in x])
    dataset["en_tokenized"].apply(lambda x: all_en_tok.extend(x))
    dataset["{}_tokenized".format(lang)].apply(lambda x: all_target_tok.extend(x))
    return dataset, all_en_tokens, all_target_tokens

In [64]:
data_en_target, all_en_tokens, all_target_tokens = read_and_tokenize_opus_data(lang=config.val_test_lang)

In [65]:
data_en_target.head(3)

Unnamed: 0,en,zh,en_tokenized,zh_tokenized
0,"Ah, this is greasy. I want to eat kimchee.",好想要吃泡菜,"[ah.en, this.en, is.en, greasy.en, i.en, want....","[好.zh, 想要.zh, 吃.zh, 泡菜.zh]"
1,Is Chae Yoon's coordinator in here?,崔允的造型师在吗,"[is.en, chae.en, yoons.en, coordinator.en, in....","[崔.zh, 允.zh, 的.zh, 造型.zh, 造型师.zh, 在.zh, 吗.zh]"
2,"Excuse me, aren't you Chae Yoon's coordinator?",请问一下 你是不是崔允的造型师,"[excuse.en, me.en, arent.en, you.en, chae.en, ...","[请问.zh, 一下.zh, 你.zh, 是不是.zh, 不是.zh, 崔.zh, 允.zh..."


In [100]:
data_en_target["len_en"] = data_en_target["en_tokenized"].apply(lambda x: len(x))
data_en_target["len_{}".format(config.val_test_lang)] = \
data_en_target["{}_tokenized".format(config.val_test_lang)].apply(lambda x: len(x))

In [101]:
data_en_target = data_en_target[(data_en_target["len_en"] > 1)&(data_en_target["len_{}".format(config.val_test_lang)] > 1)]


In [66]:
def create_contrastive_dataset(dataset, trg_lang):
    shuffle_ix_src = torch.randperm(len(dataset))
    src_c = np.array([*dataset["{}_tokenized".format("en")].values])[shuffle_ix_src]
    trg_c = dataset["{}_tokenized".format(trg_lang)]
    contrastive_df = pd.DataFrame({"en_tokenized": src_c, "{}_tokenized".format(trg_lang): trg_c})
    return contrastive_df

In [67]:
c_df = create_contrastive_dataset(data_en_target, config.val_test_lang)

In [68]:
c_df = c_df.iloc[torch.randperm(len(c_df))]

In [69]:
c_df["{}_tokenized".format(config.experiment_lang)].iloc[:100000] = \
c_df["{}_tokenized".format(config.experiment_lang)].iloc[:100000]\
.apply(lambda x: [np.random.choice(x) for s in range(len(x)-1)])

In [70]:
c_df.head()

Unnamed: 0,en_tokenized,zh_tokenized
599889,"[why.en, is.en, it.en, always.en, some.en, fuc...","[對.zh, 扯.zh, 澤.zh, 頭.zh, 他.zh, 澤.zh]"
127534,"[reggie.en, usually.en, yeah.en, yeah.en]","[徜徉.zh, 徜徉.zh, 你.zh, 徜徉.zh]"
1577536,"[i.en, gotta.en, cook.en]",[]
218596,"[they.en, are.en]","[原谅.zh, 可以.zh, 你.zh]"
1301603,"[ive.en, touched.en, you.en, so.en, often.en, ...","[号叫.zh, 死者.zh, 号叫.zh, 华.zh, 失血.zh, 烂.zh, 命.zh,..."


In [71]:
shuffle_ix = torch.randperm(len(c_df))

In [72]:
c_df["{}_tokenized".format(config.experiment_lang)] = np.array(c_df["{}_tokenized".format(config.experiment_lang)])[shuffle_ix]

In [73]:
c_df.head()

Unnamed: 0,en_tokenized,zh_tokenized
599889,"[why.en, is.en, it.en, always.en, some.en, fuc...","[回到.zh, 维多.zh, 维多利.zh, 维多利亚.zh, 多利.zh, 多利亚.zh,..."
127534,"[reggie.en, usually.en, yeah.en, yeah.en]","[是因为.zh, 因为.zh, 露.zh, 娜.zh, 你.zh, 对.zh, 她.zh, ..."
1577536,"[i.en, gotta.en, cook.en]","[看看.zh, 我.zh, 这儿.zh, 有.zh, 多少.zh]"
218596,"[they.en, are.en]","[罗伯特.zh, 雷.zh, 福.zh]"
1301603,"[ive.en, touched.en, you.en, so.en, often.en, ...","[今天.zh, 早上.zh, 7.zh, 点.zh, 我.zh, 看见.zh, 他.zh, ..."


In [74]:
c_df["len_en"] = c_df["en_tokenized"].apply(lambda x: len(x))
c_df["len_{}".format(config.val_test_lang)] = \
c_df["{}_tokenized".format(config.val_test_lang)].apply(lambda x: len(x))

c_df = c_df[(c_df["len_en"] > 1)&(c_df["len_{}".format(config.val_test_lang)] > 1)]

In [79]:
c_df.head()

Unnamed: 0,en_tokenized,zh_tokenized,len_en,len_zh
599889,"[why.en, is.en, it.en, always.en, some.en, fuc...","[回到.zh, 维多.zh, 维多利.zh, 维多利亚.zh, 多利.zh, 多利亚.zh,...",9,8
127534,"[reggie.en, usually.en, yeah.en, yeah.en]","[是因为.zh, 因为.zh, 露.zh, 娜.zh, 你.zh, 对.zh, 她.zh, ...",4,11
1577536,"[i.en, gotta.en, cook.en]","[看看.zh, 我.zh, 这儿.zh, 有.zh, 多少.zh]",3,5
218596,"[they.en, are.en]","[罗伯特.zh, 雷.zh, 福.zh]",2,3
1301603,"[ive.en, touched.en, you.en, so.en, often.en, ...","[今天.zh, 早上.zh, 7.zh, 点.zh, 我.zh, 看见.zh, 他.zh, ...",11,15


In [102]:
class AlignDataset(Dataset):
    def __init__(self, data, max_sent_len, src_lang, trg_lang,
                 token2id, id2token):
        self.src = [*data["{}_tokenized".format(src_lang)].values]
        self.trg = [*data["{}_tokenized".format(trg_lang)].values]
        self.max_sent_len = int(max_sent_len)
        self.token2id, self.id2token = token2id, id2token
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, row):
        src_ix, trg_ix = [], []
        src_mask, trg_mask = [], []
        for w in self.src[row][:self.max_sent_len]:
            if w in self.token2id.keys():
                src_ix.append(self.token2id[w])
                src_mask.append(0)
            else:
                src_ix.append(UNK_IDX)
                src_mask.append(1)
        for w in self.trg[row][:self.max_sent_len]:
            if w in self.token2id.keys():
                trg_ix.append(self.token2id[w])
                trg_mask.append(0)
            else:
                trg_ix.append(UNK_IDX)
                trg_mask.append(1)
        
        src_list = [src_ix, src_mask, len(src_ix)]
        trg_list = [trg_ix, trg_mask, len(trg_mask)]
        return src_list + trg_list
    
def align_collate_func(batch, max_sent_len):
    src_data, trg_data = [], []
    src_mask, trg_mask = [], []
    src_len, trg_len = [], []
    
    for datum in batch:
        src_len.append(datum[2])
        trg_len.append(datum[5])
        src_data_padded = np.pad(np.array(datum[0]), pad_width=((0, max_sent_len-datum[2])), mode="constant", constant_values=PAD_IDX)
        src_data.append(src_data_padded)
        src_mask_padded = np.pad(np.array(datum[1]), pad_width=((0, max_sent_len-datum[2])), mode="constant", constant_values=PAD_IDX)
        src_mask.append(src_mask_padded)
        trg_data_padded = np.pad(np.array(datum[3]), pad_width=((0, max_sent_len-datum[5])), mode="constant", constant_values=PAD_IDX)
        trg_data.append(trg_data_padded)
        trg_mask_padded = np.pad(np.array(datum[4]), pad_width=((0, max_sent_len-datum[5])), mode="constant", constant_values=PAD_IDX)
        trg_mask.append(trg_mask_padded)
        
    ind_dec_order = np.argsort(src_len)[::-1]
    src_data = np.array(src_data)[ind_dec_order]
    trg_data = np.array(trg_data)[ind_dec_order]
    src_mask = np.array(src_mask)[ind_dec_order].reshape(len(batch), -1, 1)
    trg_mask = np.array(trg_mask)[ind_dec_order].reshape(len(batch), -1, 1)
    src_len = np.array(src_len)[ind_dec_order]
    trg_len = np.array(trg_len)[ind_dec_order]

    return [torch.from_numpy(src_data), torch.from_numpy(src_mask).float(), src_len,
            torch.from_numpy(trg_data), torch.from_numpy(trg_mask).float(), trg_len]

In [103]:
# her epochta kaydet
align_dataset = AlignDataset(data_en_target, config.max_sent_len, "en", config.experiment_lang,
                             token2id_mutual, id2token_mutual)
align_loader = torch.utils.data.DataLoader(dataset=align_dataset, batch_size=config.batch_size,
                               collate_fn=lambda x, max_sentence_length=config.max_sent_len: align_collate_func(x, config.max_sent_len),
                               shuffle=False)

In [104]:
c_align_dataset = AlignDataset(c_df, config.max_sent_len, "en", config.experiment_lang, 
                               token2id_mutual, id2token_mutual)
c_align_loader = torch.utils.data.DataLoader(dataset=c_align_dataset, batch_size=config.batch_size,
                               collate_fn=lambda x, max_sentence_length=config.max_sent_len: align_collate_func(x, config.max_sent_len),
                               shuffle=False)

In [105]:
def loss_align(en_rep, target_rep, en_c, target_c, lambda_reg):
    """:param en_rep: output repr of eng encoder (batch_size, hidden_size)
       :param target_rep: output repr of target encoder (batch_size, hidden_size)
       :param en_c: contrastive sentence repr from eng encoder (batch_size, hidden_size)
       :param target_c: contrastive sentence repr form target encoder (batch_size, hidden_size)
       :param lambda_reg: regularization coef [default: 0.25]

    Returns: L_align = l2norm (en_rep, target_rep) - lambda_reg( l2norm (en_c, target_rep) + l2norm (en_rep, target_c))
    """
    dist = torch.norm(en_rep - target_rep, 2)
    c_dist = torch.norm(en_c - target_rep, 2) + torch.norm(en_rep - target_c, 2)
    L_align = dist - lambda_reg*(c_dist)
    return L_align

In [106]:
def loss_align(en_rep, target_rep, en_c, target_c, lambda_reg):
    """:param en_rep: output repr of eng encoder (batch_size, hidden_size)
       :param target_rep: output repr of target encoder (batch_size, hidden_size)
       :param en_c: contrastive sentence repr from eng encoder (batch_size, hidden_size)
       :param target_c: contrastive sentence repr form target encoder (batch_size, hidden_size)
       :param lambda_reg: regularization coef [default: 0.25]

    Returns: L_align = l2norm (en_rep, target_rep) - lambda_reg( l2norm (en_c, target_rep) + l2norm (en_rep, target_c))
    """
    dist = torch.norm(en_rep - target_rep, 2)
    c_dist = torch.norm(en_c - target_rep, 2) + torch.norm(en_rep - target_c, 2)
    L_align = dist - lambda_reg*(c_dist)
    return L_align

In [107]:
class biLSTM(nn.Module):
    
    def __init__(self,
                 hidden_size,
                 embedding_weights,
                 percent_dropout,
                 vocab_size,
                 interaction_type="concat",
                 num_layers=1,
                 input_size=300,
                 src_trg = "src"):

        super(biLSTM, self).__init__()
        
        self.num_layers, self.hidden_size = num_layers, hidden_size
        
        self.embed_table = torch.from_numpy(embedding_weights).float()
        embedding = nn.Embedding.from_pretrained(self.embed_table)

        self.embedding = embedding
        self.interaction = interaction_type
        self.dropout = percent_dropout
        self.drop_out = nn.Dropout(self.dropout)
        
        self.LSTM = nn.LSTM(300, hidden_size, num_layers, batch_first=True, bidirectional=True)
        if self.LSTM.bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
            
        self.bn = nn.BatchNorm1d(self.hidden_size * self.num_directions)
        
    def init_hidden(self, batch_size):
        hidden = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions*self.num_layers, batch_size, self.hidden_size).to(device)
        return hidden, c_0
    
    def forward(self, sentence, mask, lengths):
        sort_original = sorted(range(len(lengths)), key=lambda sentence: -lengths[sentence])
        unsort_to_original = sorted(range(len(lengths)), key=lambda sentence: sort_original[sentence])
        
        sentence = sentence[sort_original]
        _mask = mask[sort_original]
        lengths = lengths[sort_original]
        batch_size, seq_len = sentence.size()
        self.hidden, self.c_0 = self.init_hidden(batch_size)
        
        # embdddings
        embeds = self.embedding(sentence)
        embeds = mask*embeds + (1-_mask)*embeds.clone().detach()
        embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, lengths, batch_first=True)
        # first lstm
        lstm_out, (self.hidden_1, self.c_1) = self.LSTM(embeds, (self.hidden, self.c_0))
        emb1, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
        emb1 = emb1.view(batch_size, -1, 2, self.hidden_size)
        emb1 = torch.max(emb1, dim=1)[0]
        emb1 = torch.cat([emb1[:,i,:] for i in range(self.num_directions)], dim=1)
        emb1 = emb1[unsort_to_original]
        
        out = self.bn(emb1)
        
        return out

In [108]:
class Discriminator(nn.Module):

    def __init__(self, n_langs, dis_layers, dis_hidden_dim, dis_dropout):

        super(Discriminator, self).__init__()
        self.n_langs = n_langs
        self.input_dim = config.hidden_dim * self.n_langs
        self.dis_layers = dis_layers
        self.dis_hidden_dim = dis_hidden_dim
        self.dis_dropout = dis_dropout

        layers = []
        for i in range(self.dis_layers + 1):
            if i == 0:
                input_dim = self.input_dim
            else:
                input_dim = self.dis_hidden_dim
            output_dim = self.dis_hidden_dim if i < self.dis_layers else self.n_langs
            layers.append(nn.Linear(input_dim, output_dim))
            if i < self.dis_layers:
                layers.append(nn.LeakyReLU(0.28))
                layers.append(nn.Dropout(self.dis_dropout))
        self.layers = nn.Sequential(*layers)

    def forward(self, input):
        out = self.layers(input)
        out = F.log_softmax(out, 1)
        return out

In [109]:
# src: always English
def train(LSTM_src, LSTM_trg, discriminator, loader, contrastive_loader, optimizer, dis_optim, epoch):
    LSTM_src.train()
    LSTM_trg.train()
    discriminator.train()
    total_loss = 0
    for batch_idx, ([src_data, src_mask, src_len, trg_data, trg_mask, trg_len],
                    [src_c, src_mc, src_len_c, trg_c, trg_mc, trg_len_c]) in \
        enumerate(zip(loader, contrastive_loader)):
        
        src_data, src_mask = src_data.to(device), src_mask.to(device)
        trg_data, trg_mask = trg_data.to(device), trg_mask.to(device)
        src_c, src_mc = src_c.to(device), src_mc.to(device)
        trg_c, trg_mc = trg_c.to(device), trg_mc.to(device)
        optimizer.zero_grad()
        dis_optim.zero_grad()
        if np.random.random() <= 0.02:
            src_data = src_data + torch.rand(src_data.size()).long().to(device)
            trg_data = trg_data + torch.rand(trg_data.size()).long().to(device)
#             src_c = src_c + torch.rand(src_c.size()).long().to(device)
#             trg_c = trg_c + torch.rand(trg_c.size()).long().to(device)
            
        src_out = LSTM_src(src_data, src_mask, src_len)
        trg_out = LSTM_trg(trg_data, trg_mask, trg_len)
        src_c_out = LSTM_src(src_c, src_mc, src_len_c)
        trg_c_out = LSTM_trg(trg_c, trg_mc, trg_len_c)
        loss = loss_align(src_out, trg_out, src_c_out, trg_c_out, 0.25)
        loss.cuda().backward(retain_graph=True)
        # step
        optimizer.step()
        total_loss += loss.item() * len(src_data) / len(loader.dataset)
        if (batch_idx+1) % (len(loader.dataset)//(50*config.batch_size)) == 0:
            
            dis_labels_src = torch.zeros(config.batch_size).long()
            dis_labels_trg = torch.ones(config.batch_size).long()
            dis_labels = torch.cat([dis_labels_src, dis_labels_trg], 0)
            idx = torch.randperm(config.batch_size * 2)
            dis_input = torch.cat([src_out, trg_out], 0)
            dis_input = dis_input[idx]
            dis_labels = dis_labels[idx].to(device)
            dis_out = discriminator(dis_input)
            dis_criterion = nn.NLLLoss()
            dis_loss = dis_criterion(dis_out, dis_labels)
            dis_loss.cuda().backward(retain_graph=True)
            dis_optim.step()
            
            loss += (-1) * dis_criterion(dis_out, dis_labels)
            loss.cuda().backward()
            optimizer.step()
            
            torch.save(LSTM_trg.state_dict(), "LSTM_en_{}_{}_epoch_{}".format(config.experiment_lang,
                                                                      config.experiment_lang.upper(),
                                                                      epoch))
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epoch, (batch_idx+1) * config.batch_size, len(loader.dataset),
                100. * (batch_idx+1) / len(loader), loss.item()))
            
    optimizer.zero_grad()
    return total_loss

In [110]:
load_epoch = 3
LSTM_src_model = biLSTM(hidden_size=config.hidden_dim, embedding_weights=weights_init ,num_layers=1, percent_dropout = config.dropout, 
             vocab_size=weights_init.shape[0], interaction_type="concat", input_size=300).to(device)
LSTM_src_model.load_state_dict(torch.load("best_encoder_eng_mnli_{}_{}".format(load_epoch, 
                                                                        config.experiment_lang)))
for param in LSTM_src_model.parameters():
    param.requires_grad = False

LSTM_trg_model = biLSTM(hidden_size=config.hidden_dim, embedding_weights=weights_init ,num_layers=1, percent_dropout = config.dropout, 
             vocab_size=weights_init.shape[0], interaction_type="concat", input_size=300).to(device)
LSTM_trg_model.load_state_dict(torch.load("best_encoder_eng_mnli_{}_{}".format(load_epoch, 
                                                                        config.experiment_lang)))

disc = Discriminator(n_langs = 2, dis_layers = 5, dis_hidden_dim = 128, dis_dropout = 0.1).to(device)

print ("Encoder src:\n", LSTM_src_model)
print ("Encoder trg:\n", LSTM_trg_model)
print ("Discriminator:\n", disc)

for epoch in range(config.epochs):
    print ("\nepoch = "+str(epoch))
    
    loss_train = train(LSTM_src=LSTM_src_model, LSTM_trg=LSTM_trg_model, discriminator = disc,
                       loader=align_loader, contrastive_loader=c_align_loader,
                       optimizer = torch.optim.Adam([*LSTM_src_model.parameters()] + [*LSTM_trg_model.parameters()] + [*disc.parameters()],
                                                    lr=config.lr), 
                       dis_optim = torch.optim.Adam([*disc.parameters()],
                                                    lr=config.lr), 
                       epoch = epoch)

    torch.save(LSTM_trg_model.state_dict(), "LSTM_en_{}_{}_epoch_{}".format(config.experiment_lang,
                                                                      config.experiment_lang.upper(),
                                                                      epoch))

Encoder src:
 biLSTM(
  (embedding): Embedding(420002, 300)
  (drop_out): Dropout(p=0.1)
  (LSTM): LSTM(300, 512, batch_first=True, bidirectional=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Encoder trg:
 biLSTM(
  (embedding): Embedding(420002, 300)
  (drop_out): Dropout(p=0.1)
  (LSTM): LSTM(300, 512, batch_first=True, bidirectional=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Discriminator:
 Discriminator(
  (layers): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.28)
    (2): Dropout(p=0.1)
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): LeakyReLU(negative_slope=0.28)
    (5): Dropout(p=0.1)
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): LeakyReLU(negative_slope=0.28)
    (8): Dropout(p=0.1)
    (9): Linear(in_features=128, out_features=128, bias=True)
    (10): LeakyR

RuntimeError: The size of tensor a (43) must match the size of tensor b (64) at non-singleton dimension 0

In [None]:
# minus NLL_loss?

In [None]:
# NOW: new kind of contrastiveness 