In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!cp -r /content/drive/MyDrive/textclassification/* .

^C


In [None]:
!pip3 install pytreebank



In [None]:
import pytreebank
import numpy as np
import torch
import sys
import argparse
from glove_reader import GloveReader
from gensim.utils import tokenize
import itertools
from model import RNN
import time
from dataloader import SST
import pandas as pd

def transform_and_pad(data):
    """
        Given the input dataset, returns three tensors of the padded data (on the left!)
        Returns:
            content - Tensor Nxd
            labels  - Tensor Nx1
            mask    - Tensor Nxd (binary)
        
    """
    max_len = max(map(lambda x: len(x[1:]), data))
    labels = torch.tensor(list(map(lambda x: x[0], data)))
    data = list(map(lambda x: x[1:], data))
    content = torch.tensor([(0,)*(max_len - len(x)) + x for x in data])
    content_mask = torch.tensor([(0,)*(max_len - len(x)) + (1,) * len(x) for x in data])
    return content, labels, content_mask

In [None]:
def validate(model, dev_dataset):
    dev = torch.utils.data.DataLoader(dev_dataset, batch_size=len(dev_dataset), num_workers=4, shuffle=False)
    
    model.eval()
    
    loss_fn = torch.nn.CrossEntropyLoss()

    total_count = 0
    num_correct = 0
    tot_loss = 0.0
    device = "cuda" if torch.cuda.is_available() else "cpu"
    for batch in dev:
        data, labels, mask = batch
        data = data.to(device)
        labels = labels.to(device)
        mask = mask.to(device)

        output = model(data,mask)
        tot_loss += loss_fn(output, labels).item()
        
        pred = torch.argmax(output, 1)
        num_correct += (pred == labels).sum().item()
        total_count += pred.size(0)

    model.train()
    return num_correct / total_count, tot_loss


In [None]:
def train(model, train_dataset, dev_dataset, max_epochs=100, model_name='model.save', stopping_counter=20):

    losses = []
    accs = []
    dev_losses = []
    dev_accs = []
    

    optimizer = torch.optim.Adam(model.parameters())

    loss_fn = torch.nn.CrossEntropyLoss()
    
    best_loss = float('+inf')
    best_model = model
    
    counter = 0

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    train = torch.utils.data.DataLoader(train_dataset, batch_size=256, num_workers=2, shuffle=True)
    model.train()
    for epoch in range(max_epochs):
        print("-" * 10 , "EPOCH ", epoch,  "-"*10)
        
        num_correct = 0.0
        total_count = 0.0
        start = time.time()
        epoch_loss = 0.0
        
        for i, batch in enumerate(train):
            if i + 1 % 100 == 0:
                print('Batch ', i)
            
            data, labels, mask = batch
            
            data = data.to(device)
            labels = labels.to(device)
            mask = mask.to(device)
            # print(data.device, mask.device, labels.device)

            optimizer.zero_grad()

            output = model(data, mask)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()

            pred = torch.argmax(output, 1)
            num_correct += (pred == labels).sum().item()
            total_count += pred.size(0)
            epoch_loss += loss.item()
        
        
        losses.append(epoch_loss / (i+1))
        accs.append(num_correct / total_count)
    
        end = time.time()
        
        counter += 1 # for early stopping
        
        eval_acc, eval_loss = validate(model, dev_dataset)
        
        print(f'Loss={losses[-1]}, Accuracy={accs[-1]}, Dev Accuracy={eval_acc}, epoch took {end - start}s')
        
        dev_losses.append(eval_loss)
        dev_accs.append(eval_acc)
        
        if eval_loss < best_loss:
            best_loss = eval_loss
            best_model = model
            counter = 0
            print("Saving new best model...")
            torch.save(best_model.state_dict(), model_name)
        
        
        if counter == stopping_counter:
            return losses, accs, dev_losses, dev_accs
    
    return losses, accs, dev_losses, dev_accs


In [None]:
def load_data(path):
    data = pytreebank.import_tree_corpus(path)
    data = list(map(lambda x: x.to_labeled_lines()[0], data))

    # tokenizes and filters those words that exist in the dictionary for each example
    data = list(map(lambda x: (x[0], list(filter(lambda x: x in glove.words2idx, tokenize(x[1], lower=True)))), data))

    # transforms words into numbers
    data = list(map(lambda x: (x[0],*list(map(lambda y: glove.words2idx[y], x[1]))), data))

    # pad data and transform tensor
    content, labels, mask = transform_and_pad(data)
    del data
    
    return content, labels, mask

In [None]:
#loads glove embeddings
glove = GloveReader()

In [None]:
#load train dataset
train_data = SST(*load_data('./trees/train.txt'))

In [None]:
# load dev dataset
dev_data = SST(*load_data('./trees/dev.txt'))

In [None]:
from itertools import product
configuration = {
    'dropout': [0, 0.2],
    'hidden_size' : [256, 512],
    'n_layers': [1, 3],
    'embeddings': [glove.embeddings, None]
}

In [None]:
model_params = list(product(*configuration.values()))
param_names = list(configuration)

In [None]:
param_names

['dropout', 'hidden_size', 'n_layers', 'embeddings']

In [None]:
import pickle
for params in model_params[9:]:
    info = f'{params[0]}_{params[1]}_{params[2]}_{"glove" if params[3] is not None else "default"}'
    model_name = f'/content/drive/MyDrive/textclassification/models/model_{info}'
    data_name = f'/content/drive/MyDrive/textclassification/data/data_{info}'

    m = model.RNN(300, params[2], params[1], 5, pretrained_embeddings=params[3], dropout=params[0])
    
    data = train(m, train_data, dev_data, model_name=model_name)
    
    with open(data_name, 'wb') as f:
        pickle.dump(data, f)
      
    del m
    del data

---------- EPOCH  0 ----------


  cpuset_checked))


Loss=1.5901698785669662, Accuracy=0.2598314606741573, Dev Accuracy=0.259763851044505, epoch took 25.209333658218384s
Saving new best model...
---------- EPOCH  1 ----------
Loss=1.5506299383500044, Accuracy=0.330875468164794, Dev Accuracy=0.2633969118982743, epoch took 25.141629457473755s
Saving new best model...
---------- EPOCH  2 ----------
Loss=1.5075141857652103, Accuracy=0.3838951310861423, Dev Accuracy=0.28065395095367845, epoch took 25.120407581329346s
Saving new best model...
---------- EPOCH  3 ----------
Loss=1.4516119220677544, Accuracy=0.4524812734082397, Dev Accuracy=0.2524977293369664, epoch took 25.234118700027466s
---------- EPOCH  4 ----------
Loss=1.3972482751397526, Accuracy=0.5117041198501873, Dev Accuracy=0.2888283378746594, epoch took 25.06570839881897s
---------- EPOCH  5 ----------
Loss=1.3459703922271729, Accuracy=0.568000936329588, Dev Accuracy=0.28973660308810173, epoch took 25.03239417076111s
---------- EPOCH  6 ----------
Loss=1.3046011013143204, Accuracy=