## 1. Import Modules and Data
If you haven't download IMDB dataset, please run `download_imdb.py` or download and unzip `aclImdb_v1.tar.gz` from [here](http://ai.stanford.edu/~amaas/data/sentiment).

In [1]:
from data_imdb import test_loader, train_loader, vocab, PAD_TOKEN, CLS_TOKEN
from modules import Encoder, make_src_mask
import torch
from torch import nn
import config
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import os

os.makedirs(config.checkpoint_dir, exist_ok=True)

torch.manual_seed(3407)
config.device

Loading pos data: 100%|██████████| 12500/12500 [00:02<00:00, 4738.87it/s]
Loading neg data: 100%|██████████| 12500/12500 [00:02<00:00, 4561.69it/s]
Loading pos data: 100%|██████████| 12500/12500 [00:02<00:00, 4306.01it/s]
Loading neg data: 100%|██████████| 12500/12500 [00:03<00:00, 4104.61it/s]


device(type='cuda', index=0)

## 2. Build Classifier Model
We only need to use the transformer encoder as a text feature extractor, and then use the CLS token attached to the beginning of each text to make predictions.

In [2]:
class SentimentClassifier(nn.Module):
    def __init__(self, encoder, d_model, device):
        super(SentimentClassifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(d_model, 2, device=device)

    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids, attention_mask)
        cls_token_output = encoder_output[:, 0, :]  # Extract CLS token
        logits = self.fc(cls_token_output)
        return logits

model = SentimentClassifier(
    Encoder(
        enc_voc_size=len(vocab),
        max_len=config.max_len,
        d_model=config.d_model,
        ffn_hidden=config.ffn_hidden,
        n_head=config.n_head,
        n_layer=config.n_layer,
        dropout=config.dropout,
        device=config.device,
    ),
    config.d_model,
    device=config.device
)

## 3. Train Model
Train model on IMDB training dataset for `epochs` epochs. All relavant training parameters can be found in [config.py](./config.py).

In [3]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.init_lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs)


def train(model, train_loader, optimizer, criterion, scheduler, epoch):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
        input_ids, labels = batch
        input_ids, labels = input_ids.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        outputs = model(input_ids, make_src_mask(input_ids, vocab[PAD_TOKEN], device))
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    scheduler.step()

    return avg_loss


for epoch in range(config.epochs):
    avg_train_loss = train(model, train_loader, optimizer, criterion, scheduler, epoch)
    print(f"Epoch {epoch + 1}/{config.epochs}, Training Loss: {avg_train_loss}")

checkpoint_path = os.path.join(config.checkpoint_dir, f"imdb_ckpt.pth")
torch.save(model.state_dict(), checkpoint_path)

Training Epoch 1: 100%|██████████| 782/782 [01:16<00:00, 10.23it/s]


Epoch 1/5, Training Loss: 0.6502295922668998


Training Epoch 2: 100%|██████████| 782/782 [01:16<00:00, 10.25it/s]


Epoch 2/5, Training Loss: 0.5537692443717777


Training Epoch 3: 100%|██████████| 782/782 [01:16<00:00, 10.25it/s]


Epoch 3/5, Training Loss: 0.4922361841134708


Training Epoch 4: 100%|██████████| 782/782 [01:16<00:00, 10.25it/s]


Epoch 4/5, Training Loss: 0.4627986763368177


Training Epoch 5: 100%|██████████| 782/782 [01:16<00:00, 10.25it/s]


Epoch 5/5, Training Loss: 0.4430596483561694


## 4. Evaluate Model
Evaluate model on IMDB test dataset.

In [4]:
def evaluate(model, test_loader):
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids, labels = batch
            input_ids, labels = input_ids.to(config.device), labels.to(config.device)

            outputs = model(
                input_ids, make_src_mask(input_ids, vocab[PAD_TOKEN], config.device)
            )
            predictions = outputs.argmax(dim=-1)
            total_correct += (predictions == labels).sum().item()
            total_samples += predictions.size(0)
    return total_correct / total_samples

accuracy = evaluate(model, test_loader)
print(f"Test Accuracy: {accuracy}")

Evaluating: 100%|██████████| 782/782 [00:24<00:00, 32.39it/s]

Test Accuracy: 0.78504



