# 初期設定・インポート

In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import math
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from matplotlib import pyplot as plt

In [None]:
# ログ出力の設定
import logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

In [None]:
# 乱数シードの固定
from mingpt.utils import set_seed
set_seed(42)

# パラメータ

In [None]:
layer = 3
epo = 16
mid_dim = 128
twolayer = True
random_flag = False
championship = True
exp = "state"
folder_name = f"battery_othello/{exp}"
if twolayer:
    folder_name += f"_tl{mid_dim}"
if random_flag:
    folder_name += "_random"
if championship:
    folder_name += "_championship"

In [None]:
print(f"Running experiment for {folder_name}")

# データセットの読み込み

In [None]:
from data import get_othello
from mingpt.dataset import CharDataset

In [None]:
# チャンピオンシップ用のデータディレクトリを設定
othello = get_othello(data_root="data/othello_championship")
train_dataset = CharDataset(othello)

# モデルの定義

In [None]:
from mingpt.model import GPTConfig, GPTforProbing
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
model = GPTforProbing(mconf, probe_layer=layer)
if random_flag:
    model.apply(model._init_weights)
elif championship:
    model.load_state_dict(torch.load("./ckpts/gpt_championship.ckpt"))
else:
    model.load_state_dict(torch.load("./ckpts/gpt_synthetic.ckpt"))
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)

# データローダの準備と属性抽出

In [None]:
from torch.utils.data.dataloader import DataLoader
loader = DataLoader(train_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)
act_container = []
property_container = []
for x, y in tqdm(loader, total=len(loader)):
    tbf = [train_dataset.itos[idx] for idx in x.tolist()[0]]
    valid_until = tbf.index(-100) if -100 in tbf else 999
    # 例: ここでは OthelloBoardState を使用しているが、必要に応じて数独用に変更する
    from data.othello import OthelloBoardState
    a = OthelloBoardState()
    properties = a.get_gt(tbf[:valid_until], "get_" + exp)
    act = model(x.to(device))[0, ...].detach().cpu()
    act_container.extend([chunk[0] for chunk in act.split(1, dim=0)[:valid_until]])
    property_container.extend(properties)

# 年齢情報の抽出

In [None]:
age_container = []
for x, y in tqdm(loader, total=len(loader)):
    tbf = [train_dataset.itos[idx] for idx in x.tolist()[0]]
    valid_until = tbf.index(-100) if -100 in tbf else 999
    from data.othello import OthelloBoardState
    a = OthelloBoardState()
    ages = a.get_gt(tbf[:valid_until], "get_age")
    age_container.extend(ages)

# プローブ用のデータセット生成

In [None]:
from torch.utils.data import Dataset
class ProbingDataset(Dataset):
    def __init__(self, act, y, age):
        assert len(act) == len(y)
        assert len(act) == len(age)
        print(f"{len(act)} pairs loaded...")
        self.act = act
        self.y = y
        self.age = age
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.act[idx], torch.tensor(self.y[idx]).to(torch.long), torch.tensor(self.age[idx]).to(torch.long)
    
probing_dataset = ProbingDataset(act_container, property_container, age_container)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
train_dataset_sub, test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset_sub, shuffle=False, batch_size=128, num_workers=1)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=128, num_workers=1)

# トレーニングの設定と実施

In [None]:
from mingpt.probe_trainer import Trainer, TrainerConfig
max_epochs = epo
t_start = time.strftime("_%Y%m%d_%H%M%S")
tconf = TrainerConfig(
    max_epochs=max_epochs, batch_size=1024, learning_rate=1e-3,
    betas=(.9, .999), 
    lr_decay=True, warmup_tokens=len(train_dataset_sub)*5, 
    final_tokens=len(train_dataset_sub)*max_epochs,
    num_workers=4, weight_decay=0., 
    ckpt_path=os.path.join("./ckpts/", folder_name, f"layer{layer}")
)
trainer = Trainer(probe, train_dataset_sub, test_dataset, tconf)
trainer.train(prt=True)
trainer.save_traces()
trainer.save_checkpoint()