In [1]:
import collections

import torch
from torch import nn
from torch.utils.data import DataLoader

from torchtext.data.utils import get_tokenizer
from torchtext.datasets import AG_NEWS

from torchtext.vocab import Vocab

Going to use built-in training datasets and compare against known results. Should probably link off to the pytorch datasets docs here.

Here we use the [AG News Dataset](https://paperswithcode.com/dataset/ag-news).

Each sample from the iterator is a tuple of `(label, text)` where `text` is an amalgamation of the `title`, `source` and `description` fields that are defined in the [original source](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html).

The labels are:  
1. World
2. Sports
3. Business
4. Sci/Tec

In PyTorch, a [DataLoader](https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader) wraps a Dataset as an iterable and allows for batching, sampling, shuffling and multiprocess data loading. The AG News dataset is an _iterable_ dataset and so we can't use sampling or shuffle (since in principle these would be dealt with by the iterator).

## Text Processing

We want to apply a simple text processing pipeline to each sample in our data:
 1. Tokenise: Split our inputs in to individual words
 2. Encode each word as integer (its index in our vocabulary)
 
We can use any tokeniser we want, but for simplicity just use [`torchtext`'s provided `basic_english` tokeniser](https://pytorch.org/text/stable/data_utils.html). This means that punctuation gets its own token for now.

PyTorch leaves the process of counting the token occurances to you (using `collections.Counter`) which you then pass into a [`torchtext.vocab.Vocab`](https://pytorch.org/text/stable/vocab.html) that handles the encoding, can also do things like fix total size, a minimum occurance frequency etc.

In [2]:
tokenizer = get_tokenizer('basic_english')

train_iter = AG_NEWS(split='train')

counter = collections.Counter()
for (label, text) in train_iter:
    counter.update(tokenizer(text))
    
vocab = Vocab(counter, min_freq=1)

train.csv: 29.5MB [00:00, 33.8MB/s]                            


Now we define the functions that we want to apply to each line of data and use them in a `collate_fn` that we will apply to an entire batch of data

In [3]:
device = torch.device("cpu")

class TextPipeline:
    def __init__(self, vocab, tokenizer):
        self.vocab = vocab
        self.tokenizer = tokenizer
        
    def __call__(self, text):
        return [self.vocab[token] for token in self.tokenizer(text)]

class LabelPipeline:
    def __call__(self, label):
        return int(label) - 1
    
class Collator:
    def __init__(self, text_pipeline, label_pipeline):
        self.text_pipeline = text_pipeline
        self.label_pipeline = label_pipeline
        
    def __call__(self, batch):
        """
        Prepare batch of data to be used as input to torch model.

        Returns
        -------
          labels: a torch.tensor of integer encoded labels. Has shape (batch_size)
          texts: a torch.tensor of integer encoded text sequences. Encoded using text_pipeline.
              Each example is concatenated together into a flat 1D tensor. The start of each
              example is recorded in offsets. Has shape (n_tokens_in_batch)
          offsets: a torch.tensor of the index of the start of each example.
              Has shape (batch_size)
        """
        labels, texts, offsets = [], [], [0]

        for (label, text) in batch:
            labels.append(
                self.label_pipeline(label)
            )
            processed_text = torch.tensor(
                self.text_pipeline(text),
                dtype=torch.int64
            )
            texts.append(processed_text)
            offsets.append(processed_text.size(0)) # length of processed text

        labels = torch.tensor(labels, dtype=torch.int64)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # starting index of each example
        texts = torch.cat(texts) # we can treat this differently as it is a list of tensors

        return labels.to(device), texts.to(device), offsets.to(device)

In [4]:
train_iter = AG_NEWS(split='train')
dataloader = DataLoader(
    train_iter,
    batch_size=8,
    shuffle=False,
    collate_fn=Collator(
        TextPipeline(vocab, tokenizer),
        LabelPipeline()
    )
)

# Predictive Model

Define the model. The model is an embedding layer (actually a torch `nn.EmbeddingBag`) followed by a [Linear layer](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (ie essentially a Logistic Regression on the ouptut of the embedding layer).

The embedding layer calculates the mean of the embeddings of each text we send it (in this case the text is the bag).

In [5]:
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super(TextClassificationModel, self).__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_classes)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

# Define Training Loop and Evaluation Function

This requires some knowledge of the peculiarities of PyTorch, including:
 - PyTorch basically doesn't do anything unless you tell it to. This means that you have to explicitly construct each step of your training loops e.g. forward pass, backward pass, calculating loss etc
 - [You have to set the "mode" of the model](https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch), e.g. during training you should call `model.train()`. This means that layers that behave differently during training and evaluation (for example dropout layers) will do the right thing.
 - [By default the optimizer accumlates gradients on each call of `loss.backward()`](https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch). This means that we need to call `optimizer.zero_grad()` as the first stage of each training loop.

In [6]:
import time

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    total_correct, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        
        predited_label = model(text, offsets)
        
        loss = loss_fn(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        
        total_correct += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_correct/total_count))
            total_correct, total_count = 0, 0
            start_time = time.time()
            
def evaluate(dataloader, model):
    model.eval()
    total_correct, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)
            total_correct += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_correct / total_count

# Train the Model

In [7]:
train_iter = AG_NEWS(split='train')

num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)

emsize = 64

model = TextClassificationModel(vocab_size, emsize, num_class).to("cpu")

In [8]:
from torch.utils.data.dataset import random_split

# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

total_accu = None
train_iter, test_iter = AG_NEWS()

# this changes the dataset from being an iterable to a map (ie accessible by index)
train_dataset = list(train_iter)
test_dataset = list(test_iter)

# create 95% train/val split with test heldout
num_train = int(len(train_dataset) * 0.95)

split_train_, split_valid_ = random_split(
    train_dataset,
    [num_train, len(train_dataset) - num_train]
)

train_dataloader = DataLoader(
    split_train_,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=Collator(
        TextPipeline(vocab, tokenizer),
        LabelPipeline()
    )
)

valid_dataloader = DataLoader(
    split_valid_,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=Collator(
        TextPipeline(vocab, tokenizer),
        LabelPipeline()
    )
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=Collator(
        TextPipeline(vocab, tokenizer),
        LabelPipeline()
    )
)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, loss_fn, optimizer)
    accu_val = evaluate(valid_dataloader, model)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

test.csv: 1.86MB [00:00, 17.0MB/s]                  


| epoch   1 |   500/ 1782 batches | accuracy    0.673
| epoch   1 |  1000/ 1782 batches | accuracy    0.853
| epoch   1 |  1500/ 1782 batches | accuracy    0.874
-----------------------------------------------------------
| end of epoch   1 | time: 11.12s | valid accuracy    0.883 
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.898
| epoch   2 |  1000/ 1782 batches | accuracy    0.897
| epoch   2 |  1500/ 1782 batches | accuracy    0.902
-----------------------------------------------------------
| end of epoch   2 | time: 10.68s | valid accuracy    0.893 
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.917
| epoch   3 |  1000/ 1782 batches | accuracy    0.913
| epoch   3 |  1500/ 1782 batches | accuracy    0.911
-----------------------------------------------------------
| end of epoch   3 | time: 10.67s | valid accuracy    0.904 
-------------------------------

In [10]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader, model)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.906
