## 1. Import Modules and Data

In [1]:
from data import load_data

tokenizer, *_ = load_data("sst2")

## 2. Load Trained Model

In [2]:
import torch

import config
from modules import GPT, GPTClassifier

device_gpt_clf = torch.device("cuda:2")
device_gpt = torch.device("cuda:1")

gpt = GPT.from_pretrained(config.pretrained_dir, num_frozen_layers=12).to(device_gpt)
gpt.eval()

n_classes = 2  # sst2 is a binary classification task
gpt_clf = GPTClassifier.from_pretrained(
    config.pretrained_dir,
    num_frozen_layers=12,
    n_classes=n_classes,
    vocab_size=40478 + 3,  # add 3 special tokens
).to(device_gpt_clf)
gpt_clf.load_state_dict(torch.load(config.checkpoint_dir / "gpt_clf_3.pth")["model"])
gpt.eval()

number of parameters: 116.14M
number of parameters: 116.15M


GPT(
  (transformer): ModuleDict(
    (tokens_embed): Embedding(40478, 768)
    (positions_embed): Embedding(512, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=40478, bias

## 3. Compare Generate Ability


In [3]:
text = "Hi, I'm Kami-chanw, a"
ids = tokenizer.encode(text, verbose=False)
ids_tensor = torch.tensor(ids, dtype=torch.long)
print(
    "Pretrained GPT output: ",
    tokenizer.decode(
        gpt.generate(ids_tensor.to(device_gpt), max_new_tokens=10).tolist(),
        skip_special_tokens=True,
    ),
)
print(
    "Fine-tuned GPT output: ",
    tokenizer.decode(
        gpt_clf.generate(ids_tensor.to(device_gpt_clf), max_new_tokens=10).tolist(),
        skip_special_tokens=True,
    ),
)

Pretrained GPT output:  ["hi , i 'm kami - chanw , a hi hi hi hi hi hi hi hi hi hi "]
Fine-tuned GPT output:  ["hi , i 'm kami - chanw , a ya ya . take only the ease that the country "]
