In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import math

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

from random import random
import sys
import pickle
import argparse

device = torch.device(0 if torch.cuda.is_available() else "cpu")

In [2]:
# read from train/test data files and return the tuple as (label, original_sent, candsent, trendid)
def readInData(filename):

    data = []
    trends = set([])
    
    (trendid, trendname, origsent, candsent, judge, origsenttag, candsenttag) = (None, None, None, None, None, None, None)
    
    for line in open(filename):
        line = line.strip()
        #read in training or dev data with labels
        if len(line.split('\t')) == 7:
            (trendid, trendname, origsent, candsent, judge, origsenttag, candsenttag) = line.split('\t')
        #read in test data without labels
        elif len(line.split('\t')) == 6:
            (trendid, trendname, origsent, candsent, origsenttag, candsenttag) = line.split('\t')
        else:
            continue
        
        #if origsent == candsent:
        #    continue
        
        trends.add(trendid)
        
        if judge == None:
            data.append((judge, origsent, candsent, trendid))
            continue

        # ignoring the training/test data that has middle label 
        if judge[0] == '(':  # labelled by Amazon Mechanical Turk in format like "(2,3)"
            nYes = eval(judge)[0]            
            if nYes >= 3:
                amt_label = True
                data.append((amt_label, origsent, candsent, trendid))
            elif nYes <= 1:
                amt_label = False
                data.append((amt_label, origsent, candsent, trendid))   
        elif judge[0].isdigit():   # labelled by expert in format like "2"
            nYes = int(judge[0])
            if nYes >= 4:
                expert_label = True
                data.append((expert_label, origsent, candsent, trendid))
            elif nYes <= 2:
                expert_label = False
                data.append((expert_label, origsent, candsent, trendid))     
            else:
                expert_label = None
                data.append((expert_label, origsent, candsent, trendid))        
                
    return data, trends


In [3]:
def generate_dict(embedding_path, d_model):
    d = {}
    embedding_list = []
    with open(embedding_path, 'r', encoding='utf-8') as f:
        line = f.readline()
        idx = 1
        while line:
            try:
                k = line.split()
                a = [float(w) for w in k[1:]]
                if (len(a)==d_model):
                    d[k[0].lower()] = idx
                    idx += 1
                    embedding_list.append(a)
            except:
                pass
            line = f.readline()
    tmp = []
    for i in range(d_model):
        tmp.append(0)
    embedding_list = [tmp] + embedding_list
    embedding = nn.Embedding.from_pretrained(torch.tensor(embedding_list), padding_idx=0)

    print('Reading embedding finished.')
        
    return d, embedding

In [4]:
def padding(x, max_len=10000):
#     max_len = 0
#     for xx in x:
#         if max_len < len(xx):
#             max_len = len(xx)
    for i in range(len(x)):
        xx = x[i]
        kk = len(xx)
        x[i] = xx + ([0] * (max_len - kk)) 
    return x

In [5]:
def get_index(d, sentence):
    s=sentence.strip().split()
    for i in range(len(s)):
        s[i]=s[i].lower()
        if s[i] in d.keys():
            s[i]=d[s[i]]
        else:
            s[i]=0
    return s

In [10]:
def preprocessing(embedding_path, input_path, testing=False, d_model=200, max_len=None):
    d, embedding = generate_dict(embedding_path, d_model)
    x0 = []
    x1 = []
    y = []
    trends, _ = readInData(input_path)

    for trend in trends:
        if testing:
            x0.append(get_index(d, trend[1]))
            x1.append(get_index(d, trend[2]))
            y.append(-1)
        else:
            if trend[0] == True:
                x0.append(get_index(d, trend[1]))
                x1.append(get_index(d, trend[2]))
                y.append(0)
            elif trend[0] == False:
                x0.append(get_index(d, trend[1]))
                x1.append(get_index(d, trend[2]))
                y.append(1)
    
    if max_len==None:
        max_len = 0
        for xx in x0 + x1:
            if max_len < len(xx):
                max_len = len(xx)
    print("max length is: ", max_len)
    embedding=embedding.to(device)
    x0 = embedding(torch.tensor(padding(x0, max_len=max_len)).to(device))    
    x1 = embedding(torch.tensor(padding(x1, max_len=max_len)).to(device))    

    return x0.cpu(), x1.cpu(), torch.tensor(y, dtype=torch.float), embedding.cpu()

In [23]:
MODEL_SAVE_PATH = '../tmp/attention_model'

# Data & embedding configerations
d_model = 200
PRE_TRAINED_EMBEDDING_PATH = '../embedding/glove.twitter.27B.'+str(d_model)+'d.txt'
DATA_PATH = '../data/dev.data'
OUTPUT_PATH = '../data/dev_data'

In [24]:
x0, x1, Y, emb = preprocessing(PRE_TRAINED_EMBEDDING_PATH, DATA_PATH, testing=False, d_model=d_model, max_len=18)

Reading embedding finished.
max length is:  18


In [25]:
print(x0.size())
print(x1.size())
print(Y.size())
print(x0[0][:5])

torch.Size([4142, 18, 200])
torch.Size([4142, 18, 200])
torch.Size([4142])
tensor([[ 1.4931e-01,  2.7889e-01,  8.9979e-02,  4.0882e-01, -2.1328e-01,
          1.5406e-01, -2.5642e-02, -6.4515e-01, -7.1643e-01, -1.1794e-01,
         -2.9600e-01, -4.3363e-01, -2.1885e-01,  3.2778e-02,  1.5606e-01,
          2.2966e-02, -5.3795e-02,  3.3622e-01, -6.2113e-01,  1.0144e-01,
          2.3716e-01, -5.1758e-02,  2.9100e-01, -4.3310e-01,  5.1603e-01,
         -1.9666e+00,  2.0311e-01,  6.6447e-02,  1.5362e-01,  6.4771e-01,
         -3.8559e-01,  4.7402e-03, -5.2268e-02, -1.0286e-01,  6.7909e-03,
          5.1034e-01, -1.9149e-01, -1.0676e-01, -9.3639e-01,  2.3279e-01,
         -6.8884e-01,  4.6741e-02,  1.0391e-01,  1.7044e-01,  5.3320e-01,
         -1.6093e-01,  9.8364e-02,  3.6096e-01,  7.6576e-02,  4.0381e-01,
         -2.1510e-02,  6.4061e-02, -3.2644e-01, -1.5550e-01, -1.4447e-02,
          5.5337e-01,  2.5903e-01,  1.0481e-01,  3.1606e-01,  2.1116e-01,
          3.0245e-01, -1.8877e-01, -5

In [26]:
f=open(OUTPUT_PATH+"_"+str(d_model)+"d.pkl", "wb")
pickle.dump(x0, f)
pickle.dump(x1, f)
pickle.dump(Y, f)
f.close()