In [1]:
import argparse
import logging
import time

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.prototype.transforms import load_sp_model, PRETRAINED_SP_MODEL, SentencePieceTokenizer
from torchtext.utils import download_from_url
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torchtext.vocab import GloVe, FastText

In [2]:
# Constants
dataset_name = "AG_NEWS"
data_dir = "data"
device = "cpu"
embedding_dim = 300
learning_rate = 4.0
batch_size = 16
num_epochs = 5
padding_value = 0
padding_idx = padding_value


In [3]:
basic_tokenizer = get_tokenizer("basic_english")

In [4]:
sp_model_path = download_from_url(PRETRAINED_SP_MODEL["text_unigram_15000"], root="data")
sp_model = load_sp_model(sp_model_path)
sentencepiece_tokenizer = SentencePieceTokenizer(sp_model)

In [5]:
tokenizer = basic_tokenizer

In [6]:
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

In [7]:
train_iter = DATASETS[dataset_name](root=data_dir, split="train")
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=('<pad>', '<unk>'))
vocab.set_default_index(vocab['<unk>'])

In [8]:
glove = GloVe(name='840B', dim=300)
fasttext = FastText()

In [9]:
# tokenize the string input and shift lables
def text_pipeline(text):
    return vocab(tokenizer(text))

def label_pipeline(label):
    return int(label) - 1

In [10]:
def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        # Change the label from {1, 2, 3, 4} to {0, 1, 2, 3}
        label_list.append(label_pipeline(_label))
                
        # Return a list of ints.
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text.clone().detach())
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    
    # padding all sequences with 0
    text_list = pad_sequence(text_list, batch_first=True, padding_value=padding_value)

    return label_list.to(device), text_list.to(device)


In [11]:
# Get the data
train_iter = DATASETS[dataset_name](root=data_dir, split="train")
num_classes = len(set([label for (label, _) in train_iter]))

print(f"The number of classes is {num_classes} ...")

The number of classes is 4 ...


In [12]:
# Define the model
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, use_pretrained=True, freeze_embeddings=True):
        super(TextClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)

        if use_pretrained:
            self.embedding.weight.requires_grad = False
            for i in range(vocab_size):
                token = vocab.lookup_token(i)
                self.embedding.weight[i, :] = glove.get_vecs_by_tokens(tokenizer(token), lower_case_backup=True)
            self.embedding.weight.requires_grad = True

        if freeze_embeddings:
            self.embedding.weight.requires_grad = False

        self.fc = nn.Linear(embedding_dim, num_classes)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text):
        embedded = self.embedding(text)
        embedded_mean = embedded.mean(axis=1).squeeze(1)
        output = self.fc(embedded_mean)

        return output

In [13]:
#Set up the model

criterion = torch.nn.CrossEntropyLoss().to(device)
model = TextClassifier(len(vocab), embedding_dim, num_classes, use_pretrained=True, freeze_embeddings=True).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

In [14]:
train_iter, test_iter = DATASETS[dataset_name]()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train, split_valid = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

In [15]:
def train_epoch(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500

    for idx, (labels, texts) in enumerate(dataloader):
        optimizer.zero_grad()
        logits = model(texts)
        loss = criterion(input=logits, target=labels)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        optimizer.step()
        total_acc += (logits.argmax(1) == labels).sum().item()
        total_count += labels.size(0)

        if idx % log_interval == 0 and idx > 0:
            print(f"| epoch {epoch:3d} | {idx:5d}/{len(dataloader):5d} batches | accuracy {total_acc / total_count:8.3f}")
            total_acc, total_count = 0, 0


In [16]:
def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (labels, texts) in enumerate(dataloader):
            logits = model(texts)
            total_acc += (logits.argmax(1) == labels).sum().item()
            total_count += labels.size(0)

    return total_acc / total_count

In [17]:
for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()
    train_epoch(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(epoch, time.time() - epoch_start_time, accu_val)
    )
    print("-" * 59)

print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader, model)
print("test accuracy {:8.3f}".format(accu_test))

| epoch   1 |   500/ 7125 batches | accuracy    0.347
| epoch   1 |  1000/ 7125 batches | accuracy    0.461
| epoch   1 |  1500/ 7125 batches | accuracy    0.518
| epoch   1 |  2000/ 7125 batches | accuracy    0.549
| epoch   1 |  2500/ 7125 batches | accuracy    0.570
| epoch   1 |  3000/ 7125 batches | accuracy    0.576
| epoch   1 |  3500/ 7125 batches | accuracy    0.595
| epoch   1 |  4000/ 7125 batches | accuracy    0.594
| epoch   1 |  4500/ 7125 batches | accuracy    0.609
| epoch   1 |  5000/ 7125 batches | accuracy    0.611
| epoch   1 |  5500/ 7125 batches | accuracy    0.615
| epoch   1 |  6000/ 7125 batches | accuracy    0.616
| epoch   1 |  6500/ 7125 batches | accuracy    0.630
| epoch   1 |  7000/ 7125 batches | accuracy    0.635
-----------------------------------------------------------
| end of epoch   1 | time: 10.51s | valid accuracy    0.657 
-----------------------------------------------------------
| epoch   2 |   500/ 7125 batches | accuracy    0.679
| epoch  