In [2]:
from torchtext import datasets
from torchtext.data import to_map_style_dataset
import numpy as np
from torchtext.vocab import GloVe, vocab
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
import torch
import torch.nn as nn

# Load the dataset
train_iter, test_iter = datasets.AG_NEWS(split=('train', 'test'))

train_ds = to_map_style_dataset(train_iter)
test_ds = to_map_style_dataset(test_iter)


In [3]:
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer("basic_english")

my_vocab = build_vocab_from_iterator(map(lambda x: tokenizer(x[1]), train_ds), specials=['<pad>','<unk>'])
my_vocab.set_default_index(my_vocab["<unk>"])

In [18]:
from torch.utils.data import DataLoader
import torch
num_class = len(set([label for (label, _) in train_iter]))

text_pipeline = lambda x: my_vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    label_list, text_list = [], []
    for sample in batch:
        label, text = sample
        text_list.append(torch.tensor(text_pipeline(text), dtype=torch.long))
        label_list.append(label_pipeline(label))
    return torch.tensor(label_list, dtype=torch.long), torch.nn.utils.rnn.pad_sequence(text_list, batch_first=True, padding_value=my_vocab["<pad>"])

train_dataloader = DataLoader(
    train_iter, batch_size=64, shuffle=True, collate_fn=collate_batch
)

test_dataloader = DataLoader(
    test_iter, batch_size=64, shuffle=True, collate_fn=collate_batch
)

In [14]:
glove_vectors = GloVe(name='6B', dim=300)


tensor([-3.3712e-01, -2.1691e-01, -6.6365e-03, -4.1625e-01, -1.2555e+00,
        -2.8466e-02, -7.2195e-01, -5.2887e-01,  7.2085e-03,  3.1997e-01,
         2.9425e-02, -1.3236e-02,  4.3511e-01,  2.5716e-01,  3.8995e-01,
        -1.1968e-01,  1.5035e-01,  4.4762e-01,  2.8407e-01,  4.9339e-01,
         6.2826e-01,  2.2888e-01, -4.0385e-01,  2.7364e-02,  7.3679e-03,
         1.3995e-01,  2.3346e-01,  6.8122e-02,  4.8422e-01, -1.9578e-02,
        -5.4751e-01, -5.4983e-01, -3.4091e-02,  8.0017e-03, -4.3065e-01,
        -1.8969e-02, -8.5670e-02, -8.1123e-01, -2.1080e-01,  3.7784e-01,
        -3.5046e-01,  1.3684e-01, -5.5661e-01,  1.6835e-01, -2.2952e-01,
        -1.6184e-01,  6.7345e-01, -4.6597e-01, -3.1834e-02, -2.6037e-01,
        -1.7797e-01,  1.9436e-02,  1.0727e-01,  6.6534e-01, -3.4836e-01,
         4.7833e-02,  1.6440e-01,  1.4088e-01,  1.9204e-01, -3.5009e-01,
         2.6236e-01,  1.7626e-01, -3.1367e-01,  1.1709e-01,  2.0378e-01,
         6.1775e-01,  4.9075e-01, -7.5210e-02, -1.1

In [12]:
from torchtext.vocab import vocab
from collections import Counter, OrderedDict

unk_token = "<unk>"
unk_index = 0
glove_vocab = vocab(glove_vectors.stoi)
glove_vocab.insert_token("<unk>",unk_index)
glove_vocab.set_default_index(unk_index)



In [24]:
# versión 1
def pretrain_embeddings(tokens, fastText):
    
    embeddings = [fastText[token] for token in tokens]

    return torch.stack( list( map(torch.tensor, embeddings) ) )

embeddings = pretrain_embeddings(my_vocab.get_itos(), glove_vectors)


  return torch.stack( list( map(torch.tensor, embeddings) ) )


torch.Size([95812, 300])

In [16]:
import torch
import torch.nn as nn


class LSTMTextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class, pretrained_embeddings):
        super(LSTMTextClassificationModel, self).__init__()
        
        self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=True)

        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_class)

    def forward(self, text):
        embedded = self.embedding(text)  
        lstm_out, _ = self.lstm(embedded)
        last_output = lstm_out[:, -1, :]
        output = self.fc(last_output)
        return output

vocab_size = len(my_vocab)  
embed_dim = 300
hidden_dim = 64
num_class = 4

model = LSTMTextClassificationModel(vocab_size, embed_dim, hidden_dim, num_class, embeddings)



In [22]:
import time

# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

def train(dataloader):
    model.train()
    total_acc, total_count, max_acc = 0, 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| {:5d} batches '
                  '| accuracy {:8.3f}'.format(idx, total_acc / total_count))

            if max_acc < total_acc / total_count:
                max_acc = total_acc / total_count
                
            total_acc, total_count = 0, 0
            start_time = time.time()
    return max_acc


def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            predicted_label = model(text)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [23]:
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()

    accu_train = train(train_dataloader)
    accu_val = evaluate(test_dataloader)

    #if accu_train > accu_val:
    #    scheduler.step()
    
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(
            epoch, time.time() - epoch_start_time, accu_val
        )
    )
    print("-" * 59)

|   500 batches | accuracy    0.252
|  1000 batches | accuracy    0.255
|  1500 batches | accuracy    0.269
-----------------------------------------------------------
| end of epoch   1 | time: 67.44s | valid accuracy    0.834 
-----------------------------------------------------------
|   500 batches | accuracy    0.828
|  1000 batches | accuracy    0.875
|  1500 batches | accuracy    0.902
-----------------------------------------------------------
| end of epoch   2 | time: 74.78s | valid accuracy    0.892 
-----------------------------------------------------------
|   500 batches | accuracy    0.889
|  1000 batches | accuracy    0.900
|  1500 batches | accuracy    0.914
-----------------------------------------------------------
| end of epoch   3 | time: 84.86s | valid accuracy    0.893 
-----------------------------------------------------------
|   500 batches | accuracy    0.901
|  1000 batches | accuracy    0.909
|  1500 batches | accuracy    0.923
-------------------------