In [60]:
import pytreebank
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import re

In [61]:
bank = pytreebank.load_sst("stanfordSentimentTreebank/trees")

dictFile = open("stanfordSentimentTreebank/dictionary.txt")
lines = dictFile.readlines()
exp = r'^(\S+)\|\d+$'
unfilterdWords = map(lambda line: re.findall(exp, line), lines)
words = list(map(lambda x: x.replace('\\', ''),
            map(lambda x: x[0],
            filter(lambda x: len(x) > 0, 
            unfilterdWords))))
words.append('8 1/2')
words.append('2 1/2')
words.append('9 1/2')

dictionary = dict((word, number) for number, word in enumerate(words))

def oneHotEncoding(word):
    ##if word.find('/') != -1:
    ##    tmp = word.split('/')
    ##    word = tmp[0] + '\/' + tmp[1]
    vector = torch.zeros(1, len(dictionary)).cuda()
    vector[0][dictionary[word]] = 1
    return vector

def getLabels(tree):
    labels = []
    if len(tree.children) == 2:
        labels.extend(getLabels(tree.children[0]))
        labels.extend(getLabels(tree.children[1]))
        labels.extend([tree.label])
        return labels
    else:
        return [tree.label]

In [62]:
class RNTN(nn.Module):
    
    def __init__(self, vocabularySize, classes = 5, d = 25):
        super(RNTN, self).__init__()
        self.d = d
        self.L = nn.Linear(vocabularySize, d, bias=False)
        self.W = nn.Linear(d * 2, d)
        self.Ws = nn.Linear(d,  classes)
        self.register_parameter('V', nn.Parameter(torch.rand(2 * d, 2 * d, d).cuda()))
        self.lSoftmax = nn.LogSoftmax(dim=1)
    
    def tensorProduct(self, phrase):    
        result = torch.empty(1, self.d).cuda()
        for i in range(self.d):
            result[0][i] = torch.mm(phrase, torch.mm(self.V[:,:,i], torch.t(phrase)))
        return result
    
    def embed(self, inpt):
        return self.L(inpt)
    
    def getSentiment(self, inpt):
        return self.lSoftmax(self.Ws(inpt))
    
    def forward(self, root):
        
        self.outputs = []
        self.phraseStack = []
        visited = []
        stack = [root]
        
        while len(stack) > 0:
            
            node = stack[-1]
            if len(node.children) == 2:
                ## Calculate phrase vector of the children
                if node not in visited:
                    stack.append(node.children[1])
                    stack.append(node.children[0])
                    visited.append(node)
                else:
                    ## Calculate phrase vector of the node
                    inpt2 = self.phraseStack.pop()
                    inpt1 = self.phraseStack.pop()
                    phraseVec = torch.cat([inpt1, inpt2], dim=1)
                    phraseVec = torch.tanh(self.tensorProduct(phraseVec) + self.W(phraseVec))
                    self.phraseStack.append(phraseVec)
                    
                    ## Save the outputs of the backpropagation
                    self.outputs = torch.cat([self.outputs, self.getSentiment(phraseVec)], dim=0)
                    
                    stack.pop()
            else:
                phraseVec = self.embed(oneHotEncoding(node.to_lines()[0]))
                self.phraseStack.append(phraseVec)
                if len(self.outputs) == 0:
                    self.outputs = self.getSentiment(phraseVec)
                else:
                    self.outputs = torch.cat([self.outputs, self.getSentiment(phraseVec)], dim=0)
                stack.pop()
        return self.outputs

In [63]:
testNet = RNTN(len(dictionary))
testNet.cuda()
testNet.load_state_dict(torch.load('./net.pth'))
testNet.eval()

with torch.no_grad():
    correct = 0
    total = len(bank['test'])
    for sentence in bank['test']:
        outputs = testNet(sentence).cuda()
        targets = torch.tensor(getLabels(sentence)).cuda()

        if torch.argmax(outputs[-1]) == targets[-1]:
            correct += 1

    print(f'Fine grained accuracy on the test set: {correct/total * 100}%')

Fine grained accuracy on the test set: 35.88235294117647%
