# 10.Tree Recursive Neural Networks and Constituency Parsing

* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf
* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf

In [74]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter, OrderedDict
import nltk
from copy import deepcopy
import os
flatten = lambda l: [item for sublist in l for item in sublist]

In [75]:
USE_CUDA = torch.cuda.is_available()

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

In [76]:
def getBatch(batch_size,train_data):
    random.shuffle(train_data)
    sindex=0
    eindex=batch_size
    while eindex < len(train_data):
        batch = train_data[sindex:eindex]
        temp = eindex
        eindex = eindex+batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

## Data load and Preprocessing

### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)

In [314]:
sample = random.choice(open('../dataset/trees/train.txt','r',encoding='utf-8').readlines())
print(sample)

(2 (2 (1 (0 (1 (2 It) (1 (2 would) (1 (2 be) (1 (1 disingenuous) (3 (2 to) (4 (2 (2 call) (2 Reno)) (4 (2 a) (4 (4 great) (2 film))))))))) (2 ,)) (2 but)) (2 (2 you) (2 (2 can) (2 (2 say) (2 (2 (2 that) (2 (2 about) (2 most))) (2 (2 of) (2 (2 (2 the) (2 flicks)) (2 (2 (3 moving) (2 (2 (2 in) (2 and)) (1 out))) (2 (2 of) (2 (2 the) (2 multiplex))))))))))) (2 .))



### Tree Class 

borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3

In [67]:
class Node:  # a node in the tree
    def __init__(self, label, word=None):
        self.label = label
        self.word = word
        self.parent = None  # reference to parent
        self.left = None  # reference to left child
        self.right = None  # reference to right child
        # true if I am a leaf (could have probably derived this from if I have
        # a word)
        self.isLeaf = False
        # true if we have finished performing fowardprop on this node (note,
        # there are many ways to implement the recursion.. some might not
        # require this flag)

    def __str__(self):
        if self.isLeaf:
            return '[{0}:{1}]'.format(self.word, self.label)
        return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)


class Tree:

    def __init__(self, treeString, openChar='(', closeChar=')'):
        tokens = []
        self.open = '('
        self.close = ')'
        for toks in treeString.strip().split():
            tokens += list(toks)
        self.root = self.parse(tokens)
        # get list of labels as obtained through a post-order traversal
        self.labels = get_labels(self.root)
        self.num_words = len(self.labels)

    def parse(self, tokens, parent=None):
        assert tokens[0] == self.open, "Malformed tree"
        assert tokens[-1] == self.close, "Malformed tree"

        split = 2  # position after open and label
        countOpen = countClose = 0

        if tokens[split] == self.open:
            countOpen += 1
            split += 1
        # Find where left child and right child split
        while countOpen != countClose:
            if tokens[split] == self.open:
                countOpen += 1
            if tokens[split] == self.close:
                countClose += 1
            split += 1

        # New node
        node = Node(int(tokens[1]))  # zero index labels

        node.parent = parent

        # leaf Node
        if countOpen == 0:
            node.word = ''.join(tokens[2:-1]).lower()  # lower case?
            node.isLeaf = True
            return node

        node.left = self.parse(tokens[2:split], parent=node)
        node.right = self.parse(tokens[split:-1], parent=node)

        return node

    def get_words(self):
        leaves = getLeaves(self.root)
        words = [node.word for node in leaves]
        return words

def loadTrees(dataSet='train'):
    """
    Loads training trees. Maps leaf node words to word ids.
    """
    file = '../dataset/trees/%s.txt' % dataSet
    print("Loading %s trees.." % dataSet)
    with open(file, 'r',encoding='utf-8') as fid:
        trees = [Tree(l) for l in fid.readlines()]

    return trees

In [297]:
train_data = loadTrees('train')

Loading train trees..


### Build Vocab 

In [298]:
vocab = list(set(flatten([t.get_words() for t in train_data])))

In [299]:
word2index={}
for vo in vocab:
    if vo not in word2index.keys():
        word2index[vo]=len(word2index)
        
index2word = {v:k for k,v in word2index.items()}

## Modeling 

In [315]:
class RNTN(nn.Module):
    
    def __init__(self,word2index,hidden_size,output_size):
        super(RNTN,self).__init__()
        
        self.word2index = word2index
        self.embed = nn.Embedding(len(word2index),hidden_size)
        self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size*2,hidden_size*2)) for _ in range(hidden_size)]) # Tensor
        self.W = nn.Parameter(torch.randn(hidden_size*2,hidden_size))
        self.W_out = nn.Linear(hidden_size,output_size)
        
    def tree_propagation(self,node):
        
        recursive_tensor = OrderedDict()
        current=None
        if node.isLeaf:
            tensor = Variable(LongTensor([self.word2index[node.word]]))
            current = self.embed(tensor) # 1xD
        else:
            recursive_tensor.update(self.tree_propagation(node.left))
            recursive_tensor.update(self.tree_propagation(node.right))
            
            concated = torch.cat([recursive_tensor[node.left],recursive_tensor[node.right]],1) # 1x2D
            xVx=[] 
            for i,v in enumerate(self.V):
                xVx.append(torch.matmul(torch.matmul(concated,v),concated.transpose(0,1)))
            
            xVx = torch.cat(xVx,1) # 1xD
            Wx = torch.matmul(concated,self.W) # 1xD

            current = F.tanh(xVx+Wx)
        recursive_tensor[node]=current
        return recursive_tensor
        
    def forward(self,Trees,root_only=False):
        
        propagated=[]
        if not isinstance(Trees,list):
            Trees = [Trees]
            
        for Tree in Trees:
            recursive_tensor = self.tree_propagation(Tree.root)
            if root_only:
                recursive_tensor = recursive_tensor[Tree.root]
                propagated.append(recursive_tensor)
            else:
                recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]
                propagated.extend(recursive_tensor)
        
        propagated = torch.cat(propagated)
              
        return F.log_softmax(self.W_out(propagated))

## Training 

In [322]:
HIDDEN_SIZE = 35
ROOT_ONLY = False
BATCH_SIZE = 20
EPOCH = 5

In [323]:
model = RNTN(word2index,HIDDEN_SIZE,5)
if USE_CUDA:
    model = model.cuda()

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

In [None]:
for epoch in range(EPOCH):
    losses=[]
    for i, batch in enumerate(getBatch(BATCH_SIZE,train)):
        
        if ROOT_ONLY:
            labels = [tree.labels[-1] for tree in batch]
            labels = Variable(LongTensor(labels))
        else:
            labels = [tree.labels for tree in batch]
            labels = Variable(LongTensor(flatten(labels)))
        
        model.zero_grad()
        preds = model(batch,ROOT_ONLY)
        
        loss = loss_function(preds,labels)
        losses.append(loss.data.tolist()[0])
        
        loss.backward()
        optimizer.step()
        
        if i % 100==0:
            print('[%d/%d] mean_loss : %.2f' % (epoch,EPOCH,np.mean(losses)))
            losses=[]
        

[0/5] mean_loss : 1.76
[0/5] mean_loss : 1.62
[0/5] mean_loss : 1.45
[0/5] mean_loss : 1.36
[0/5] mean_loss : 1.30
[1/5] mean_loss : 1.25


## Test

## TODO 

* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold