In [93]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import random
import sys
import pickle
import argparse

device = torch.device(0 if torch.cuda.is_available() else "cpu")

In [94]:
class Tree(object):
    def __init__(self, idx):
        self.parent = None
        self.num_children = 0
        self.children = list()
        self.index = idx
        self.state = None

    def add_child(self, child):
        child.parent = self
        self.num_children += 1
        self.children.append(child)

    def size(self):
        if getattr(self, '_size'):
            return self._size
        count = 1
        for i in range(self.num_children):
            count += self.children[i].size()
        self._size = count
        return self._size

    def depth(self):
        if getattr(self, '_depth'):
            return self._depth
        count = 0
        if self.num_children > 0:
            for i in range(self.num_children):
                child_depth = self.children[i].depth()
                if child_depth > count:
                    count = child_depth
            count += 1
        self._depth = count
        return self._depth

In [95]:
d_model=50
TRAIN_DATA_PATH = '../data/train_data_tree_'+str(d_model)+'d.pkl'
f=open(TRAIN_DATA_PATH, "rb")
train_x0=pickle.load(f)
train_x1=pickle.load(f)
train_Y=pickle.load(f)
train_x0_r=pickle.load(f)
train_x1_r=pickle.load(f)
f.close()
train_Y = train_Y.long()

VAL_DATA_PATH = '../data/dev_data_tree_'+str(d_model)+'d.pkl'
f=open(VAL_DATA_PATH, "rb")
val_x0=pickle.load(f)
val_x1=pickle.load(f)
val_Y=pickle.load(f)
val_x0_r=pickle.load(f)
val_x1_r=pickle.load(f)
f.close()
val_Y = val_Y.long()

d_len=train_x0.size()[0]
s_len=train_x0.size()[1]

In [96]:
class ChildSumTreeLSTM(nn.Module):
    def __init__(self, in_dim, mem_dim):
        super(ChildSumTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.ioux = nn.Linear(self.in_dim, 3 * self.mem_dim)
        self.iouh = nn.Linear(self.mem_dim, 3 * self.mem_dim)
        self.fx = nn.Linear(self.in_dim, self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

    def node_forward(self, inputs, child_c, child_h):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        iou = self.ioux(inputs) + self.iouh(child_h_sum)
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)

        f = F.sigmoid(
            self.fh(child_h) +
            self.fx(inputs).repeat(len(child_h), 1)
        )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, F.tanh(c))
        return c, h

    def forward(self, tree, inputs):
        for idx in range(tree.num_children):
            self.forward(tree.children[idx], inputs)

        if tree.num_children == 0:
            child_c = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
            child_h = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
        else:
            child_c, child_h = zip(* map(lambda x: x.state, tree.children))
            child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)
        
        if tree.index != -1:
            tree.state = self.node_forward(inputs[tree.index], child_c, child_h)
        else:
            tree.state = self.node_forward(torch.zeros(self.in_dim).to(device), child_c, child_h)
        return tree.state


# module for distance-angle similarity
class Similarity(nn.Module):
    def __init__(self, mem_dim, hidden_dim, num_classes):
        super(Similarity, self).__init__()
        self.mem_dim = mem_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.wh = nn.Linear(2 * self.mem_dim, self.hidden_dim)
        self.wp = nn.Linear(self.hidden_dim, self.num_classes)

    def forward(self, lvec, rvec):
        mult_dist = torch.mul(lvec, rvec)
        abs_dist = torch.abs(torch.add(lvec, -rvec))
        vec_dist = torch.cat((mult_dist, abs_dist), 1)

        out = F.relu(self.wh(vec_dist))
        out = self.wp(out)
        return out


# putting the whole model together
class SimilarityTreeLSTM(nn.Module):
    def __init__(self, in_dim, mem_dim=256, hidden_dim=512, num_classes=2):
        super(SimilarityTreeLSTM, self).__init__()
        self.childsumtreelstm = ChildSumTreeLSTM(in_dim, mem_dim)
        self.similarity = Similarity(mem_dim, hidden_dim, num_classes)

    def forward(self, ltree, linputs, rtree, rinputs):
        lstate, lhidden = self.childsumtreelstm(ltree, linputs)
        rstate, rhidden = self.childsumtreelstm(rtree, rinputs)
        output = self.similarity(lstate, rstate)
        return output

In [97]:
def make_model(d_model):
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    model = SimilarityTreeLSTM(d_model)
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)
    return model

In [98]:
def compute_acc(val_x0, val_x1, val_Y, val_x0_r, val_x1_r, model):
    cor=0
    for j in range(len(val_x0)):
        y_out = model(val_x0_r[j], val_x0[j].to(device), val_x1_r[j], val_x1[j].to(device))
        y_out = F.log_softmax(y_out, dim=1)
        y_out = torch.argmax(y_out, dim=1)
        if y_out[0].cpu()==val_Y[j]:
            cor+=1
    print("current accuracy is :", cor/len(val_x0))
    return cor/len(val_x0)

In [99]:
model = make_model(d_model).to(device)
criterion = nn.CrossEntropyLoss()
learning_rate=1e-3
reg=1e-7
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=reg)
tot_epoch=1000
batch_size=100
batch_arrange = [i for i in range(0, d_len)]
loss_history=[]
val_acc_history=[]



In [100]:
torch.cuda.empty_cache()

total_loss = 0.0
max_acc = 0.0

for epoch in range(tot_epoch):
    total_loss = 0.0
    random.shuffle(batch_arrange)
    for (idx, i) in enumerate(batch_arrange):
        y_out = model(train_x0_r[i], train_x0[i].to(device), train_x1_r[i], train_x1[i].to(device))
        loss = criterion(y_out, train_Y[i].unsqueeze_(0).to(device))
        loss.backward()
        total_loss += loss.item()
        if idx % batch_size == 0 and idx > 0:
            optimizer.step()
            optimizer.zero_grad()
            loss_history.append(total_loss)
    print("epoch ", epoch, ": current loss ", total_loss, sep="")
    cur_acc = compute_acc(val_x0, val_x1, val_Y, val_x0_r, val_x1_r, model)
    val_acc_history.append(cur_acc)
    if cur_acc > max_acc:
        max_acc=cur_acc
        torch.save(model.state_dict(), 'models/treelstm_' + str(d_model) +'d_epoch_' + str(epoch) + '.torch')



KeyboardInterrupt: 

In [None]:
f=open('treelstm_'+str(d_model)+'d.his', 'wb')
pickle.dump(loss_history, f)
pickle.dump(val_acc_history, f)
f.close()