In [None]:
!cp /content/drive/MyDrive/textclassification/glove_reader.py .
!cp /content/drive/MyDrive/textclassification/model.py .
!cp /content/drive/MyDrive/textclassification/dataloader.py .
!cp /content/drive/MyDrive/textclassification/glove.6B.300d.txt .
!cp /content/drive/MyDrive/textclassification/dataloader.py .

In [None]:
!pip3 install pytreebank



In [None]:
import pickle
import pandas, numpy as np
import torch, torch.nn as nn
from model import RNN
from dataloader import SST
import pytreebank
from glove_reader import GloveReader

In [None]:
from gensim.utils import tokenize
def transform_and_pad(data):
    """
        Given the input dataset, returns three tensors of the padded data (on the left!)
        Returns:
            content - Tensor Nxd
            labels  - Tensor Nx1
            mask    - Tensor Nxd (binary)
        
    """
    max_len = max(map(lambda x: len(x[1:]), data))
    labels = torch.tensor(list(map(lambda x: x[0], data)))
    data = list(map(lambda x: x[1:], data))
    content = torch.tensor([(0,)*(max_len - len(x)) + x for x in data])
    content_mask = torch.tensor([(0,)*(max_len - len(x)) + (1,) * len(x) for x in data])
    return content, labels, content_mask
    
def load_data(path):
    data = pytreebank.import_tree_corpus(path)
    data = list(map(lambda x: x.to_labeled_lines()[0], data))

    # tokenizes and filters those words that exist in the dictionary for each example
    data = list(map(lambda x: (x[0], list(filter(lambda x: x in glove.words2idx, tokenize(x[1], lower=True)))), data))

    # transforms words into numbers
    data = list(map(lambda x: (x[0],*list(map(lambda y: glove.words2idx[y], x[1]))), data))

    # pad data and transform tensor
    content, labels, mask = transform_and_pad(data)
    del data
    
    return content, labels, mask

In [None]:
glove = GloveReader()

In [None]:
!cp -r /content/drive/MyDrive/textclassification/trees .

In [None]:
test_data = SST(*load_data('./trees/test.txt'))
train_data = SST(*load_data('./trees/train.txt'))
dev_data = SST(*load_data('./trees/dev.txt'))

In [None]:
from itertools import product
configuration = {
    'dropout': [0, 0.2],
    'hidden_size' : [256, 512],
    'n_layers': [1, 3],
    'embeddings': [glove.embeddings, None]
}
model_params = list(product(*configuration.values()))

In [None]:
def validate(model, dataset):
    dev = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), num_workers=2, shuffle=False)
    
    model.eval()
    
    loss_fn = torch.nn.CrossEntropyLoss()

    total_count = 0
    num_correct = 0
    tot_loss = 0.0
    device = "cuda" if torch.cuda.is_available() else "cpu"
    for batch in dev:
        data, labels, mask = batch
        data = data.to(device)
        labels = labels.to(device)
        mask = mask.to(device)

        output = model(data,mask)
        tot_loss += loss_fn(output, labels).item()
        
        pred = torch.argmax(output, 1)
        num_correct += (pred == labels).sum().item()
        total_count += pred.size(0)
        del data
        del labels
        del mask

    model.train()
    return num_correct / total_count, tot_loss


In [None]:
import gc
gc.collect()

166365

In [None]:
import pickle
results = []
device = "cuda" if torch.cuda.is_available() else "cpu"
for params in model_params:
  info = f'{params[0]}_{params[1]}_{params[2]}_{"glove" if params[3] is not None else "default"}'
  model_name = f'/content/drive/MyDrive/textclassification/models/model_{info}'
  data_name = f'/content/drive/MyDrive/textclassification/data/data_{info}'

  with open(data_name, 'rb') as f:
    data = pickle.load(f)
  
  m = RNN(300, params[2], params[1], 5, pretrained_embeddings=params[3], dropout=params[0])

  m.load_state_dict(torch.load(model_name, map_location=device))

  test_acc, test_loss = validate(m, test_data)
  train_acc, train_loss = validate(m, train_data)
  dev_acc, dev_loss = validate(m, dev_data)
  results.append({'test_accuracy':test_acc, 'train_acc': train_acc, 'dev_acc': dev_acc, 'test_loss': test_loss, 'train_loss':train_loss, 'dev_loss': dev_loss, 'name': info})

  gc.collect()
  

In [None]:
results

[{'dev_acc': 0.3923705722070845,
  'dev_loss': 1.4872496128082275,
  'name': '0_256_1_glove',
  'test_accuracy': 0.3927601809954751,
  'test_loss': 1.489555835723877,
  'train_acc': 0.5801732209737828,
  'train_loss': 1.333773136138916},
 {'dev_acc': 0.3142597638510445,
  'dev_loss': 1.5634868144989014,
  'name': '0_256_1_default',
  'test_accuracy': 0.32036199095022627,
  'test_loss': 1.562217354774475,
  'train_acc': 0.5470505617977528,
  'train_loss': 1.3634068965911865},
 {'dev_acc': 0.405086285195277,
  'dev_loss': 1.4830974340438843,
  'name': '0_256_3_glove',
  'test_accuracy': 0.40180995475113124,
  'test_loss': 1.4855982065200806,
  'train_acc': 0.5709269662921348,
  'train_loss': 1.326481819152832},
 {'dev_acc': 0.26067211625794734,
  'dev_loss': 1.5761791467666626,
  'name': '0_256_3_default',
  'test_accuracy': 0.265158371040724,
  'test_loss': 1.5806573629379272,
  'train_acc': 0.32982209737827717,
  'train_loss': 1.5627070665359497},
 {'dev_acc': 0.37420526793823794,
  'd

In [None]:
with open ('/content/drive/MyDrive/textclassification/results.pickle', 'wb') as f:
  pickle.dump(results, f)

In [None]:
for i,res in enumerate(results):
  dropout, hidden_size, n_layers, embeddings = res['name'].split('_')
  print(f'{dropout} & {hidden_size} & {n_layers} & {embeddings} & {round(res["test_accuracy"], 4)} & {round(res["test_loss"], 4)} & {round(res["train_acc"],4)} & {round(res["dev_acc"], 4)}\\\\')
  if (i+1) % 2 == 0:
    print('\\hline')

0 & 256 & 1 & glove & 0.3928 & 1.4896 & 0.5802 & 0.3924\\
0 & 256 & 1 & default & 0.3204 & 1.5622 & 0.5471 & 0.3143\\
\hline
0 & 256 & 3 & glove & 0.4018 & 1.4856 & 0.5709 & 0.4051\\
0 & 256 & 3 & default & 0.2652 & 1.5807 & 0.3298 & 0.2607\\
\hline
0 & 512 & 1 & glove & 0.371 & 1.5067 & 0.5218 & 0.3742\\
0 & 512 & 1 & default & 0.2733 & 1.5785 & 0.336 & 0.2552\\
\hline
0 & 512 & 3 & glove & 0.3747 & 1.5247 & 0.6114 & 0.3787\\
0 & 512 & 3 & default & 0.2937 & 1.5758 & 0.3 & 0.2779\\
\hline
0.2 & 256 & 1 & glove & 0.4131 & 1.4777 & 0.5715 & 0.4087\\
0.2 & 256 & 1 & default & 0.324 & 1.5624 & 0.7869 & 0.3079\\
\hline
0.2 & 256 & 3 & glove & 0.3937 & 1.498 & 0.4306 & 0.3896\\
0.2 & 256 & 3 & default & 0.2914 & 1.5831 & 0.3828 & 0.2888\\
\hline
0.2 & 512 & 1 & glove & 0.3941 & 1.4842 & 0.5242 & 0.3815\\
0.2 & 512 & 1 & default & 0.3208 & 1.5707 & 0.8704 & 0.3152\\
\hline
0.2 & 512 & 3 & glove & 0.3575 & 1.5334 & 0.5138 & 0.3688\\
0.2 & 512 & 3 & default & 0.2864 & 1.5874 & 0.3663 & 0.2997\