In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import math

from random import random
import sys
import pickle
import argparse

device = torch.device(0 if torch.cuda.is_available() else "cpu")

In [2]:
class Tree(object):
    def __init__(self, idx):
        self.parent = None
        self.num_children = 0
        self.children = list()
        self.index = idx
        self.state = None

    def add_child(self, child):
        child.parent = self
        self.num_children += 1
        self.children.append(child)

    def size(self):
        if getattr(self, '_size'):
            return self._size
        count = 1
        for i in range(self.num_children):
            count += self.children[i].size()
        self._size = count
        return self._size

    def depth(self):
        if getattr(self, '_depth'):
            return self._depth
        count = 0
        if self.num_children > 0:
            for i in range(self.num_children):
                child_depth = self.children[i].depth()
                if child_depth > count:
                    count = child_depth
            count += 1
        self._depth = count
        return self._depth

In [3]:
# read from train/test data files and return the tuple as (label, original_sent, candsent, trendid)
def readInData(filename):

    data = []
    trends = set([])
    
    (trendid, trendname, origsent, candsent, judge, origsenttag, candsenttag) = (None, None, None, None, None, None, None)
    
    for line in open(filename):
        line = line.strip()
        #read in training or dev data with labels
        if len(line.split('\t')) == 7:
            (trendid, trendname, origsent, candsent, judge, origsenttag, candsenttag) = line.split('\t')
        #read in test data without labels
        elif len(line.split('\t')) == 6:
            (trendid, trendname, origsent, candsent, origsenttag, candsenttag) = line.split('\t')
        else:
            continue
        
        #if origsent == candsent:
        #    continue
        
        trends.add(trendid)
        
        if judge == None:
            data.append((judge, origsenttag, candsenttag, trendid))
            continue

        # ignoring the training/test data that has middle label 
        if judge[0] == '(':  # labelled by Amazon Mechanical Turk in format like "(2,3)"
            nYes = eval(judge)[0]
            data.append((nYes, origsenttag, candsenttag, trendid))
        elif judge[0].isdigit():   # labelled by expert in format like "2"
            nYes = int(judge[0])
            data.append((nYes, origsenttag, candsenttag, trendid))        
                
    return data, trends

In [4]:
def generate_dict(embedding_path, d_model):
    d = {}
    embedding_list = []
    with open(embedding_path, 'r', encoding='utf-8') as f:
        line = f.readline()
        idx = 1
        while line:
            try:
                k = line.split()
                a = [float(w) for w in k[1:]]
                if (len(a)==d_model):
                    d[k[0].lower()] = idx
                    idx += 1
                    embedding_list.append(a)
            except:
                pass
            line = f.readline()
    tmp = []
    for i in range(d_model):
        tmp.append(0)
    embedding_list = [tmp] + embedding_list
    embedding = nn.Embedding.from_pretrained(torch.tensor(embedding_list), padding_idx=0)

    print('Reading embedding finished.')
        
    return d, embedding

In [5]:
def padding(x, max_len=10000):
#     max_len = 0
#     for xx in x:
#         if max_len < len(xx):
#             max_len = len(xx)
    for i in range(len(x)):
        xx = x[i]
        kk = len(xx)
        x[i] = xx + ([0] * (max_len - kk)) 
    return x

In [6]:
def generate_tree(d, sentence):
    s=sentence.strip().split()
    temp=[]
    for i in range(len(s)):
        s[i]=s[i].lower()
        s[i]=s[i].split('/')
        cur_node = Tree(i)
        
        word=s[i][0]
        s0=s[i][1]
        s1=s[i][3]
        if s0[0]=='b':
            temp.append(Tree(-1))
            temp[-1].add_child(cur_node)
        elif s0[0]=='i':
            temp[-1].add_child(cur_node)
        else:
            if s1[0]=='b':
                temp.append(Tree(-1))
                temp[-1].add_child(cur_node)
            elif (s1[0]=='i'):
                temp[-1].add_child(cur_node)
            else:
                temp.append(cur_node)

        if word in d.keys():
            s[i]=d[word]
        else:
            s[i]=0
            
    root = Tree(-1)
    for child in temp:
        root.add_child(child)
    
    return s, root

In [7]:
def preprocessing(embedding_path, input_path, testing=False, d_model=200, max_len=None):
    d, embedding = generate_dict(embedding_path, d_model)
    x0 = []
    x0_r = []
    x1 = []
    x1_r = []
    y = []
    trends, _ = readInData(input_path)

    for trend in trends:
        if testing:
            s, r = generate_tree(d, trend[1])
            x0.append(s)
            x0_r.append(r)
            s, r = generate_tree(d, trend[2])
            x1.append(s)
            x1_r.append(r)
            y.append(-1)
        else:
            s, r = generate_tree(d, trend[1])
            x0.append(s)
            x0_r.append(r)
            s, r = generate_tree(d, trend[2])
            x1.append(s)
            x1_r.append(r)
            y.append(trend[0])
    
    if max_len==None:
        max_len = 0
        for xx in x0 + x1:
            if max_len < len(xx):
                max_len = len(xx)
    print("max length is: ", max_len)
    embedding=embedding.to(device)
    x0 = embedding(torch.tensor(padding(x0, max_len=max_len)).to(device))    
    x1 = embedding(torch.tensor(padding(x1, max_len=max_len)).to(device))    

    return x0.cpu(), x1.cpu(), torch.tensor(y, dtype=torch.float), x0_r, x1_r

In [14]:
# Data & embedding configerations
d_model = 50
PRE_TRAINED_EMBEDDING_PATH = '../embedding/glove.twitter.27B.'+str(d_model)+'d.txt'
DATA_PATH = '../data/test.data'
OUTPUT_PATH = '../data/test_data_tree'

In [15]:
x0, x1, Y, x0_r, x1_r = preprocessing(PRE_TRAINED_EMBEDDING_PATH, DATA_PATH, testing=False, d_model=d_model, max_len=18)

Reading embedding finished.
max length is:  18


In [16]:
f=open(OUTPUT_PATH+"_"+str(d_model)+"d.pkl", "wb")
pickle.dump(x0, f)
pickle.dump(x1, f)
pickle.dump(Y, f)
pickle.dump(x0_r, f)
pickle.dump(x1_r, f)
f.close()