In [4]:
import torch
from torch.utils.data.dataloader import DataLoader

from EWOthello.data.othello import *
from EWOthello.mingpt.dataset import CharDataset # AK's mingpt data child 
from EWOthello.mingpt.model import GPT, GPTConfig, GPTforProbing # AKs and KLi GPT models
from EWOthello.mingpt.trainer import Trainer, TrainerConfig # AKs GPT trainer
from EWOthello.mingpt.utils import set_seed, sample # AKs helpers for sampling predictions

set_seed(44)

In [2]:
othello = get(ood_num=-1, data_root=None)
train_dataset = CharDataset(othello) 

mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

Mem Used: 14.64 GB: 100%|██████████| 231/231 [00:42<00:00,  5.50it/s]


Deduplicating...
Deduplicating finished with 23096133 games left
Using 20 million for training, 3096133 for validation
Dataset created has 20000000 sequences, 61 unique words.


In [6]:
# This is a truncated GPT model (with random or pre-trained weights) that will be used to get activations from after attention modules
model = GPTforProbing(mconf, probe_layer=2)

mode = "synthetic"
# Apply the GPT model weights
if mode=="random":
    model.apply(model._init_weights)
else:
    path = "../EWOthello/ckpts/gpt_championship.ckpt" if mode=="championship" else "../EWOthello/ckpts/gpt_synthetic.ckpt"
    model.load_state_dict(torch.load(path))

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)

In [None]:
loader = DataLoader(train_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)
act_container = []
property_container = []
for x, y in tqdm(loader, total=len(loader)):
    tbf = [train_dataset.itos[_] for _ in x.tolist()[0]]
    valid_until = tbf.index(-100) if -100 in tbf else 999
    a = OthelloBoardState()
    properties = a.get_gt(tbf[:valid_until], "get_" + args.exp)  # [block_size, ]
    act = model(x.to(device))[0, ...].detach().cpu()  # [block_size, f]
    act_container.extend([_[0] for _ in act.split(1, dim=0)[:valid_until]])
    property_container.extend(properties)

age_container = []
for x, y in tqdm(loader, total=len(loader)):
    tbf = [train_dataset.itos[_] for _ in x.tolist()[0]]
    valid_until = tbf.index(-100) if -100 in tbf else 999
    a = OthelloBoardState()
    ages = a.get_gt(tbf[:valid_until], "get_age")  # [block_size, ]
    age_container.extend(ages)

