In [3]:
import numpy as np
from scipy.stats import pearsonr,spearmanr
import sys
from os.path import join
from train_model import sent_util
import torch
from torchtext import data, datasets
import pandas as pd
import nltk
import nltk.corpus

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def get_sst_trees():
    
    inputs = data.Field(lower='preserve-case')
    answers = data.Field(sequential=False, unk_token=None)

    train_s, dev_s, test_s = datasets.SST.splits(inputs, answers, 
                                                 fine_grained = False, 
                                                 train_subtrees = True,
                                                 filter_pred=lambda ex: ex.label != 'neutral')
    
    inputs.build_vocab(train_s, dev_s, test_s)
    answers.build_vocab(train_s)
    
    train_iter, dev_iter, test_iter = data.BucketIterator.splits(
        (train_s, dev_s, test_s), batch_size=1, device=device)
    
    return inputs, answers, train_iter, dev_iter, test_iter


In [6]:
inputs, answers, train_iter, dev_iter, test_iter = get_sst_trees()

In [7]:
batches = sent_util.get_batches_iterator([i for i in range(10)], train_iter)

getting batches...


In [8]:
for index in batches:
    text = batches[index].text.data[:,0]
    print([inputs.vocab.itos[i] for i in text], answers.vocab.itos[batches[index].label.data])

['to', 'sneak', 'out', 'of', 'the', 'theater'] negative
['sad'] negative
['is', 'this', 'films', 'reason', 'for', 'being', '.'] positive
['you', 'wish', 'had', 'been', 'developed', 'with', 'more', 'care'] negative
['of', 'the', 'holiday', 'box', 'office', 'pie'] positive
['to', 'this', 'film', 'that', 'may', 'not', 'always', 'work'] negative
['i', 'killed', 'my', 'father', 'compelling'] positive
[',', 'mostly', 'martha', 'will', 'leave', 'you', 'with', 'a', 'smile', 'on', 'your', 'face', 'and', 'a', 'grumble', 'in', 'your', 'stomach', '.'] positive
['theater'] positive


In [9]:
sst_reader = nltk.corpus.BracketParseCorpusReader("/Users/silanhe/Documents/McGill/Grad/WINTER2020/NLU/sst/trees", ".*.txt")

In [10]:
sst_sentences = sst_reader.sents("train.txt")

In [11]:
sst = sst_reader.parsed_sents("train.txt")

In [12]:
len(sst)

8544

In [41]:
def travelTree(node):
    
    index_words = 0
    
    def dfs(node):
        nonlocal index_words
        if isinstance(node,str):
            list_return = [index_words]
            index_words += 1
            return list_return

        else:
            score = node.label()
            len_node = len(node)
            
            subtree_list_words = []
            if len(node) > 0:
                subtree_list_words += dfs(node[0])
            if len(node) > 1:
                subtree_list_words += dfs(node[1])
            
            # get CD score
#             print(score,min(subtree_list_words), max(subtree_list_words))
            
            return subtree_list_words
    dfs(node)

In [42]:
for index,tree in enumerate(sst):
    words= [nltk.word_tokenize(word.lower())[0] for word in sst_sentences[index]]
    print(' '.join(words))
    print(tree)
    travelTree(tree)
    break



the rock is destined to be the 21st century 's new `` conan '' and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .
(3
  (2 (2 The) (2 Rock))
  (4
    (3
      (2 is)
      (4
        (2 destined)
        (2
          (2
            (2
              (2
                (2 to)
                (2
                  (2 be)
                  (2
                    (2 the)
                    (2
                      (2 21st)
                      (2
                        (2 (2 Century) (2 's))
                        (2 (3 new) (2 (2 ``) (2 Conan))))))))
              (2 ''))
            (2 and))
          (3
            (2 that)
            (3
              (2 he)
              (3
                (2 's)
                (3
                  (2 going)
                  (3
                    (2 to)
                    (4
                      (3
                        (2 make)
                        (3
                     

In [16]:
tree.label()

'3'

In [22]:
tree[0][0][0]

'The'