In [111]:
import torch 
import numpy as np
import torchdata
from torchtext.datasets import AG_NEWS

In [112]:
train_iter = iter(AG_NEWS(split='train'))

1 : World
2 : Sports
3 : Business
4 : Sci/Tec

In [113]:
next(train_iter)

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

## Prepare data processing pipelines

In [114]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [115]:
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter) :
    for _, text in data_iter :
        yield tokenizer(text)

In [116]:
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>']) # Unknown
vocab.set_default_index(vocab['<unk>'])

In [117]:
vocab(['man','i','do','not','know','why','promethazine','acetaldehyde'])

[335, 282, 423, 62, 1199, 1164, 0, 0]

In [118]:
len(vocab)

95811

- The text pipeline converts a text string into a list of integers based on the lookup table defined in the vocabulary. 
- The label pipeline converts the label into integers

In [119]:
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : int(x) - 1

In [120]:
text_pipeline('man i do not know why promethazine actaldehyde')

[335, 282, 423, 62, 1199, 1164, 0, 0]

In [121]:
label_pipeline('12')

11

## Generate data batch and iterator

In [122]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [123]:
def collate_batch(batch) :
    label_list, text_list, offsets = [], [], [0]
    for _label, _text in batch :
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In [124]:
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

## Create a Model

In [125]:
from torch import nn

class TextClassificationModel(nn.Module) :
    def __init__(self, vocab_size, embed_dim, num_class) :
        super().__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.dense = nn.Linear(in_features = embed_dim, out_features = num_class)
        self.init_weight()
    
    def init_weight(self) :
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.dense.weight.data.uniform_(-initrange, initrange)
        self.dense.bias.data.zero_()

    def forward(self, text, offsets) :
        embedded = self.embedding(text, offsets)
        return self.dense(embedded)

In [126]:
num_class = len(set([label for label,text in train_iter]))
vocab_size = len(vocab)
emsize = 64
modelV1 = TextClassificationModel(vocab_size, emsize, num_class)

In [127]:
modelV1

TextClassificationModel(
  (embedding): EmbeddingBag(95811, 64, mode=mean)
  (dense): Linear(in_features=64, out_features=4, bias=True)
)

In [128]:
import time

def train_step(model, dataloader, optimizer, loss_fn) :
    model.train()
    total_acc, total_loss = 0, 0

    for batch, (label, text, offsets) in enumerate(dataloader) :
        predicted_label = model(text, offsets)
        loss = loss_fn(predicted_label, label)
        optimizer.zero_grad()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
        optimizer.step()
        
        total_acc += (torch.argmax(torch.softmax(predicted_label, dim=1), dim=1) == label).sum().item() / len(predicted_label)
        total_loss += loss.item()
    
    total_acc = total_acc / len(dataloader)
    total_loss = total_loss / len(dataloader)
    return total_acc, total_loss

In [129]:
def test_step(model, dataloader, loss_fn) :
    model.eval()
    total_acc, total_loss = 0, 0

    with torch.inference_mode() :
        for batch, (label, text, offsets) in enumerate(dataloader) :
            predicted_label = model(text, offsets)
            loss = loss_fn(predicted_label, label)
            total_acc += (torch.argmax(torch.softmax(predicted_label, dim=1), dim=1) == label).sum().item() / len(predicted_label)
            total_loss += loss.item()
        
    total_acc = total_acc / len(dataloader)
    total_loss = total_loss / len(dataloader)
    return total_acc, total_loss

## Split the dataset and run the model

In [130]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from tqdm import tqdm

EPOCHS = 10
BATCH_SIZE = 32
lr = 0.1

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(modelV1.parameters(), lr=lr)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.9)
split_train, split_val = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
split_train, split_val

(<torch.utils.data.dataset.Subset at 0x1c72ba0a140>,
 <torch.utils.data.dataset.Subset at 0x1c72ba09660>)

In [131]:
train_dataloader = DataLoader(split_train, BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_dataloader = DataLoader(split_val, BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
len(train_dataloader), len(val_dataloader), len(test_dataloader)

(3375, 375, 238)

In [132]:
def collate_batch(batch) :
    label_list, text_list, offsets = [], [], [0]
    for _label, _text in batch :
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In [133]:
r = 0
_offsets = [0]
for _label, _text in split_train :
    a = torch.tensor(text_pipeline(_text))
    print(a ,'\n', a.shape,'\n', a.size(0))
    r+=1
    _offsets.append(a.size(0))
    if r == 4 :
        break
_offsets

tensor([ 1111,  4276,   465, 14198,    24,   903,     4,   372,  3906,   783,
          864,     3,  7005,     1,    15, 21292,  1862,   677,    62,   189,
          442,     4, 13558,    32,    78,     1,   108,   960,     6,  1322,
          783,   864,     1,    45,     2,   335,    75,   747,   283,   423,
           25,  2328,    12,    83,   435, 18624,     4,  3736,     1]) 
 torch.Size([49]) 
 49
tensor([21633,    84,  3881,    33,   203,   379,    30,   365,   388,  2253,
        25589, 21633, 18953,     3,   141,   199,   125,     8,    76,     6,
            2,   379,  1007,   176,     3,    10,    60,    26, 81390,    17,
           34,  1872,     2,  6465,   199,   125,  4991,  3881,    95,   939,
          379,    66,     5,  3799,   365,   388,     1]) 
 torch.Size([47]) 
 47
tensor([  630,   810,  5938,    11, 10489,     4,   149,   614,  4734,    10,
            2,  1091,    29,  7310,     4,   149,   614,   107,   192,    12,
           83,    37,  2708,    10,    30,

[0, 49, 47, 37, 59]

In [134]:
torch.tensor(_offsets[:-1]).cumsum(dim=0)

tensor([  0,  49,  96, 133])

In [135]:
for epoch in range(1, EPOCHS+1) :
    train_acc, train_loss = train_step(modelV1, train_dataloader, optimizer, loss_fn)
    val_acc, val_loss = test_step(modelV1, val_dataloader, loss_fn)

    print(f'Epoch {epoch} | train accuracy {train_acc:.3f} | valid accuracy {val_acc:.3f} ')
    print('-' * 59)

Epoch 1 | train accuracy 0.411 | valid accuracy 0.516 
-----------------------------------------------------------
Epoch 2 | train accuracy 0.571 | valid accuracy 0.620 
-----------------------------------------------------------
Epoch 3 | train accuracy 0.664 | valid accuracy 0.705 
-----------------------------------------------------------
Epoch 4 | train accuracy 0.728 | valid accuracy 0.758 
-----------------------------------------------------------
Epoch 5 | train accuracy 0.769 | valid accuracy 0.788 
-----------------------------------------------------------
Epoch 6 | train accuracy 0.795 | valid accuracy 0.808 
-----------------------------------------------------------
Epoch 7 | train accuracy 0.813 | valid accuracy 0.820 
-----------------------------------------------------------
Epoch 8 | train accuracy 0.827 | valid accuracy 0.829 
-----------------------------------------------------------
Epoch 9 | train accuracy 0.838 | valid accuracy 0.836 
-------------------------

In [136]:
for epoch in range(11, EPOCHS+11) :
    train_acc, train_loss = train_step(modelV1, train_dataloader, optimizer, loss_fn)
    val_acc, val_loss = test_step(modelV1, val_dataloader, loss_fn)

    print(f'Epoch {epoch} | train accuracy {train_acc:.3f} | valid accuracy {val_acc:.3f} ')
    print('-' * 59)

Epoch 11 | train accuracy 0.853 | valid accuracy 0.847 
-----------------------------------------------------------
Epoch 12 | train accuracy 0.858 | valid accuracy 0.853 
-----------------------------------------------------------
Epoch 13 | train accuracy 0.863 | valid accuracy 0.857 
-----------------------------------------------------------
Epoch 14 | train accuracy 0.868 | valid accuracy 0.862 
-----------------------------------------------------------
Epoch 15 | train accuracy 0.871 | valid accuracy 0.864 
-----------------------------------------------------------
Epoch 16 | train accuracy 0.875 | valid accuracy 0.867 
-----------------------------------------------------------
Epoch 17 | train accuracy 0.878 | valid accuracy 0.869 
-----------------------------------------------------------
Epoch 18 | train accuracy 0.880 | valid accuracy 0.871 
-----------------------------------------------------------
Epoch 19 | train accuracy 0.883 | valid accuracy 0.874 
----------------

In [137]:
torch.save(modelV1.state_dict(), './save_model/modelV1_nlp1.pth')
modelV1_loaded = TextClassificationModel(vocab_size, emsize, num_class)
modelV1_loaded.load_state_dict(torch.load(f='./save_model/modelV1_nlp1.pth'))
modelV1_loaded

TextClassificationModel(
  (embedding): EmbeddingBag(95811, 64, mode=mean)
  (dense): Linear(in_features=64, out_features=4, bias=True)
)

In [138]:
modelV1_loaded.state_dict()['embedding.weight'].shape

torch.Size([95811, 64])

In [139]:
test_acc, test_loss = test_step(modelV1_loaded, test_dataloader, loss_fn)
test_acc, test_loss

(0.8696165966386554, 0.39370154457933765)

In [140]:
AG_NEWS_label = ['World', 'Sports', 'Business', 'Sci/Tec']

def predict(model, text, text_pipeline) :
    model.eval()
    with torch.inference_mode() :
        text = torch.tensor(text_pipeline(text), dtype=torch.int64)
        pred = model(text, torch.tensor([0]))
        probas = torch.softmax(pred, dim=1)
        pred_class = torch.argmax(probas, dim=1)

    return probas[0], AG_NEWS_label[pred_class]

In [141]:
ex_text_str = "MEMPHIS, Tenn. - Four days ago, Jon Rahm was \
    enduring the seasons worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursdays first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering hed never played the \
    front nine at TPC Southwind."

In [142]:
probas, pred_class = predict(modelV1_loaded, ex_text_str, text_pipeline)
probas, pred_class

(tensor([0.0403, 0.9428, 0.0064, 0.0105]), 'Sports')