### Setup

In [1]:
import iit.model_pairs as mp
import iit.utils.index as index
from circuits_benchmark.utils.get_cases import get_cases
from iit_utils import make_iit_hl_model, create_dataset
import random
from iit_utils.tracr_ll_corrs import get_tracr_ll_corr
from argparse import Namespace

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
case = 3

args = Namespace(
    command='compile', 
    indices=f'{case}', 
    output_dir='/Users/cybershiptrooper/src/interpretability/MATS/circuits-benchmark/results', 
    device='cpu', 
    seed=1234, 
    run_tests=False, 
    tests_atol=1e-05, 
    fail_on_error=False, 
    original_args=['compile', f'-i={case}', '-f'])

In [3]:
cases = get_cases(args)
case = cases[0]
if not case.supports_causal_masking():
    raise NotImplementedError(f"Case {case.get_index()} does not support causal masking")

tracr_output = case.build_tracr_model()
hl_model = case.build_transformer_lens_model()

In [4]:
train_data, test_data = create_dataset(case, hl_model)

In [15]:
from typing import Any


class TracrHLNode(mp.HLNode):
    label: str

    def __init__(self, name, label, num_classes, idx=None):
        # type checks
        assert isinstance(name, str), ValueError(f"name is not a string, but {type(name)}")
        assert isinstance(label, str), ValueError(f"label is not a string, but {type(label)}")
        assert isinstance(num_classes, int), ValueError(f"num_classes is not an int, but {type(num_classes)}")
        assert idx is None or isinstance(idx, index.TorchIndex), ValueError(f"index is not a TorchIndex, but {type(index)}")
        super().__init__(name, num_classes, idx)
        self.label = label

    def get_label(self) -> str:
        return self.label

    def get_name(self) -> str:
        return self.name
    
    def __hash__(self) -> int:
        return super().__hash__() + hash(self.label)
    
    def __str__(self) -> str:
        return super().__str__()
    
    def __repr__(self) -> str:
        return f"TracrHLNode(name: {self.name},\n label: {self.label},\n classes: {self.num_classes},\n index: {self.index}\n)"

# test the TracrHLNode
node = TracrHLNode('name', 'label', 10)
assert node.get_label() == 'label'
assert str(node) == node.get_name()
assert node.index is None
node

TracrHLNode(name: name,
 label: label,
 classes: 10,
 index: None
)

In [6]:
from tracr.compiler.compiling import TracrOutput
from tracr.craft.transformers import MLP, MultiAttentionHead, SeriesWithResiduals
import networkx as nx

class TracrHLCorrespondence:
    """
    Stores a dictionary that takes tracr graph nodes to HookPoint nodes...
    """

    def __init__(self, graph: nx.DiGraph, craft_model: SeriesWithResiduals):
        self.graph = graph
        self.craft_model = craft_model

        self._dict = dict()
        i = 0
        for block in craft_model.blocks:
            if isinstance(block, MLP):
                assert block.fst.output_space == block.snd.input_space
                for direction in block.snd.output_space.basis:
                    self._dict[direction] = (i, "mlp", None)
                i += 1
            elif isinstance(block, MultiAttentionHead):
                for j, sb in enumerate(block.sub_blocks):
                    for direction in sb.w_ov.output_space.basis:
                        self._dict[direction] = (i, "attn", j)
                
            else:
                raise ValueError(f"Unknown block type {type(block)}")

        """
    This is not necessary for now, because node names and direction names are equal
    """
        # for name, node in graph.nodes.items():
        #   if name in ["indices", "tokens"]: continue
        #   if "OUTPUT_BASIS" in node:
        #     for direction in node["OUTPUT_BASIS"]:
        #       self._dict[direction] = self._unit_output_bases[direction]

    @staticmethod
    def from_output(tracr_output: TracrOutput):
        return TracrHLCorrespondence(tracr_output.graph, tracr_output.craft_model)

    def __getitem__(self, key):
        return self._dict[key]

    def __repr__(self) -> str:
        dict_repr = "\n".join([f"\t{k.name, k.value}: {v}" for k, v in self._dict.items()])
        return f"TracrCorrespondence(\n{dict_repr}\n)"

    def items(self):
        return self._dict.items()

In [9]:
from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase

class TracrCorr(dict):
    def __setattr__(self, __name: TracrHLNode, __value: set[mp.LLNode]) -> None:
        assert isinstance(__name, TracrHLNode), ValueError(f"__name is not a TracrHLNode, but {type(__name)}")
        assert isinstance(__value, set), ValueError(f"__value is not a set, but {type(__value)}")
        assert all(isinstance(v, mp.LLNode) for v in __value), ValueError(f"__value contains non-LLNode elements")
        super().__setattr__(__name, __value)

    @classmethod
    def make_hl_ll_corr(cls,
                        tracr_hl_corr: TracrHLCorrespondence,
                        tracr_ll_corr: dict[str, set[mp.LLNode]],
                        hook_name_style: str="tl"):
        def hook_name(loc, style) -> str:
            layer, attn_or_mlp, unit = loc
            assert attn_or_mlp in ["attn", "mlp"], ValueError(f"Unknown attn_or_mlp {attn_or_mlp}")
            if style == "tl":
                return f"blocks.{layer}.{attn_or_mlp}.{'hook_result' if attn_or_mlp == 'attn' else 'hook_post'}"
            elif style == "wrapper":
                return f"mod.blocks.{loc}.mod.{attn_or_mlp}.hook_point"
            else:
                raise ValueError(f"Unknown style {style}")

        def idx(loc):
            _, attn_or_mlp, unit = loc
            if isinstance(unit, index.TorchIndex):
                return unit
            assert attn_or_mlp in ["attn", "mlp"], ValueError(f"Unknown attn_or_mlp {attn_or_mlp}")
            if attn_or_mlp == "attn":
                return index.Ix[:, :, unit, :]
            assert unit is None
            return index.Ix[[None]]
        
        if tracr_ll_corr is None:
            print("WARNING: tracr_ll_corr is None, returning an Identity correspondence using HL Model")
            return cls({
                TracrHLNode(hook_name(hl_loc, hook_name_style), 
                            label=basis_dir.name,
                            num_classes=0, # TODO: get num_classes
                            idx = idx(hl_loc)): 
                {
                    mp.LLNode(hook_name(hl_loc, hook_name_style), idx(hl_loc), None)
                }
                for basis_dir, hl_loc in tracr_hl_corr.items()
            })

        return cls({
           TracrHLNode(hook_name(hl_loc, hook_name_style), 
                            label=basis_dir.name,
                            num_classes=0, # TODO: get num_classes
                            idx = idx(hl_loc)): 
                            {
                mp.LLNode(hook_name(ll_loc, "tl"), idx(ll_loc)) for ll_loc in tracr_ll_corr[basis_dir.name, basis_dir.value]
            }
            for basis_dir, hl_loc in tracr_hl_corr.items()
        })
    
    @classmethod
    def from_case(case: BenchmarkCase, tracr_output: TracrOutput):
        tracr_hl_corr = TracrHLCorrespondence.from_output(tracr_output)
        tracr_ll_corr = get_tracr_ll_corr(case)
        return TracrCorr.make_hl_ll_corr(tracr_hl_corr, tracr_ll_corr)
    
from collections import namedtuple

class TracrEdgeCorr(list[namedtuple]):
    EdgeCorr = namedtuple("EdgeCorr", ["hookpoint", "index1", "hookpoint2", "index2"])
    

In [16]:

# this is the graph node -> hl node correspondence
tracr_hl_corr = TracrHLCorrespondence.from_output(tracr_output)
tracr_ll_corr = get_tracr_ll_corr(case)

hl_ll_corr = TracrCorr.make_hl_ll_corr(tracr_hl_corr, tracr_ll_corr)

In [19]:
print(hl_ll_corr)

{TracrHLNode(name: blocks.0.mlp.hook_post,
 label: is_x_3,
 classes: -1,
 index: [:]
): {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, TracrHLNode(name: blocks.1.attn.hook_result,
 label: frac_prevs_1,
 classes: -1,
 index: [:, :, 0, :]
): {LLNode(name='blocks.1.attn.hook_result', index=[:, :, 1:3, :], subspace=None)}}
