In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import math
import time
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [None]:
# --- 2. prepare data ---

# load dataset
train_iter, test_iter = AG_NEWS()
# turn iterator to list to be able to rewind it later
train_data = list(train_iter)
test_data = list(test_iter)

# create tokenizer
tokenizer = get_tokenizer('basic_english')

# build vocabulary from training set
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

# add special tokens <unk> and <pad>
vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])
PAD_IDX = vocab['<pad>']
VOCAB_SIZE = len(vocab)

#define text and label pipelines
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1  # labels are 1, 2, 3, 4 in AG_NEWS

# define collate function for DataLoader
def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    # use pad_sequence to pad the text sequences to the same length
    # batch_first=False for transformer input
    padded_text_list = pad_sequence(text_list, batch_first=False, padding_value=PAD_IDX)
    return label_list.to(device), padded_text_list.to(device)

In [None]:
# --- 3. define model ---

class PositionalEncoding(nn.Module):
    # positional encoding module
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerClassifier(nn.Module):
    # transformer model for text classification
    def __init__(self, vocab_size, embed_dim, nhead, ffn_hid_dim, nlayers, num_class, dropout=0.5):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        encoder_layers = nn.TransformerEncoderLayer(embed_dim, nhead, ffn_hid_dim, dropout, batch_first=False)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.d_model = embed_dim
        self.decoder = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_padding_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_key_padding_mask=src_padding_mask)
        # average pooling over the sequence dimension to get a fixed-size representation
        output = output.mean(dim=0)
        return self.decoder(output)

In [None]:
LR = 5
BATCH_SIZE = 64
num_class = len(set([label for (label, text) in train_data]))
vocab_size = len(vocab)
emsize = 64
nhead = 4
ffn_hid_dim = 64
nlayers = 2
EPOCHS = 5

# optimizer, loss function and model instantiation
model = TransformerClassifier(vocab_size, emsize, nhead, ffn_hid_dim, nlayers, num_class).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

# create DataLoader
train_dataset, test_dataset = to_map_style_dataset(train_data), to_map_style_dataset(test_data)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)


In [None]:
# --- 4. set up training and evaluation functions ---

# hyperparameters
EPOCHS = 10
LR = 5.0  
BATCH_SIZE = 64
EMBED_DIM = 64      # embedding dimension
NHEAD = 4           # number of heads in multi-head attention
FFN_HID_DIM = 64    # feedforward network hidden dimension
NLAYERS = 2         # Transformer Encoder layers
NUM_CLASS = len(set([label for (label, text) in train_data]))

# instantiate the model
model = TransformerClassifier(VOCAB_SIZE, EMBED_DIM, NHEAD, FFN_HID_DIM, NLAYERS, NUM_CLASS).to(device)

# define loss function, optimizer and learning rate scheduler
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

In [None]:
# --- 5. training and evaluation functions ---

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    
    for label, text in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()
        
        # generate padding mask
        src_padding_mask = (text == PAD_IDX).transpose(0, 1)
        
        predicted_label = model(text, src_padding_mask)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        
    return total_acc / total_count

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for label, text in dataloader:
            src_padding_mask = (text == PAD_IDX).transpose(0, 1)
            predicted_label = model(text, src_padding_mask)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [None]:
# --- 6. training loop ---

# generate DataLoader
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

total_accu = None
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    accu_train = train(train_dataloader)
    accu_test = evaluate(test_dataloader)
    
    if total_accu is None or accu_test > total_accu:
        total_accu = accu_test
        # save the best model
        torch.save(model.state_dict(), 'best_transformer_model.pth')
        
    print('-' * 59)
    print(f'| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | '
          f'train accuracy {accu_train:8.3f} | test accuracy {accu_test:8.3f} ')
    print('-' * 59)
    scheduler.step()

# --- 7. evaluate the model on test dataset ---
model.load_state_dict(torch.load('best_transformer_model.pth'))
print('Checking the results of the best model on test dataset.')
accu_test = evaluate(test_dataloader)
print(f'Test accuracy {accu_test:8.3f}')

# --- 8. predict a new sentence ---
def predict(text, model, vocab, text_pipeline):
    model.eval()
    with torch.no_grad():
        tensor = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device).unsqueeze(1)
        output = model(tensor)
        return output.argmax(1).item() + 1 # +1 is because we subtracted 1 from labels earlier

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75."
ag_news_label = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"}

model = model.to("cpu") 
print("\n--- Prediction Example ---")
print("Sample Text:", ex_text_str)
print("Predicted category:", ag_news_label[predict(ex_text_str, model, vocab, text_pipeline)])

Training: 100%|██████████| 1875/1875 [07:12<00:00,  4.34it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 443.10s | train accuracy    0.782 | test accuracy    0.897 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [10:25<00:00,  3.00it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 639.11s | train accuracy    0.905 | test accuracy    0.904 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [12:02<00:00,  2.59it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 736.27s | train accuracy    0.911 | test accuracy    0.905 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [13:13<00:00,  2.36it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 807.61s | train accuracy    0.911 | test accuracy    0.906 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [10:21<00:00,  3.02it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 632.15s | train accuracy    0.910 | test accuracy    0.905 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [10:01<00:00,  3.12it/s]


-----------------------------------------------------------
| end of epoch   6 | time: 612.18s | train accuracy    0.911 | test accuracy    0.905 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [10:05<00:00,  3.10it/s]


-----------------------------------------------------------
| end of epoch   7 | time: 617.08s | train accuracy    0.911 | test accuracy    0.905 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [10:25<00:00,  3.00it/s]


-----------------------------------------------------------
| end of epoch   8 | time: 636.00s | train accuracy    0.911 | test accuracy    0.905 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [10:03<00:00,  3.11it/s]


-----------------------------------------------------------
| end of epoch   9 | time: 614.39s | train accuracy    0.911 | test accuracy    0.905 
-----------------------------------------------------------


Training: 100%|██████████| 1875/1875 [06:55<00:00,  4.51it/s]


-----------------------------------------------------------
| end of epoch  10 | time: 418.75s | train accuracy    0.912 | test accuracy    0.905 
-----------------------------------------------------------
Checking the results of the best model on test dataset.
Test accuracy    0.906

--- Prediction Example ---
Sample Text: MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75.
Predicted category: Sports
