## 1. Import Modules and Data
If you haven't download IMDB dataset, please run `download_imdb.py` or download and unzip `aclImdb_v1.tar.gz` from [here](http://ai.stanford.edu/~amaas/data/sentiment).

In [8]:
from data_imdb import test_dataset, tokenizer, vocab, PAD_TOKEN
from modules import Encoder,  make_src_mask
import torch
from torch import nn
import config
from tqdm import tqdm
import random
import os

torch.manual_seed(3407)
config.device

device(type='cuda', index=0)

## 2. Build Classifier Model
We only need to use the transformer encoder as a text feature extractor, and then use the CLS token attached to the beginning of each text to make predictions.

In [9]:
class SentimentClassifier(nn.Module):
    def __init__(self, encoder, d_model, device):
        super(SentimentClassifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(d_model, 2, device=device)

    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids, attention_mask)
        cls_token_output = encoder_output[:, 0, :]  # Extract CLS token
        logits = self.fc(cls_token_output)
        return logits
    
    
model = SentimentClassifier(
    Encoder(
        enc_voc_size=len(vocab),
        max_len=config.max_len,
        d_model=config.d_model,
        ffn_hidden=config.ffn_hidden,
        n_head=config.n_head,
        n_layer=config.n_layer,
        dropout=config.dropout,
        device=config.device,
    ),
    config.d_model,
    device=config.device
)

model.load_state_dict(torch.load(os.path.join(config.checkpoint_dir, 'imdb_ckpt.pth')))
model.eval()

SentimentClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): Embedding(89530, 512)
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiheadAttention(
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_concat): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln_1): LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
       

## 3. Inference


In [10]:
num_samples = 5
indices = random.sample(range(len(test_dataset)), num_samples)
sample_data = [test_dataset[i] for i in indices]

input_ids = [item[0] for item in sample_data]
labels = [item[1] for item in sample_data]

input_ids = torch.nn.utils.rnn.pad_sequence(
    input_ids, batch_first=True, padding_value=vocab[PAD_TOKEN]
)
labels = torch.stack(labels)
attention_mask = make_src_mask(input_ids, vocab[PAD_TOKEN], config.device)

input_ids, attention_mask, labels = (
    input_ids.to(config.device),
    attention_mask.to(config.device),
    labels.to(config.device),
)

with torch.no_grad():
    outputs = model(input_ids, attention_mask)
    predictions = outputs.argmax(dim=-1)

for i in range(num_samples):
    print(f"Text: {tokenizer.decode(input_ids[i].cpu().numpy())}")
    print(f"True Label: {labels[i].item()}, Predicted Label: {predictions[i].item()}\n")

Text: <cls> a strange relationship between a middle-aged woman and a transsexual who gonna be a woman <unk> charlotte and <unk> both trapped by their inanimate lives and don't know how to get out of <unk> charlotte is an owner of a beauty <unk> she has broken up with her aggressive <unk> moved into an apartment alone with all the furniture packed except her big <unk> veronica lives downstairs with her poor <unk> <unk> sensitive and desperately bothered by her <unk> visiting and the bad relationship with her <unk> her only hope is that the upcoming transsexual operation will turn her into a real woman and then everything will be <unk> all she can do now is waiting for an approval <unk> <unk> <unk> these two individuals meet by chance and gradually they are all involved into <unk> <unk> there are some sparkles between <unk> but no one is brave enough to face the truth because they are not willing to accept the change as most people <unk> eventually the ending is quite satisfying and leav