In [None]:
%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

In [None]:
%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

In [None]:
import wandb
from pathlib import Path
import os
import yaml
import s3fs
import circuitsvis as cv

from omegaconf import OmegaConf

import torch

from di_automata.config_setup import *
from di_automata.constructors import (
    construct_model, 
    create_dataloader_hf,
)
from di_automata.tasks.data_utils import take_n

# AWS
s3 = s3fs.S3FileSystem()

In [None]:
"""Notebook should be run under assumption that logits are already loaded in disk."""
config_file_path = "configs/slt_config.yaml"
slt_config = OmegaConf.load(config_file_path)

with open(f"configs/task_config/{slt_config.dataset_type}.yaml", 'r') as file:
    task_config = yaml.safe_load(file)
    
OmegaConf.set_struct(slt_config, False) # Allow new configuration values to be added
slt_config["task_config"] = task_config
# Convert OmegaConf object to MainConfig Pydantic model for dynamic type validation - NECESSARY DO NOT SKIP
pydantic_config = PostRunSLTConfig(**slt_config)
# Convert back to OmegaConf object for compatibility with existing code
slt_config = OmegaConf.create(pydantic_config.model_dump())

print(task_config["dataset_type"])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Run path and name for easy referral later
run_path = f"{slt_config.entity_name}/{slt_config.wandb_project_name}"
run_name = slt_config.run_name

# Get run information
api = wandb.Api(timeout=3000)
run_list = api.runs(
    path=run_path, 
    filters={
        "display_name": run_name,
        "state": "finished",
        },
    order="created_at", # Default descending order so backwards in time
)
assert run_list, f"Specified run {run_name} does not exist"
run_api = run_list[slt_config.run_idx]
try: history = run_api.history()
except: history = run_api.history
loss_history = history["Train Loss"]
accuracy_history = history["Train Acc"]
steps = history["_step"]
time = run_api.config["time"]

In [None]:
for run in run_list:
    print(run.name)

In [None]:
def get_config() -> MainConfig:
    """"
    Manually get config from run as artifact. 
    WandB also logs automatically for each run, but it doesn't log enums correctly.
    """
    artifact = api.artifact(f"{run_path}/config:{run_name}_{time}")
    data_dir = artifact.download()
    config_path = Path(data_dir) / "config.yaml"
    return OmegaConf.load(config_path)

In [None]:
# config: MainConfig = OmegaConf.create(run_api.config)
config = get_config()

# Set total number of unique samples seen (n). If this is not done it will break LLC estimator.
slt_config.rlct_config.sgld_kwargs.num_samples = slt_config.rlct_config.num_samples = config.rlct_config.sgld_kwargs.num_samples
slt_config.nano_gpt_config = config.nano_gpt_config

model, param_inf_properties = construct_model(config)

In [None]:
def restore_state_single_cp(cp_idx: int) -> dict:
    """Restore model state from a single checkpoint.
    Used in _load_logits_states() and _calculate_rlct().
    
    Args:
        idx_cp: index of checkpoint.
        
    Returns:
        model state dictionary.
    """
    idx = cp_idx * config.rlct_config.ed_config.eval_frequency
    print(config.model_save_method)
    match config.model_save_method:
        case "wandb":
            artifact = API.artifact(f"{run_path}/states:idx{idx}_{run_name}_{time}")
            data_dir = artifact.download()
            state_path = Path(data_dir) / f"states_{idx}.torch"
            states = torch.load(state_path)
        case "aws":
            with s3.open(f'{config.aws_bucket}/{run_name}_{time}/states_{idx}.pth', mode='rb') as file:
                states = torch.load(file)
    return states["model"]

In [None]:
current_directory = Path().absolute()
logits_file_path = current_directory.parent / f"di_automata/logits_{run_name}_{time}"
print(logits_file_path)

In [None]:
# if os.path.exists(logits_file_path):
#     print(f"Loading existing logits from {logits_file_path}")
#     ed_logits = torch.load(logits_file_path)
#     print("Done loading existing logits")
cusp_idx = 620
cp_idx = cusp_idx // config.rlct_config.ed_config.eval_frequency
print(cp_idx)
print(time)

state = restore_state_single_cp(cp_idx)
model.load_state_dict(state)

In [None]:
print(model)

In [None]:
ed_loader = create_dataloader_hf(config, deterministic=True) # Make sure deterministic to see same data

In [None]:
for data in take_n(ed_loader, 1):
    inputs = data["input_ids"]
    logits, cache = model.run_with_cache(inputs)

In [None]:
att_1 = cache["pattern", 0, "attn"]
print(att_1)
display(cv.attention.attention_patterns(
    tokens=inputs,
    attention=att_1,
    attention_head_names=[f"L0H{i}" for i in range(12)],
))