In [1]:
import torch
from torch import nn
import sys 
sys.path.append("../models")
from transformer_blocks import Transformer
from torch.utils.data import random_split
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS
from torch.utils.data import DataLoader

In [2]:
class ClassificationTransformer(nn.Module): 
    def __init__(self, d_k, d_model, d_v, d_ff, num_heads, num_layers, num_classes, vocab_size, dropout=0.1) -> None:
        super(ClassificationTransformer, self).__init__()
        self.encoder_only_transformer = Transformer(d_k, d_model, d_v, d_ff, num_heads, num_layers, vocab_size, dropout=0.1)
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(d_model, d_model)
        self.fc2 = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        out = self.encoder_only_transformer(x)
        avg_pool = torch.mean(out, dim=-2)
        return self.fc2(self.dropout(self.fc1(self.dropout(avg_pool))))

In [3]:
train_iter = AG_NEWS(split='train')

# Convert to list to enable random splitting
train_dataset = list(train_iter)

#80-20 train-val split 
train_size = int(len(train_dataset) * 0.8)  
val_size = len(train_dataset) - train_size  
train_data, val_data = random_split(train_dataset, [train_size, val_size])

tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

VOCAB_SIZE = 5000

# Build vocab based on the train_data
train_data_iter = (text for _, text in train_data)
vocab = build_vocab_from_iterator(yield_tokens(train_data_iter), specials=["<unk>"], max_tokens=VOCAB_SIZE)
vocab.set_default_index(vocab["<unk>"])

In [4]:
from torch.nn.utils.rnn import pad_sequence

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    label_list, text_list, lengths = [], [], []
    
    # Sort the batch in the descending order
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    
    for _label, _text in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list, dtype=torch.int64)
    lengths = torch.tensor(lengths, dtype=torch.int64)
    
    # Pad sequences
    text_list = pad_sequence(text_list, batch_first=True)
    
    return label_list.to(device), text_list.to(device), lengths

In [5]:
train_loader = DataLoader(train_data, batch_size = 8, shuffle = True, collate_fn = collate_batch)
val_loader = DataLoader(val_data, batch_size = 8, shuffle = False, collate_fn = collate_batch)

In [6]:
LEARNING_RATE = 3e-4
NUM_EPOCHS = 50
DROPOUT = 0.2
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

D_K = 128
D_V = D_K
D_MODEL = D_K * 2
D_FF = D_MODEL * 2
NUM_LAYERS = 2
OUTPUT_DIM = 4

In [7]:
model = ClassificationTransformer(D_K, D_MODEL, D_V, D_FF, num_heads=3, num_classes=OUTPUT_DIM, num_layers=2, vocab_size=VOCAB_SIZE)
model = model.to(DEVICE)

In [8]:
def train(model, train_loader, val_loader, loss_function, optim, epochs, device):
    losses = [] #group losses for loss visualization 
    running_loss = 0.0
    val_losses = []
    for epoch in range(epochs):
        model.train()
        print("Epoch %d / %d" % (epoch+1, epochs))
        print("-"*10)
    
        for i, batch_data in enumerate(train_loader):
            (y, x, x_size) = batch_data
            logits = model(x)
            loss = loss_function(logits, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            running_loss += loss.item()
            losses.append(loss)

            if (i+1) % 1000 == 0:
                print("Step: {}, average training loss over last 2000 steps: {:.4f}".format(i+1, running_loss/1000))
                running_loss = 0.0
            
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            correct_pred = 0.0
            num_samples = 0
            for i, batch_data in enumerate(val_loader):
                (y, x, x_size) = batch_data
                y, x, x_size = y.to(device), x.to(device), x_size.to(device)
                logits = model(x)
                loss = loss_function(logits, y)
                _, predicted_labels = torch.max(logits, 1)
                correct_pred += (predicted_labels.long() == y.long()).sum()
                num_samples+=predicted_labels.shape[0]
                val_loss += loss.item()
            
            val_accuracy = (correct_pred / num_samples) * 100
            val_losses.append(val_loss)
        print("Epoch: {}, validation loss: {:.4f}, val accuracy: {:.2f}".format(epoch+1, val_loss/len(val_loader), val_accuracy))
    
    return losses, val_losses

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
train_loss, val_loss = train(model, train_loader, val_loader, torch.nn.functional.cross_entropy, optimizer, NUM_EPOCHS, DEVICE)

Epoch 1 / 50
----------
Step: 1000, average training loss over last 2000 steps: 1.0266
Step: 2000, average training loss over last 2000 steps: 0.5943
Step: 3000, average training loss over last 2000 steps: 0.5062
Step: 4000, average training loss over last 2000 steps: 0.4616
Step: 5000, average training loss over last 2000 steps: 0.4308
Step: 6000, average training loss over last 2000 steps: 0.4151
Step: 7000, average training loss over last 2000 steps: 0.4083
Step: 8000, average training loss over last 2000 steps: 0.3991
Step: 9000, average training loss over last 2000 steps: 0.4013
Step: 10000, average training loss over last 2000 steps: 0.3784
Step: 11000, average training loss over last 2000 steps: 0.3690
Step: 12000, average training loss over last 2000 steps: 0.3558
Epoch: 1, validation loss: 0.3677, val accuracy: 87.29
Epoch 2 / 50
----------
Step: 1000, average training loss over last 2000 steps: 0.3211
Step: 2000, average training loss over last 2000 steps: 0.3210
Step: 3000, 

What do I want to see from this training? 

- I want to see how long it takes for this to converge 
- I want to see the computational complexity (how long per epoch?)
    - anywhere from 60 seconds an epoch to 3 minutes an epoch depending on hyperparameters 
    - With the low
- How does the model accuracy respond to changes in hyperparams? Not just LR and the usuals, but also d_model, d_k, d_v, d_ff
    - This is intuitive, higher values for these hyper params work better (until a point of diminishing return) since there's a balance between trying to capture all the intricacies of the data in high dimensional vector space and the actual complexity of the data to begin with. (as well as computational constraints) - doubling these dimensions led to more batch efficient learning but at the same time increased runtime by a proportional 2x 
    - 

Thoughts so far: 

This architecture is way more compute intensive but at the same time so much more batch efficient than previously explored ones when tuned with the right hyperparameters (LSTMs, RNNs even when bidirectional + multi layered and so on) that we see near 90% validation accuracy after just one or two epochs. These epochs, however, take tremendously long (7 minutes per epoch v/s the 1 minute LSTMs were taking). Part of this could be becauase of the implementation being from scratch (torch.transformer might be faster), but most of this comes from the computational load the multiple blocks (and their individual complexities)

Now that we have the encoder and were able to build something useful off of it and ensure its functionality, we should tweak our multihead attention to enable some form of masking and then build a decoder.