<a href='https://ai.meng.duke.edu'> = <img align="left" style="padding-top:10px;" src=https://storage.googleapis.com/aipi_datasets/Duke-AIPI-Logo.png>

# Text Classification using Model-Created Embeddings in PyTorch
In this notebook we will be doing text classification by using PyTorch embeddings to represent for each document in the dataset, and we will learn the embeddings from scratch using our training data, rather than using pre-trained embeddings. The embeddings will then be used to represent our documents as features and input into a classification model.  Our goal will be to classify the articles in the AgNews dataset into their correct category: "World", "Sports", "Business", or "Sci/Tec".

To create the embedding for each document, we will first create embeddings for each word in the document.  We will then use the mean embedding for all words in a document as the embedding to represent the document.  The document embedding will then serve as the feature set to feed into a single-layer linear classifier which performs softmax regression with cross entropy loss to classify the documents.

**Notes:**  
- This does not need to run on GPU, but will take ~5 minutes to run on CPU

**References:**  
- This notebook includes portions of code from the [PyTorch docs tutorials](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)

In [38]:
import os
import numpy as np
import pandas as pd
import string
import time
import copy
from sklearn.linear_model import LogisticRegression
import urllib.request
import zipfile

import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader,TensorDataset
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torch import nn

import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Download the data
if not os.path.exists('../data'):
    os.mkdir('../data')
if not os.path.exists('../data/agnews'):
    url = 'https://storage.googleapis.com/aipi540-datasets/agnews.zip'
    urllib.request.urlretrieve(url,filename='../data/agnews.zip')
    zip_ref = zipfile.ZipFile('../data/agnews.zip', 'r')
    zip_ref.extractall('../data/agnews')
    zip_ref.close()

train_df = pd.read_csv('../data/agnews/train.csv')
test_df = pd.read_csv('../data/agnews/test.csv')

# Combine title and description of article to use as input documents for model
train_df['full_text'] = train_df.apply(lambda x: ' '.join([x['Title'],x['Description']]),axis=1)
test_df['full_text'] = test_df.apply(lambda x: ' '.join([x['Title'],x['Description']]),axis=1)

# Create dictionary to store mapping of labels
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

train_df.head()

Unnamed: 0,Class Index,Title,Description,full_text
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli...",Wall St. Bears Claw Back Into the Black (Reute...
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...,Carlyle Looks Toward Commercial Aerospace (Reu...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...,Oil and Economy Cloud Stocks' Outlook (Reuters...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...,Iraq Halts Oil Exports from Main Southern Pipe...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco...","Oil prices soar to all-time record, posing new..."


In [5]:
# View a couple of the documents
for i in range(5):
    print(train_df.iloc[i]['full_text'])
    print()

Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.

Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday.

Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world

## Build Datasets
Now that our data is loaded, we first need to prepare our data by putting it into PyTorch Dataset format.  We will also split our training data to create a validation set.

In [6]:
# Put data in iterator form needed to create PyTorch Datasets from data
train_iter = [(label,text) for label,text in zip(train_df['Class Index'].to_list(),train_df['full_text'].to_list())]
test_iter = [(label,text) for label,text in zip(test_df['Class Index'].to_list(),test_df['full_text'].to_list())]

# Create PyTorch Datasets from iterators
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

# Split training data to get a validation set
num_train = int(len(train_dataset) * 0.95)
split_train_dataset, split_valid_dataset = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

## Put Data in DataLoaders
We are now ready to create PyTorch DataLoaders from our Datasets, which we can use to feed mini-batches of inputs and labels to our model.  

However, we want to perform a couple operations on the data which is loaded into each mini-batch.  We can define a custom `collate_fn()` to perform these operations, which then is applied to the data loaded into each batch in the DataLoader.  We want to accomplish the following in our `collate_fn()`:  
- Tokenize the text data to form a list of tokens for each document  
- Convert the token list for each document into a list of integers.  We do this by creating the "vocabulary" out of all tokens found in the training data as an array, and then represent each document as a list of integers representing the index positions of the document's words in the vocabulary
- Store the locations of the "offsets" or delimiter positions between each document in the minibatch.  Since the samples in the batch are concatenated into a single tensor for the input to the `nn.EmbeddingBag` layer in our model, we need to store the delimiter positions representing the beginning index of each individual document sequence

In [25]:
# Function to tokenize the text
def yield_tokens(data_iter,tokenizer):
    for _, text in data_iter:
        yield tokenizer(text)

# Build vocabulary from tokens of training set
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(yield_tokens(train_iter,tokenizer), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# Define collate_batch function to get single collated tensor for batch in form needed by nn.EmbeddingBag
def collate_batch(batch,tokenizer,vocab):
    # Pipelines for processing text and labels
    text_pipeline = lambda x: vocab(tokenizer(x))
    label_pipeline = lambda x: int(x) - 1
    
    label_list, text_list, offsets = [], [], [0]
    # Iterate through batch, processing text and adding text, labels and offsets to lists
    for (label, text) in batch:
        label_list.append(label_pipeline(label))
        processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)  

In [40]:
batch_size = 64
# Create training, validation and test set DataLoaders using custom collate_batch function
train_dataloader = DataLoader(split_train_dataset, batch_size=batch_size,
                              shuffle=True, collate_fn = lambda x: collate_batch(x,tokenizer,vocab))
val_dataloader = DataLoader(split_valid_dataset, batch_size=batch_size,
                              shuffle=True, collate_fn = lambda x: collate_batch(x,tokenizer,vocab))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,
                             shuffle=True, collate_fn = lambda x: collate_batch(x,tokenizer,vocab))

# Set up dict for dataloaders to use in training
train_dataloaders = {'train':train_dataloader,'val':val_dataloader}

# Store size of training and validation sets
dataset_sizes = {'train':len(split_train_dataset),'val':len(split_valid_dataset)}

## Train model
Now that we have our data in DataLoaders, we are ready to train our classification model.  Our model will be composed of two layers:  
1) A [nn.EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag) layer which converts each word / n-gram into an embedding and then takes the mean or sum of the embeddings of all words/n-grams in a document as the embedding vector representing the document.  We can specify the size of the embedding vector we wish to create to represent each document
2) A fully connected nn.Linear layer which takes the document embedding as input and then attempts to classify the document based on the embedding.

![](.img/text_model_pytorch.png)

*Figure from the [PyTorch docs](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)*

In [13]:
# Define the model
class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        # Embedding layer
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, mode="mean",sparse=True)
        # Fully connected final layer to convert embeddings to output predictions
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

We can now set up a function to train our model.  Our `train_model()` function below will train our model and report out the training set and validation set performance at each epoch.  The function will store the model weights corresponding to the weights which achieved the best validation set performance during the training cycle.

In [54]:
def train_model(model, criterion, optimizer, dataloaders, dataset_sizes, scheduler, device, num_epochs=5):
    model = model.to(device) # Send model to GPU if available
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Get the input images and labels, and send to GPU if available
            for (labels, text, offsets) in dataloaders[phase]:
                text = text.to(device)
                labels = labels.to(device)
                offsets = offsets.to(device)

                # Zero the weight gradients
                optimizer.zero_grad()

                # Forward pass to get outputs and calculate loss
                # Track gradient only for training data
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model.forward(text,offsets)
                    loss = criterion(outputs, labels)

                    # Backpropagation to get the gradients with respect to each weight
                    # Only if in train
                    if phase == 'train':
                        loss.backward()
                        # Update the weights
                        optimizer.step()

                # Convert loss into a scalar and add it to running_loss
                running_loss += loss.item() * labels.size(0)
                # Track number of correct predictions
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)

            # Step along learning rate scheduler when in train
            if phase == 'train':
                scheduler.step()

            # Calculate and display average loss and accuracy for the epoch
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print('{} loss: {:.4f} accuracy: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # If model performs better on val set, save weights as the best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best validation set accuracy: {:3f}'.format(best_acc))

    # Load the weights from best model
    model.load_state_dict(best_model_wts)

    return model

In [55]:
# Instantiate the model
num_classes = len(set([label for (label, _) in train_iter]))
vocab_size = len(vocab)
embed_dim = 64 # Set desired document embedding size
nn_model = TextClassificationModel(vocab_size, embed_dim, num_classes)

# Set hyperparameters
epochs = 10 # epoch
learning_rate = 1.  # learning rate

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(nn_model.parameters(), lr=learning_rate)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1 , gamma=0.8)

# Train the model
nn_model = train_model(nn_model, criterion, optimizer, train_dataloaders, lr_scheduler, device, num_epochs=10)

Epoch 0/9
----------
train loss: 0.7667 accuracy: 0.7195
val loss: 0.4612 accuracy: 0.8528

Epoch 1/9
----------
train loss: 0.3923 accuracy: 0.8711
val loss: 0.3817 accuracy: 0.8770

Epoch 2/9
----------
train loss: 0.3354 accuracy: 0.8908
val loss: 0.3553 accuracy: 0.8887

Epoch 3/9
----------
train loss: 0.3086 accuracy: 0.9001
val loss: 0.3380 accuracy: 0.8928

Epoch 4/9
----------
train loss: 0.2923 accuracy: 0.9053
val loss: 0.3312 accuracy: 0.8918

Epoch 5/9
----------
train loss: 0.2813 accuracy: 0.9086
val loss: 0.3245 accuracy: 0.8953

Epoch 6/9
----------
train loss: 0.2732 accuracy: 0.9116
val loss: 0.3197 accuracy: 0.9000

Epoch 7/9
----------
train loss: 0.2672 accuracy: 0.9131
val loss: 0.3178 accuracy: 0.8995

Epoch 8/9
----------
train loss: 0.2627 accuracy: 0.9144
val loss: 0.3170 accuracy: 0.9020

Epoch 9/9
----------
train loss: 0.2591 accuracy: 0.9157
val loss: 0.3134 accuracy: 0.9025

Training complete in 1m 35s
Best validation set accuracy: 0.902500


## Test the model
Now that we have trained our model, we can evaluate its performance using our test set.

In [56]:
def evaluate(dataloader, model):
    # Generate predictions and calculate accuracy
    nn_model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model.forward(text, offsets)
            #loss = criterion(predited_label, label)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [57]:
# Evaluate performance on the test dataset
accu_test = evaluate(test_dataloader, nn_model)
print('test set accuracy {:8.3f}'.format(accu_test))

test set accuracy    0.895
