In [13]:
# Specify location of "othello_world" repository - https://github.com/likenneth/othello_world
OTHELLO_ROOT = "/media/home/alex/ongoing/othello_world/"
import sys
sys.path.append(OTHELLO_ROOT)

In [14]:
# Import stuff
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
from pathlib import Path

from transformer_lens import HookedTransformer, HookedTransformerConfig

from tl_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState
import plotly.express as px

from othello_data import OthelloBoardState, OthelloDataset

from mingpt.dataset import CharDataset
from mingpt.model import GPT, GPTConfig, GPTforProbing
from mingpt.probe_trainer import Trainer, TrainerConfig
from mingpt.probe_model import BatteryProbeClassification, BatteryProbeClassificationTwoLayer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f48275bc160>

# Othello Board Representation

In [16]:
# Stuff for manipulating the Othello board representation

stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

def to_idx(board_label):
    return alpha.index(board_label[0])*8 + int(board_label[1])

board_labels = list(map(to_board_label, stoi_indices))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        kwargs.update({'color_continuous_scale': 'RdBu', 'color_continuous_midpoint': 0.})
    else:
        kwargs.update({'color_continuous_scale': 'Blues', 'color_continuous_midpoint': None})
    if isinstance(state, torch.Tensor):
        state = state.detach().cpu().numpy()
    return px.imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], **kwargs)


In [17]:
# Play around with an Othello Board object
board = OthelloBoardState()

move_to_make = to_idx("F4")
board.update([move_to_make]) # Can use prt=True flag for info

print('Valid Moves: ', list(map(to_board_label, board.get_valid_moves())))

plot_square_as_board(board.state, title="Example Board State (+1 is Black, -1 is White)")

Valid Moves:  ['D5', 'F3', 'F5']


# Loading Othello-GPT Model into TransformerLens
Use Neel's helper function to load the Othello-GPT checkpoint into a TL HookedTransformer

In [18]:
def convert_to_transformer_lens_format(in_sd, n_layers=8, n_heads=8):
    out_sd = {}
    out_sd["pos_embed.W_pos"] = in_sd["pos_emb"].squeeze(0)
    out_sd["embed.W_E"] = in_sd["tok_emb.weight"]

    out_sd["ln_final.w"] = in_sd["ln_f.weight"]
    out_sd["ln_final.b"] = in_sd["ln_f.bias"]
    out_sd["unembed.W_U"] = in_sd["head.weight"].T

    for layer in range(n_layers):
        out_sd[f"blocks.{layer}.ln1.w"] = in_sd[f"blocks.{layer}.ln1.weight"]
        out_sd[f"blocks.{layer}.ln1.b"] = in_sd[f"blocks.{layer}.ln1.bias"]
        out_sd[f"blocks.{layer}.ln2.w"] = in_sd[f"blocks.{layer}.ln2.weight"]
        out_sd[f"blocks.{layer}.ln2.b"] = in_sd[f"blocks.{layer}.ln2.bias"]

        out_sd[f"blocks.{layer}.attn.W_Q"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.query.weight"], "(head d_head) d_model -> head d_model d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_Q"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.query.bias"], "(head d_head) -> head d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.W_K"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.key.weight"], "(head d_head) d_model -> head d_model d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_K"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.key.bias"], "(head d_head) -> head d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.W_V"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.value.weight"], "(head d_head) d_model -> head d_model d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_V"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.value.bias"], "(head d_head) -> head d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.W_O"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.proj.weight"], "d_model (head d_head) -> head d_head d_model", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_O"] = in_sd[f"blocks.{layer}.attn.proj.bias"]

        out_sd[f"blocks.{layer}.mlp.b_in"] = in_sd[f"blocks.{layer}.mlp.0.bias"]
        out_sd[f"blocks.{layer}.mlp.W_in"] = in_sd[f"blocks.{layer}.mlp.0.weight"].T
        out_sd[f"blocks.{layer}.mlp.b_out"] = in_sd[f"blocks.{layer}.mlp.2.bias"]
        out_sd[f"blocks.{layer}.mlp.W_out"] = in_sd[f"blocks.{layer}.mlp.2.weight"].T
    
    return out_sd


In [19]:
EPOCH = 1
CKPT_PATH = '/media/home/alex/ongoing/othello_world/ckpts/gpt_at_20230627_201626_epoch-EPOCH.ckpt'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

synthetic_checkpoint = torch.load(CKPT_PATH.replace("EPOCH", str(EPOCH)), map_location=DEVICE)
# for name, param in synthetic_checkpoint.items():
#     if name.startswith("blocks.0") or not name.startswith("blocks"):
#         print(name, param.shape)

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=DEVICE
)

model = HookedTransformer(cfg)
model.load_and_process_state_dict(convert_to_transformer_lens_format(synthetic_checkpoint))

## Validation Test on model
Test the model (trained on Synthetic Data) against **valid moves** on the Championship Dataset

In [20]:
from torch.utils.data import DataLoader
OTHELLO_ROOT = Path(OTHELLO_ROOT)
othello_data = OthelloDataset(data_root=OTHELLO_ROOT/"data/othello_championship")

# Load the championship dataset and use it as a test dataset
train_dataset = CharDataset(othello_data)
loader = DataLoader(train_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)

Loading championship games
Loaded 465/465 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2021.pgn
Loaded 644/645 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2016.pgn
Loaded 406/407 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2014.pgn
Loaded 850/850 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2018.pgn
Loaded 326/327 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2010.pgn
Loaded 303/303 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2011.pgn
Loaded 892/893 (qualified/total) sequences from /media/home/alex/ongoing/othello_world/data/othello_championship/liveothello2017.pgn
Loaded 950/950 (qualified/total) sequences

In [21]:
# -- Do a single forward pass with the model --
model_out = model(loader.dataset[0][0].unsqueeze(0).to(DEVICE))

logit_vec = model_out[0, -1] # 0th element of batch, final element of sequence
log_probs = logit_vec.log_softmax(-1)

# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60 # all examples should have 60 moves (padded with -100 [?] if game ended beforehand)

# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state = torch.full((64,), -15, device=logit_vec.device, dtype=logit_vec.dtype)
temp_board_state[stoi_indices] = log_probs

plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

In [22]:
# -- Run Validation -- #
valid_move = 0
num_examples = 0

for i, sample in tqdm.tqdm(enumerate(iter(loader))):
    actual_moves, shifted_inp = sample
    model_rollout = model(actual_moves)

    model_moves = model_rollout[0].argmax(dim=-1).tolist()
    model_moves_parsed = [train_dataset.itos[int(i)] for i in model_moves]# if i != -1]
    actual_moves_parsed = [train_dataset.itos[int(i)] for i in actual_moves[0].tolist()]# if i != -1]
    
    # Ask whether for each token, the predicted next token was valid
    board = OthelloBoardState()
    for move_idx in range(0,model_rollout.shape[1]):
        #! Update board state first
        actual_move = actual_moves_parsed[move_idx]
        if actual_move == -100: #! Not a draw (i.e. reached padding)
            break 
        board.update([actual_move], prt=False)
        
        model_move = model_moves_parsed[move_idx]

        # Check if the move is valid given the board state
        valid_moves = board.get_valid_moves()

        if model_move in valid_moves:
            valid_move += 1
        num_examples += 1

    if i > 100:
        break  
        
print(f'Accuracy on {num_examples} tokens: {valid_move/num_examples}')

0it [00:00, ?it/s]

Accuracy on 5987 tokens: 0.9752797728411559


# Linear Probing of Board State
Will do two experiments:
1. **Phase Transitions in WM**: To investigate phase transitions in the emergence of the linear world model, look at how Probes trained for a 'good' checkpoint behave at earlier checkpoints.
2. **Hiearchical Learning of WM**: To investigate whether the world model is learned hierarchically, train probes on multiple layers and checkpoints, and see how their accuracies change until a 'good' checkpoint is reached

In [25]:
import os
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

import time
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import argparse
import torch
import torch.nn as nn

from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from data import get_othello
from data.othello import permit, start_hands, OthelloBoardState

from mingpt.dataset import CharDataset
from mingpt.model import GPT, GPTConfig, GPTforProbing
from mingpt.probe_trainer import Trainer, TrainerConfig
from mingpt.probe_model import BatteryProbeClassification, BatteryProbeClassificationTwoLayer

In [26]:
# Experiment Configuration
layer = -1
epochs = 16
mid_dim = 128
experiment_name = 'state' # one of [occupied, state, next_hand_color] ; age is also collected by default
folder_name = f"data/probing/{experiment_name}"
num_examples_to_train_on = 5000 # ~ *30 for the number of training points 

In [27]:
# Load Dataset and models
synthetic_data = OthelloDataset(data_root=OTHELLO_ROOT/"data/othello_synthetic", max_samples=num_examples_to_train_on)
train_dataset = CharDataset(synthetic_data)
loader = DataLoader(train_dataset, shuffle=True, pin_memory=True, batch_size=1, num_workers=1)

mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=cfg.n_layers, n_head=cfg.n_heads, n_embd=cfg.d_model)
model = GPTforProbing(mconf, probe_layer=layer).to(DEVICE)

Loading synthetic


0it [00:00, ?it/s]

Loading gen10e5__20220324_154801.pickle


Mem Used: 2.339 GB: : 1it [00:10, 10.38s/it]


Loading gen10e5__20220324_154722.pickle
Deduplicating...
Deduplicating finished with 99999 games left
Using 20 million for training, 0 for validation
Dataset created has 99999 sequences, 61 unique words.


In [28]:
act_container = []
age_container = []
property_container = []

# Collect dataset for probe training 
for x, y in tqdm(loader, total=min(num_examples_to_train_on, len(loader))):
    tbf = [train_dataset.itos[_] for _ in x.tolist()[0]]
    valid_until = tbf.index(-100) if -100 in tbf else 999
    board = OthelloBoardState()

    # Get properties and age
    properties = []
    ages = []
    for _, move in enumerate(tbf[:valid_until]):
        board.umpire(move)
        properties.append(getattr(board, experiment_name))
        ages.append(board.get_age())

    act = model(x.to(DEVICE))[0, ...].detach().cpu()  # [block_size, f]
    act_container.extend([_[0] for _ in act.split(1, dim=0)[:valid_until]])

26089it [04:02, 107.80it/s]                         


KeyboardInterrupt: 

In [None]:
class ProbingDataset(Dataset):
    def __init__(self, act, y, age):
        assert len(act) == len(y)
        assert len(act) == len(age)
        print(f"{len(act)} pairs loaded...")
        self.act = act
        self.y = y
        self.age = age
        print(np.sum(np.array(y)==0), np.sum(np.array(y)==1), np.sum(np.array(y)==2))
        
        long_age = np.array([a for a in age]) #! ?
        counts = [np.count_nonzero(long_age == i) for i in range(60)]
        del long_age
        print(counts)
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.act[idx], torch.tensor(self.y[idx]).to(torch.long), torch.tensor(self.age[idx]).to(torch.long)

In [None]:
# Construct probe dataset

probing_dataset = ProbingDataset(act_container, property_container, age_container)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
sampler = None
probe_train_loader = DataLoader(prbobe_train_dataset, shuffle=False, sampler=sampler, pin_memory=True, batch_size=128, num_workers=1)
probe_test_loader = DataLoader(probe_test_loader, shuffle=True, pin_memory=True, batch_size=128, num_workers=1)

In [None]:
# Train probes on the same dataset as our model - synthetic 
#! FOR layer in model

# Single Layer Probes
#! probe classes - 1=black, 2=white, 3=empty
probe = BatteryProbeClassification(DEVICE, probe_class=2, num_task=64)

t_start = time.strftime("_%Y%m%d_%H%M%S")
tconf = TrainerConfig(
    max_epochs=epochs, batch_size=1024, learning_rate=1e-3,
    betas=(.9, .999),
    lr_decay=True, warmup_tokens=len(train_dataset)*5,
    final_tokens=len(train_dataset)*max_epochs,
    num_workers=4, weight_decay=0.,
    ckpt_path=os.path.join("./data/ckpts/", folder_name, f"layer_{layer}")
)

trainer = Trainer(probe, probe_train_dataset, probe_test_dataset, tconf)
trainer.train(prt=True)
trainer.save_traces() 
trainer.save_checkpoint()

Loading synthetic


Mem Used: 17.02 GB: : 1it [00:00,  7.84it/s]

Loading gen10e5__20220324_154801.pickle
Loading gen10e5__20220324_154722.pickle
Deduplicating...
Deduplicating finished with 99999 games left
Using 20 million for training, 0 for validation





Dataset created has 99999 sequences, 61 unique words.


42320it [06:58, 101.13it/s]                         


KeyboardInterrupt: 

## Phase Transitions in WM
For this we can repurpose existing code from Neel and the original authors