## 1. Import Modules and Data
We use a subset of [Children's Book Test(CST)](https://arxiv.org/pdf/1511.02301) dataset to evaluate the zero-shot ability of GPT-2. You can learn more and download from [cam-cst/cbt](https://huggingface.co/datasets/cam-cst/cbt).

As original paper reported, GPT-2 obtained accuracy of 87.65 which outperformed SoTA at the time. We are going to reproduce this result in this notebook.

The data retrieved from data loader has been tokenized and converted to indices corresponding to the vocab. Since cam-cst/cbt ensures that each row in "options" column contains 10 options, sentences will be shaped in `[batch_size, 10, seq_len]`.

In [1]:
from data import load_data

tokenizer, test_dataloader = load_data("cbt", config_name="CN", splits=["test"])

## 2. Load Pre-trained Model

In [2]:
import torch

import config
from modules import GPT2

device = torch.device("cuda:0")

model = GPT2.from_pretrained(str(config.pretrained_dir)).to(device)
model.eval()

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, max_len=1024
number of parameters: 123.65M


GPT2(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): L

## 3. Generate Ability
You can compare the generate ability with [GPT](https://github.com/Kami-chanw/SeekDeeper/blob/main/gpt/inference.ipynb)

In [3]:
text = "Hi, I'm Kami-chanw, a"
ids = tokenizer.encode(text)
ids_tensor = torch.tensor(ids, dtype=torch.long)
tokenizer.decode(
    model.generate(ids_tensor.to(device), max_new_tokens=20).tolist(),
    skip_special_tokens=True,
)[0]

"Hi, I'm Kami-chanw, a writer for Kamen Rider Gaiden 2 and Kadokawa, and this review of Victory Victory several"

## 4. Test on CST
Unlike GPT-1, GPT-2 can be tested on the CBT dataset without any fine-tuning.

The Sec 3.2 in original paper introduced how to test on CBT:

> Following the LM approach introduced in the original paper, we compute the
probability of each choice and the rest of the sentence conditioned on this choice according to the LM, and predict
the one with the highest probability.

We end up with an accuracy of 86.84, which is close to the 87.65 reported in the paper.

In [5]:
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

from data import pad_idx

torch.cuda.empty_cache()

all_preds = []
all_labels = []

for texts, answer_idx in tqdm(test_dataloader, desc="Testing: "):
    texts, answer_idx = texts.to(device), answer_idx.to(device)
    texts = texts[:, :, -config.max_len :]
    batch_size, num_options, seq_len = texts.shape
    labels = texts.masked_fill(texts == pad_idx, -100)
    with torch.no_grad():
        # [batch_size * num_options, seq_len]
        lm_logits = model(texts.view(-1, seq_len), (texts != pad_idx).view(-1, seq_len))
        shift_logits = lm_logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        loss = F.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            reduction="none",
        )
        loss_per_option = loss.view(batch_size, num_options, -1).sum(dim=-1)
        max_idx = torch.argmin(loss_per_option, dim=1)

    all_preds.extend(max_idx.cpu().numpy())
    all_labels.extend(answer_idx.cpu().numpy())

accuracy = accuracy_score(all_labels, all_preds)
print(f"Accuracy: {accuracy * 100:.4f}")

Testing:   0%|          | 0/2500 [00:00<?, ?it/s]

Accuracy: 86.8400
