## 1. Import Modules and Data
If you haven't download IMDB dataset, please run `download_imdb.py` or download and unzip `aclImdb_v1.tar.gz` from [here](http://ai.stanford.edu/~amaas/data/sentiment).

In [None]:
from data_imdb import test_loader, train_loader, vocab, PAD_TOKEN, CLS_TOKEN
from modules import Encoder, Transformer, make_src_mask
import torch
from torch import nn
from config import *
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import os

checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

torch.manual_seed(3407)
device

## 2. Build Classifier Model
We only need to use the transformer encoder as a text feature extractor, and then use the CLS token attached to the beginning of each text to make predictions.

In [None]:
class SentimentClassifier(nn.Module):
    def __init__(self, encoder, d_model, device):
        super(SentimentClassifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(d_model, 2, device=device)

    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids, attention_mask)
        cls_token_output = encoder_output[:, 0, :]  # Extract CLS token
        logits = self.fc(cls_token_output)
        return logits

model = SentimentClassifier(
    Encoder(
        enc_voc_size=len(vocab),
        max_len=max_len,
        d_model=d_model,
        ffn_hidden=ffn_hidden,
        n_heads=n_heads,
        n_layers=n_layers,
        dropout=dropout,
        device=device,
    ),
    d_model,
    device=device
)

## 3. Train Model
Train model on IMDB training dataset for `epochs` epochs. All relavant training parameters can be found in [config.py](./config.py).

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=init_lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)


def train(model, train_loader, optimizer, criterion, scheduler, epoch):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
        input_ids, labels = batch
        input_ids, labels = input_ids.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, make_src_mask(input_ids, vocab[PAD_TOKEN], device))
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    scheduler.step()

    return avg_loss


for epoch in range(epochs):
    avg_train_loss = train(model, train_loader, optimizer, criterion, scheduler, epoch)
    print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {avg_train_loss}")

checkpoint_path = os.path.join(checkpoint_dir, f"imdb_ckpt.pth")
torch.save(model.state_dict(), checkpoint_path)

## 4. Evaluate Model
Evaluate model on IMDB test dataset.

In [None]:
def evaluate(model, test_loader):
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids, labels = batch
            input_ids, labels = input_ids.to(device), labels.to(device)

            outputs = model(
                input_ids, make_src_mask(input_ids, vocab[PAD_TOKEN], device)
            )
            predictions = outputs.argmax(dim=-1)
            total_correct += (predictions == labels).sum().item()
            total_samples += predictions.size(0)
    return total_correct / total_samples

accuracy = evaluate(model, test_loader)
print(f"Test Accuracy: {accuracy}")