In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import math

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

from random import random
import sys
import pickle
import argparse

device = torch.device(0 if torch.cuda.is_available() else "cpu")

In [2]:
# read from train/test data files and return the tuple as (label, original_sent, candsent, trendid)
def readInData(filename):

    data = []
    trends = set([])
    
    (trendid, trendname, origsent, candsent, judge, origsenttag, candsenttag) = (None, None, None, None, None, None, None)
    
    for line in open(filename):
        line = line.strip()
        #read in training or dev data with labels
        if len(line.split('\t')) == 7:
            (trendid, trendname, origsent, candsent, judge, origsenttag, candsenttag) = line.split('\t')
        #read in test data without labels
        elif len(line.split('\t')) == 6:
            (trendid, trendname, origsent, candsent, origsenttag, candsenttag) = line.split('\t')
        else:
            continue
        
        #if origsent == candsent:
        #    continue
        
        trends.add(trendid)
        
        if judge == None:
            data.append((judge, origsent, candsent, trendid))
            continue

        # ignoring the training/test data that has middle label 
        if judge[0] == '(':  # labelled by Amazon Mechanical Turk in format like "(2,3)"
            nYes = eval(judge)[0]            
            if nYes >= 3:
                amt_label = True
                data.append((amt_label, origsent, candsent, trendid))
            elif nYes <= 1:
                amt_label = False
                data.append((amt_label, origsent, candsent, trendid))   
        elif judge[0].isdigit():   # labelled by expert in format like "2"
            nYes = int(judge[0])
            if nYes >= 4:
                expert_label = True
                data.append((expert_label, origsent, candsent, trendid))
            elif nYes <= 2:
                expert_label = False
                data.append((expert_label, origsent, candsent, trendid))     
            else:
                expert_label = None
                data.append((expert_label, origsent, candsent, trendid))        
                
    return data, trends


In [3]:
def generate_dict(embedding_path, d_model):
    d = {}
    embedding_list = []
    with open(embedding_path, 'r', encoding='utf-8') as f:
        line = f.readline()
        idx = 1
        while line:
            try:
                k = line.split()
                a = [float(w) for w in k[1:]]
                if (len(a)==d_model):
                    d[k[0].lower()] = idx
                    idx += 1
                    embedding_list.append(a)
            except:
                pass
            line = f.readline()
    tmp = []
    for i in range(d_model):
        tmp.append(0)
    embedding_list = [tmp] + embedding_list
    embedding = nn.Embedding.from_pretrained(torch.tensor(embedding_list), padding_idx=0)

    print('Reading embedding finished.')
        
    return d, embedding

In [4]:
def padding(x, max_len=10000):
#     max_len = 0
#     for xx in x:
#         if max_len < len(xx):
#             max_len = len(xx)
    for i in range(len(x)):
        xx = x[i]
        kk = len(xx)
        x[i] = xx + ([0] * (max_len - kk)) 
    return x

In [5]:
def get_index(d, sentence):
    s=sentence.strip().split()
    for i in range(len(s)):
        s[i]=s[i].lower()
        if s[i] in d.keys():
            s[i]=d[s[i]]
        else:
            s[i]=0
    return s

In [13]:
def preprocessing(embedding_path, input_path, testing=False, d_model=200):
    d, embedding = generate_dict(embedding_path, d_model)
    x0 = []
    x1 = []
    y = []
    max_len = 0
    trends, _ = readInData(input_path)

    for trend in trends:
        if testing:
            x0.append(get_index(d, trend[1]))
            x1.append(get_index(d, trend[2]))
            y.append(-1)
        else:
            if trend[0] == True:
                x0.append(get_index(d, trend[1]))
                x1.append(get_index(d, trend[2]))
                y.append(0)
            elif trend[0] == False:
                x0.append(get_index(d, trend[1]))
                x1.append(get_index(d, trend[2]))
                y.append(1)

    max_len = 0
    for xx in x0 + x1:
        if max_len < len(xx):
            max_len = len(xx)
    print("max length is: ", max_len)
    embedding=embedding.to(device)
    x0 = embedding(torch.tensor(padding(x0, max_len=max_len)).to(device))    
    x1 = embedding(torch.tensor(padding(x1, max_len=max_len)).to(device))    

    return x0.cpu(), x1.cpu(), torch.tensor(y, dtype=torch.float), embedding.cpu()

In [27]:
MODEL_SAVE_PATH = '../tmp/attention_model'

# Data & embedding configerations
d_model = 200
PRE_TRAINED_EMBEDDING_PATH = '../embedding/glove.twitter.27B.'+str(d_model)+'d.txt'
DATA_PATH = '../data/train.data'

In [28]:
x0, x1, Y, emb = preprocessing(PRE_TRAINED_EMBEDDING_PATH, DATA_PATH, testing=False, d_model=d_model)

Reading embedding finished.
max length is:  18


In [29]:
print(x0.size())
print(x1.size())
print(Y.size())
print(x0[0][:5])

torch.Size([11530, 18, 200])
torch.Size([11530, 18, 200])
torch.Size([11530])
tensor([[ 1.2212e-01,  3.3079e-01,  1.6658e-01, -3.7311e-01, -3.2807e-01,
         -5.2256e-01, -6.7980e-01,  2.9447e-01, -5.5401e-01,  5.1494e-01,
         -4.6707e-02, -3.5564e-01,  9.8064e-02, -3.6815e-02, -1.2640e-01,
         -3.9342e-01,  6.0168e-01, -3.4685e-01, -9.8971e-02,  1.4753e-01,
         -7.1833e-02,  3.2310e-01,  4.3638e-01,  1.7693e-01, -2.7088e-01,
         -1.1009e+00, -6.7499e-02,  3.9490e-02, -7.7714e-02,  1.0484e-01,
          4.9229e-01,  3.8817e-01,  4.7439e-02, -5.2111e-02, -2.9466e-01,
          3.1889e-01, -3.6786e-01, -1.5086e-01, -2.7244e-02,  2.4142e-01,
          1.8413e-01,  5.3505e-01,  2.9721e-01,  6.7245e-02, -4.7623e-01,
          2.4425e-01,  6.5088e-01,  3.2616e-01, -3.6000e-03, -1.0486e-01,
          1.5229e-01,  6.0477e-01,  9.4309e-02,  3.5175e-01, -1.8084e-01,
         -5.4886e-01, -8.2122e-03,  4.8639e-02, -1.4380e-01, -2.2617e-02,
          1.9567e-01,  1.4418e-01,

In [30]:
f=open(DATA_PATH+"_"+str(d_model)+"d.pkl", "wb")
pickle.dump(x0, f)
pickle.dump(x1, f)
pickle.dump(Y, f)
f.close()