In [1]:
import pandas as pd
import numpy as np
import os
from dlmslib.torch_models import trees, nlp_models

import gc

In [2]:
DATA_ROOT = '../input/'
ORIGINAL_DATA_FOLDER = os.path.join(DATA_ROOT, 'movie-review-sentiment-analysis-kernels-only')
TREEBANK_DATA_FOLDER = os.path.join(DATA_ROOT, 'stanford-sentiment-treebank')

In [3]:
def read_tree_bank_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as fid:
        tree_list = [trees.LabeledTextBinaryTreeNode.parse_ptb_string(l) for l in fid.readlines()]
    return tree_list

In [4]:
train_data_path = os.path.join(TREEBANK_DATA_FOLDER, 'train.txt')
test_data_path = os.path.join(TREEBANK_DATA_FOLDER, 'test.txt')
dev_data_path = os.path.join(TREEBANK_DATA_FOLDER, 'dev.txt')

train_trees = read_tree_bank_file(train_data_path)
test_trees = read_tree_bank_file(test_data_path)
dev_trees = read_tree_bank_file(dev_data_path)

# Load Embeddings

In [5]:
def load_embed(file):
    def get_coefs(word,*arr): 
        return word, np.asarray(arr[:len(arr)-1], dtype='float32')
    
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(file) if len(o)>15)
        
    return embeddings_index

In [6]:
pretrained_w2v_path = os.path.join(DATA_ROOT, "fasttext-crawl-300d-2m/crawl-300d-2M.vec")
w2v_fasttext = load_embed(pretrained_w2v_path)

# Build Vocab

In [7]:
UNKNOWN_TOKEN = '<UNK>'
EMB_DIM = 300

def map_unknown_token(tree, embeddings_index):
    if tree is None:
        return
    
    word = tree.text
    if word not in embeddings_index:
            tree.text = UNKNOWN_TOKEN
    
    map_unknown_token(tree.left, embeddings_index)
    map_unknown_token(tree.right, embeddings_index)

In [8]:
for tree in train_trees:
    map_unknown_token(tree, w2v_fasttext)
for tree in test_trees:
    map_unknown_token(tree, w2v_fasttext)
for tree in dev_trees:
    map_unknown_token(tree, w2v_fasttext)


flatten = lambda l: [item for sublist in l for item in sublist]
vocab = list(set(flatten([t.get_leaf_texts() for t in (train_trees + test_trees + dev_trees)])))

word2index = {'<UNK>': 0}
wv = np.zeros(shape= (len(vocab), EMB_DIM))
for vo in vocab:
    if word2index.get(vo) is None:
        word2index[vo] = len(word2index)
        
        wv[word2index[vo]] = w2v_fasttext[vo]

# Model Training

In [9]:
MAX_LEN = 0
for tree in (train_trees + test_trees + dev_trees):
    MAX_LEN = max(len(tree.get_leaf_texts()), MAX_LEN)

hidden_size = 100
tracker_size = 100
output_size = 5
pad_token_index = 0

In [10]:
model = nlp_models.ThinStackHybridLSTM(wv, hidden_size, tracker_size, output_size, pad_token_index, trainable_embed=True)

In [11]:
train_data = nlp_models.ThinStackHybridLSTM.prepare_data(train_trees, word2index, max_len=MAX_LEN, pre_pad_index=pad_token_index, post_pad_index=pad_token_index)
test_data = nlp_models.ThinStackHybridLSTM.prepare_data(test_trees, word2index, max_len=MAX_LEN, pre_pad_index=pad_token_index, post_pad_index=pad_token_index)
dev_data = nlp_models.ThinStackHybridLSTM.prepare_data(dev_trees, word2index, max_len=MAX_LEN, pre_pad_index=pad_token_index, post_pad_index=pad_token_index)

In [12]:
model.train_model(train_data[0], train_data[1], train_data[2], train_data[3], 
                  epochs=30, batch_size=50,
                  validation_tokens=dev_data[0], validation_transitions=dev_data[1], 
                  validation_labels=dev_data[2], validation_token_labels=dev_data[3]
                 )

[0/30] mean_loss : 1.71
             precision    recall  f1-score   support

          0       0.03      0.06      0.04        50
          1       0.10      0.07      0.08       256
          2       0.74      0.38      0.50      1366
          3       0.16      0.57      0.25       249
          4       0.01      0.02      0.01        61

avg / total       0.55      0.34      0.39      1982

[0/30] mean_loss : 1.54
             precision    recall  f1-score   support

          0       0.00      0.00      0.00        46
          1       0.07      0.02      0.04       167
          2       0.67      0.89      0.77      1308
          3       0.10      0.01      0.02       300
          4       0.07      0.04      0.05        81

avg / total       0.49      0.62      0.54      1902

[0/30] mean_loss : 3.70
             precision    recall  f1-score   support

          0       0.04      0.35      0.07        51
          1       0.05      0.05      0.05       184
          2       0.

[0/30] mean_loss : 0.65
             precision    recall  f1-score   support

          0       0.67      0.12      0.21        48
          1       0.76      0.33      0.46       227
          2       0.80      0.95      0.87      1238
          3       0.61      0.54      0.57       257
          4       0.62      0.25      0.36        40

avg / total       0.76      0.77      0.75      1810

[0/30] mean_loss : 0.60
             precision    recall  f1-score   support

          0       0.67      0.12      0.21        32
          1       0.55      0.53      0.54       195
          2       0.83      0.96      0.89      1301
          3       0.81      0.39      0.53       249
          4       0.92      0.22      0.35        55

avg / total       0.80      0.80      0.77      1832

[0/30] mean_loss : 0.67
             precision    recall  f1-score   support

          0       0.50      0.08      0.14        62
          1       0.65      0.50      0.57       202
          2       0.

[0/30] mean_loss : 0.58
             precision    recall  f1-score   support

          0       0.80      0.08      0.15        49
          1       0.63      0.43      0.51       195
          2       0.87      0.95      0.91      1353
          3       0.63      0.71      0.67       281
          4       0.80      0.15      0.25        54

avg / total       0.81      0.82      0.80      1932

[0/30] mean_loss : 0.58
             precision    recall  f1-score   support

          0       0.75      0.08      0.14        77
          1       0.50      0.58      0.54       206
          2       0.88      0.93      0.91      1333
          3       0.62      0.57      0.59       226
          4       0.60      0.38      0.47        68

avg / total       0.79      0.80      0.78      1910

[0/30] mean_loss : 0.57
             precision    recall  f1-score   support

          0       0.71      0.12      0.21        40
          1       0.66      0.59      0.62       212
          2       0.

[0/30] mean_loss : 0.63
             precision    recall  f1-score   support

          0       0.44      0.10      0.16        70
          1       0.48      0.75      0.58       238
          2       0.92      0.88      0.90      1338
          3       0.65      0.65      0.65       270
          4       0.77      0.44      0.56        62

avg / total       0.81      0.79      0.79      1978

[0/30] mean_loss : 0.61
             precision    recall  f1-score   support

          0       0.80      0.12      0.21        33
          1       0.70      0.52      0.59       191
          2       0.85      0.96      0.90      1211
          3       0.68      0.62      0.65       234
          4       0.96      0.32      0.47        73

avg / total       0.82      0.82      0.80      1742

[0/30] mean_loss : 0.58
             precision    recall  f1-score   support

          0       0.50      0.07      0.12        45
          1       0.59      0.53      0.56       179
          2       0.

[0/30] mean_loss : 0.52
             precision    recall  f1-score   support

          0       1.00      0.05      0.10        40
          1       0.69      0.45      0.55       233
          2       0.87      0.94      0.90      1480
          3       0.65      0.69      0.67       291
          4       0.81      0.45      0.58        38

avg / total       0.82      0.82      0.81      2082

[0/30] mean_loss : 0.55
             precision    recall  f1-score   support

          0       0.80      0.08      0.15        49
          1       0.54      0.75      0.63       217
          2       0.91      0.93      0.92      1334
          3       0.69      0.65      0.67       281
          4       0.79      0.41      0.54        93

avg / total       0.83      0.82      0.82      1974

[0/30] mean_loss : 0.54
             precision    recall  f1-score   support

          0       0.67      0.10      0.17        61
          1       0.61      0.53      0.57       223
          2       0.

[0/30] mean_loss : 0.54
             precision    recall  f1-score   support

          0       0.50      0.08      0.14        25
          1       0.62      0.56      0.59       173
          2       0.88      0.96      0.92      1152
          3       0.68      0.55      0.61       208
          4       0.73      0.44      0.55        68

avg / total       0.81      0.83      0.81      1626

[0/30] mean_loss : 0.54
             precision    recall  f1-score   support

          0       0.77      0.34      0.48        58
          1       0.58      0.61      0.60       194
          2       0.90      0.91      0.91      1187
          3       0.64      0.70      0.67       262
          4       0.62      0.23      0.33        35

avg / total       0.81      0.81      0.81      1736

[0/30] mean_loss : 0.53
             precision    recall  f1-score   support

          0       0.60      0.14      0.23        64
          1       0.54      0.65      0.59       228
          2       0.

[0/30] mean_loss : 0.56
             precision    recall  f1-score   support

          0       0.75      0.07      0.13        42
          1       0.60      0.61      0.61       178
          2       0.90      0.92      0.91      1095
          3       0.66      0.74      0.70       305
          4       0.62      0.50      0.55        84

avg / total       0.81      0.81      0.80      1704

[0/30] mean_loss : 0.55
             precision    recall  f1-score   support

          0       1.00      0.09      0.16        47
          1       0.66      0.69      0.68       215
          2       0.88      0.95      0.91      1364
          3       0.57      0.47      0.51       268
          4       0.57      0.38      0.46        92

avg / total       0.80      0.81      0.79      1986

[0/30] mean_loss : 0.54
             precision    recall  f1-score   support

          0       0.80      0.13      0.23        60
          1       0.58      0.65      0.61       217
          2       0.

[0/30] mean_loss : 0.50
             precision    recall  f1-score   support

          0       0.82      0.21      0.33        43
          1       0.67      0.65      0.66       205
          2       0.90      0.95      0.92      1220
          3       0.70      0.67      0.68       248
          4       0.70      0.52      0.60        86

avg / total       0.83      0.84      0.83      1802

[0/30] mean_loss : 0.51
             precision    recall  f1-score   support

          0       0.60      0.10      0.17        31
          1       0.62      0.64      0.63       220
          2       0.88      0.94      0.91      1411
          3       0.68      0.56      0.62       282
          4       0.76      0.38      0.51        68

avg / total       0.81      0.82      0.81      2012

[0/30] mean_loss : 0.52
             precision    recall  f1-score   support

          0       0.83      0.17      0.28        59
          1       0.62      0.63      0.63       215
          2       0.

[0/30] mean_loss : 0.51
             precision    recall  f1-score   support

          0       0.83      0.17      0.29        58
          1       0.60      0.68      0.64       208
          2       0.88      0.93      0.90      1435
          3       0.68      0.63      0.66       301
          4       0.94      0.31      0.46        52

avg / total       0.82      0.82      0.81      2054

[0/30] mean_loss : 0.57
             precision    recall  f1-score   support

          0       0.70      0.49      0.58        71
          1       0.57      0.41      0.48       176
          2       0.85      0.95      0.90      1173
          3       0.68      0.61      0.64       265
          4       0.78      0.29      0.42        49

avg / total       0.79      0.80      0.79      1734

[0/30] mean_loss : 0.53
             precision    recall  f1-score   support

          0       0.48      0.20      0.29        49
          1       0.56      0.76      0.65       195
          2       0.

[1/30] mean_loss : 0.46
             precision    recall  f1-score   support

          0       0.82      0.21      0.33        43
          1       0.65      0.56      0.60       214
          2       0.89      0.93      0.91      1383
          3       0.66      0.72      0.69       295
          4       0.81      0.34      0.48        65

avg / total       0.82      0.83      0.82      2000

[1/30] mean_loss : 0.45
             precision    recall  f1-score   support

          0       0.86      0.17      0.28        36
          1       0.63      0.64      0.63       238
          2       0.88      0.94      0.91      1380
          3       0.76      0.61      0.67       272
          4       0.81      0.46      0.59        48

avg / total       0.83      0.83      0.82      1974

[1/30] mean_loss : 0.43
             precision    recall  f1-score   support

          0       0.77      0.35      0.49        48
          1       0.63      0.48      0.54       151
          2       0.

[1/30] mean_loss : 0.44
             precision    recall  f1-score   support

          0       0.70      0.56      0.62        86
          1       0.63      0.65      0.64       219
          2       0.92      0.94      0.93      1284
          3       0.72      0.70      0.71       227
          4       0.59      0.44      0.51        50

avg / total       0.84      0.85      0.84      1866

[1/30] mean_loss : 0.45
             precision    recall  f1-score   support

          0       0.75      0.45      0.56        20
          1       0.69      0.55      0.61       172
          2       0.89      0.96      0.92      1238
          3       0.70      0.67      0.69       264
          4       0.77      0.34      0.47        68

avg / total       0.84      0.85      0.84      1762

[1/30] mean_loss : 0.50
             precision    recall  f1-score   support

          0       0.56      0.20      0.30        49
          1       0.62      0.84      0.71       210
          2       0.

[1/30] mean_loss : 0.44
             precision    recall  f1-score   support

          0       0.75      0.19      0.31        31
          1       0.75      0.66      0.70       219
          2       0.88      0.97      0.92      1455
          3       0.77      0.53      0.63       283
          4       0.60      0.50      0.55        52

avg / total       0.84      0.85      0.84      2040

[1/30] mean_loss : 0.47
             precision    recall  f1-score   support

          0       0.70      0.25      0.37        56
          1       0.57      0.64      0.60       197
          2       0.92      0.93      0.92      1312
          3       0.67      0.71      0.69       267
          4       0.72      0.63      0.67        84

avg / total       0.83      0.83      0.83      1916

[1/30] mean_loss : 0.47
             precision    recall  f1-score   support

          0       1.00      0.26      0.41        43
          1       0.61      0.78      0.68       228
          2       0.

KeyboardInterrupt: 