In [1]:
import torch
import pandas as pd
from TorchDataUtils import *
from NLPDataUtils import *

%matplotlib notebook
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

# Entrenamiento de word vectors para clasificación

## Pretraining

In [2]:
class WordEmbeddingAGNewsDataset(torch.utils.data.Dataset):
    
    def __init__(self, root='./AG_NEWS/', preprocess=lambda x: x, n_window=2, train=True):
        
        target = 'train.csv' if train else 'test.csv'
        df = pd.read_csv(root + target, header=None, names=['class_idx', 'title', 'description'])
        data = preprocess(df['title'])
        samples = data.apply(self._get_context, n_window=n_window)
        self.data = pd.DataFrame([[subsample[0], subsample[1]] for sample in samples for subsample in sample],
                                 columns=['word','context'])
    
    def __getitem__(self,idx):
        if type(idx) == torch.Tensor:
            idx = idx.item()
        
        context_vector = self.vectorizer.vectorize(self.data['context'].iloc[idx])
        word_index = self.vectorizer.vocabulary.token_to_index(self.data['word'].iloc[idx])
        return context_vector, word_index
    
    def _get_context(self,sentence,n_window):
        
        no_sentence = '<NS>'
        samples = []
        for i, word in enumerate(sentence):
            first_context_word_index = max(0,i-n_window)
            last_context_word_index = min(i+n_window+1, len(sentence))

            context = [no_sentence for j in range(i-n_window,first_context_word_index)] + \
                      sentence[first_context_word_index:i] + \
                      sentence[i+1:last_context_word_index] + \
                      [no_sentence for j in range(last_context_word_index,i+n_window+1)]
            
            samples.append((word,context))

        return samples
    
    
    def __len__(self):
        return len(self.data)
    
    
def GetAGNewsWordEmbeddingsDataset(root, preprocess, n_window=2, cutoff=25):
    
    # Datasets:
    train_dataset = WordEmbeddingAGNewsDataset(root, preprocess=preprocess, n_window=2, train=True)
    train_dataset.vectorizer = Vectorizer([train_dataset.data['word']], cutoff=cutoff)
    test_dataset = WordEmbeddingAGNewsDataset(root, preprocess=preprocess, n_window=2, train=False)
    test_dataset.vectorizer = train_dataset.vectorizer
    
    # Dataloaders:
    train_dataloader, val_dataloader, test_dataloader = generate_data_batches(train_dataset, 
                                                                              test_dataset,    
                                                                              batch_size=64)
    return train_dataloader, val_dataloader, test_dataloader


def preprocess(data):
    df = data.str.replace(r'\(AP\)','')
    df = df.str.replace(r'\(Reuters\)','')
    df = df.str.replace(r'\(AFP\)','')
    df = df.str.replace(r'\(SPACE\.com\)','')
    df = df.str.replace(r'\ba\b','')
    df = df.str.replace(r'\bthe\b','')
    df = df.str.replace(r'\bis\b','')
    df = df.str.replace(r'\bof\b','')
    df = df.str.replace(r'\bto\b','')
    df = df.str.replace(r'[,:;\?\!\"]','')
    df = df.str.replace(r'\s+','<SEP>')
    df = df.str.replace(r"'s<SEP>","<SEP>'s<SEP>")
    df = df.str.split('<SEP>')
    return df


we_train_dataloader, we_val_dataloader, we_test_dataloader = GetAGNewsWordEmbeddingsDataset(root='./AG_NEWS/', 
                                                                               preprocess=preprocess, 
                                                                               cutoff=25)

In [3]:
import torch.nn as nn

class Word2VecCBOW(nn.Module):
    
    def __init__(self, vocab_size, n_embeddings):
        super(Word2VecCBOW, self).__init__()
        self.emb = nn.Linear(vocab_size, n_embeddings)
        self.out = nn.Linear(n_embeddings, vocab_size)
        
    def forward(self, x):
        return self.out(self.emb(x))
    
    def loss(self, scores, target):
        lf = nn.CrossEntropyLoss()
        return lf(scores, target)
    

vocab_size = len(we_train_dataloader.dataset.vectorizer.vocabulary)
n_embeddings = 100
EmbeddingModel = Word2VecCBOW(vocab_size, n_embeddings)

In [4]:
def CheckAccuracy(loader, model, device, input_dtype, target_dtype):  
    num_correct = 0
    num_samples = 0
    model.eval()  
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=input_dtype)  
            y = y.to(device=device, dtype=target_dtype)
            
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)

        return num_correct, num_samples
        

def SGDTrainModel(model, data, epochs=1, learning_rate=1e-2, sample_loss_every=100, check_on_train=False, verbose=True):
    
    try:
        input_dtype = data['input_dtype'] 
        target_dtype = data['target_dtype']
    except KeyError:
        print('Input or target data type not correctly defined')
        return
    
    try:
        device = torch.device('cuda:1') if torch.cuda.is_available() and data['use_gpu'] else torch.device('cpu')
    except KeyError:
        print('Device not specified')
        return
    
    try:
        train_dataloader = data['train_dataloader']
        val_dataloader = data['val_dataloader']
    except KeyError:
        print('Train or Validation dataloaders not defined')
        return
    
    performance_history = {'iter': [], 'loss': [], 'accuracy': []}
    
    model = model.to(device=device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    batch_size = len(train_dataloader)
    
    try:
    
        for e in range(epochs):
            for t, (x,y) in enumerate(train_dataloader):
                model.train()
                x = x.to(device=device, dtype=input_dtype)
                y = y.to(device=device, dtype=target_dtype)

                # Forward pass
                scores = model(x) 

                # Backward pass
                loss = model.loss(scores,y)                 
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if (e * batch_size + t) % sample_loss_every == 0:
                    num_correct_val, num_samples_val = CheckAccuracy(val_dataloader, model, device, input_dtype, target_dtype)
                    performance_history['iter'].append(e * batch_size + t)
                    performance_history['loss'].append(loss.item())
                    performance_history['accuracy'].append(float(num_correct_val) / num_samples_val)
                    if verbose:
                        print('Epoch: {}, Batch number: {}'.format(e, t))
                        print('Accuracy on validation dataset: {}/{} ({:.2f}%)'.format(num_correct_val, num_samples_val, 100 * float(num_correct_val) / num_samples_val))
                    
                    if check_on_train and verbose:
                        num_correct_train, num_samples_train = CheckAccuracy(train_dataloader, model, device, input_dtype, target_dtype)
                        print('Accuracy on train dataset: {}/{} ({:.2f}%)'.format(num_correct_train, num_samples_train, 100 * float(num_correct_train) / num_samples_train))
                        print()
                    elif verbose:
                        print()
                        
        return performance_history
                    
    except KeyboardInterrupt:
        
        print('Exiting training...')
        print('Final accuracy registered on validation dataset: {}/{} ({:.2f}%)'.format(num_correct_val, num_samples_val, 100 * float(num_correct_val) / num_samples_val) )
        if check_on_train:
            num_correct_train, num_samples_train = CheckAccuracy(train_dataloader, model, device, input_dtype, target_dtype)
            print('Final accuracy registered on train dataset: {}/{} ({:.2f}%)'.format(num_correct_train, num_samples_train, 100 * float(num_correct_train) / num_samples_train))
            
        return performance_history

In [11]:
# Parámetros de las muestras:
data = {
    'use_gpu': True, # Trasladar o no las muestras a la GPU
    'input_dtype': torch.float, # Tipo de dato de las muestras de entrada
    'target_dtype': torch.long, # Tipo de dato de las muestras de salida
    'train_dataloader': we_train_dataloader, # Dataset de entrenamiento
    'val_dataloader': we_val_dataloader # Dataset de validación
}

# Parámetros de optimización:
epochs = 10 # Cantidad de epochs
sample_loss_every = 500 # Cantidad de iteraciones para calcular la cantidad de aciertos
learning_rate = [1e-3, 1e-4] # Tasa de aprendizaje
check_on_train = False # Queremos ver los resultados también en el train set

# Entrenamiento:
performance_history = []
for lr in learning_rate:
    EmbeddingModel = Word2VecCBOW(vocab_size, n_embeddings)
    performance_history.append(SGDTrainModel(EmbeddingModel, 
                                             data, 
                                             epochs, 
                                             lr, 
                                             sample_loss_every, 
                                             check_on_train, 
                                             verbose=True))
    print('lr={:.2g} completed.'.format(lr))

Epoch: 0, Batch number: 0
Accuracy on validation dataset: 0/15547 (0.00%)

Epoch: 0, Batch number: 500
Accuracy on validation dataset: 288/15547 (1.85%)

Epoch: 0, Batch number: 1000
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 1500
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 2000
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 2500
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 3000
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 3500
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 4000
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 4500
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 5000
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 5500
Accuracy on validation dataset: 290/15547 (1.87%)

Epoch: 0, Batch number: 6000
Accuracy on valid

Epoch: 4, Batch number: 3888
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 4388
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 4888
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 5388
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 5888
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 6388
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 6888
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 7388
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 7888
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 8388
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 8888
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 9388
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 4, Batch number: 9888
Accuracy on

Epoch: 8, Batch number: 7776
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 8276
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 8776
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 9276
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 9776
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 10276
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 10776
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 11276
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 8, Batch number: 11776
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 9, Batch number: 373
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 9, Batch number: 873
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 9, Batch number: 1373
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 9, Batch number: 1873
Accuracy 

Epoch: 2, Batch number: 11194
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 2, Batch number: 11694
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 291
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 791
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 1291
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 1791
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 2291
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 2791
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 3291
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 3791
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 4291
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 4791
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 3, Batch number: 5291
Accuracy on

Epoch: 7, Batch number: 3179
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 3679
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 4179
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 4679
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 5179
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 5679
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 6179
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 6679
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 7179
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 7679
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 8179
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 8679
Accuracy on validation dataset: 316/15547 (2.03%)

Epoch: 7, Batch number: 9179
Accuracy on

In [12]:
fig, ax = plt.subplots()
for lr, ph in zip(learning_rate, performance_history):
    ax.plot(ph['iter'],ph['loss'],label='lr={:.2g}'.format(lr))
    
ax.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f9a124acf28>

## Finetuning

In [9]:
class AGNewsDataset(torch.utils.data.Dataset):
    
    def __init__(self, root='./AG_NEWS/', preprocess=lambda x: x, train=True):
        
        target = 'train.csv' if train else 'test.csv'
        df = pd.read_csv(root + target, header=None, names=['class_idx', 'title', 'description'])
        
        # Etiquetas:
        self.cls_indeces = torch.tensor(df['class_idx'].tolist(), dtype=torch.long) - 1
        
        # DataSeries con las muestras de entradas:
        data = df['title']
        self.data = preprocess(data)
    
    def __getitem__(self,idx):
        if type(idx) == torch.Tensor:
            idx = idx.item()
        return self.vectorizer.vectorize(self.data.iloc[idx]), self.cls_indeces[idx]
    
    def __len__(self):
        return len(self.cls_indeces)
    
    
def GetAGNewsDataset(root, preprocess, cutoff=25):
    
    # Datasets:
    train_dataset = AGNewsDataset(root, preprocess=preprocess, train=True)
    train_dataset.vectorizer = Vectorizer([train_dataset.data], cutoff=cutoff)
    test_dataset = AGNewsDataset(root, preprocess=preprocess, train=False)
    test_dataset.vectorizer = train_dataset.vectorizer
    
    # Dataloaders:
    train_dataloader, val_dataloader, test_dataloader = generate_data_batches(train_dataset, 
                                                                              test_dataset,    
                                                                              batch_size=64)
    return train_dataloader, val_dataloader, test_dataloader

In [10]:
tc_train_dataloader, tc_val_dataloader, tc_test_dataloader = GetAGNewsDataset(root='./AG_NEWS/', 
                                                                     preprocess=preprocess, 
                                                                     cutoff=25)

import torch.nn as nn

class TextClassifier(nn.Module):
    
    def __init__(self, EmbeddingLayer, n_classes):
        super(TextClassifier, self).__init__()
        self.emb = EmbeddingLayer
#         for param in self.emb.parameters():
#             param.requires_grad = True
        self.out = nn.Linear(self.emb.out_features, n_classes)
        
    def forward(self, x):
        return self.out(self.emb(x))
    
    def loss(self, scores, target):
        lf = nn.CrossEntropyLoss()
        return lf(scores, target)
    

n_classes = 4
ClassifierModel = TextClassifier(EmbeddingModel.emb, n_classes)

In [13]:
# Parámetros de las muestras:
data = {
    'use_gpu': True, # Trasladar o no las muestras a la GPU
    'input_dtype': torch.float, # Tipo de dato de las muestras de entrada
    'target_dtype': torch.long, # Tipo de dato de las muestras de salida
    'train_dataloader': tc_train_dataloader, # Dataset de entrenamiento
    'val_dataloader': tc_val_dataloader # Dataset de validación
}

# Parámetros de optimización:
epochs = 10 # Cantidad de epochs
sample_loss_every = 500 # Cantidad de iteraciones para calcular la cantidad de aciertos
learning_rate = 1e-4 # Tasa de aprendizaje
check_on_train = False # Queremos ver los resultados también en el train set

# Entrenamiento:
performance_history = SGDTrainModel(ClassifierModel, data, epochs, learning_rate, sample_loss_every, check_on_train)

Epoch: 0, Batch number: 0
Accuracy on validation dataset: 1276/2400 (53.17%)

Epoch: 0, Batch number: 500
Accuracy on validation dataset: 1269/2400 (52.88%)

Epoch: 0, Batch number: 1000
Accuracy on validation dataset: 1275/2400 (53.12%)

Epoch: 0, Batch number: 1500
Accuracy on validation dataset: 1293/2400 (53.88%)

Epoch: 1, Batch number: 162
Accuracy on validation dataset: 1285/2400 (53.54%)

Epoch: 1, Batch number: 662
Accuracy on validation dataset: 1323/2400 (55.12%)

Epoch: 1, Batch number: 1162
Accuracy on validation dataset: 1268/2400 (52.83%)

Epoch: 1, Batch number: 1662
Accuracy on validation dataset: 1286/2400 (53.58%)

Epoch: 2, Batch number: 324
Accuracy on validation dataset: 1314/2400 (54.75%)

Epoch: 2, Batch number: 824
Accuracy on validation dataset: 1308/2400 (54.50%)

Epoch: 2, Batch number: 1324
Accuracy on validation dataset: 1281/2400 (53.38%)

Epoch: 2, Batch number: 1824
Accuracy on validation dataset: 1252/2400 (52.17%)

Epoch: 3, Batch number: 486
Accuracy