In [69]:
from circuits_benchmark.utils.get_cases import get_cases
from circuits_benchmark.commands.build_main_parser import build_main_parser
from iit_utils import create_dataset
from transformer_lens import HookedTransformer, HookedTransformerConfig

task = 3
args, _ = build_main_parser().parse_known_args(
    [
        "compile",
        f"-i={task}",
        "-f",
    ]
)

cases = get_cases(args)
case = cases[0]

hl_model = case.build_transformer_lens_model()

train_data, test_data = create_dataset(case, hl_model)

cfg_dict = {
    "n_layers": 2,
    "n_heads": 4,
    "d_head": 2,
    "d_model": 8,
    "d_mlp": 16,
    "seed": 0,
    "act_fn": "gelu",
}
ll_cfg = hl_model.cfg.to_dict().copy()
ll_cfg.update(cfg_dict)
print(ll_cfg)
model = HookedTransformer(ll_cfg)

{'n_layers': 2, 'd_model': 8, 'n_ctx': 5, 'd_head': 2, 'model_name': 'custom', 'n_heads': 4, 'd_mlp': 16, 'act_fn': 'gelu', 'd_vocab': 6, 'eps': 1e-05, 'use_attn_result': True, 'use_attn_scale': True, 'use_split_qkv_input': True, 'use_hook_mlp_in': True, 'use_attn_in': False, 'use_local_attn': False, 'original_architecture': None, 'from_checkpoint': False, 'checkpoint_index': None, 'checkpoint_label_type': None, 'checkpoint_value': None, 'tokenizer_name': None, 'window_size': None, 'attn_types': None, 'init_mode': 'gpt2', 'normalization_type': None, 'device': device(type='mps'), 'n_devices': 1, 'attention_dir': 'causal', 'attn_only': False, 'seed': 0, 'initializer_range': 0.22188007849009167, 'init_weights': True, 'scale_attn_by_inverse_layer_idx': False, 'positional_embedding_type': 'standard', 'final_rms': False, 'd_vocab_out': 1, 'parallel_attn_mlp': False, 'rotary_dim': None, 'n_params': 676, 'use_hook_tokens': False, 'gated_mlp': False, 'default_prepend_bos': True, 'dtype': torch.

In [2]:
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PretrainedConfig, GPT2Config, GPT2Model
from pyvene import IntervenableModel, IntervenableConfig

my_config = GPT2Config(
        vocab_size=ll_cfg['d_vocab'],
        n_positions=ll_cfg['n_ctx'],
        n_embd=ll_cfg['d_model'],
        n_layer=ll_cfg['n_layers'],
        n_head=ll_cfg['n_heads'],
        n_inner=ll_cfg['d_mlp'], # TODO: check if this is correct
        activation_function=ll_cfg['act_fn'],
        resid_pdrop=0,
        embd_pdrop=0,
        attn_pdrop=0,
        layer_norm_epsilon=ll_cfg['eps'],

        # not sure if we need to change these...
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=0.0,
        scale_attn_weights=True,
        use_cache=True,
        bos_token_id=50256,
        eos_token_id=50256,
        scale_attn_by_inverse_layer_idx=False,
        reorder_and_upcast_attn=False,
)

In [3]:
# class IntervenableTLConfig(IntervenableConfig):
#     def __init__(self, **kwargs):
#         super().__init__(**kwargs)


# class IntervenableTLModel(IntervenableModel):
#     def __init__(self, config, model, **kwargs):
#         super().__init__(config, model, **kwargs)
import transformers.models as hf_models
class TLModel(PreTrainedModel):
    def __init__(self, tl_config: HookedTransformerConfig, *inputs, config= my_config, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.model = HookedTransformer(tl_config)
        self.tl_config = tl_config
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    


# AutoModel.register(GPT2Config, TLModel)
transformers_model = TLModel(ll_cfg)

In [25]:
from pyvene.models.constants import *
tl_to_module_mapping = {
    "block_input": ("model.blocks[%s]", CONST_INPUT_HOOK),
    "block_output": ("model.blocks[%s]", CONST_OUTPUT_HOOK),
    # "mlp_activation": ("model.blocks[%s].mlp.act", CONST_OUTPUT_HOOK),
    "mlp_output": ("model.blocks[%s].mlp", CONST_OUTPUT_HOOK),
    "mlp_input": ("model.blocks[%s].mlp", CONST_INPUT_HOOK),

    # "attention_value_output": ("model.blocks[%s].attn.hook_v", CONST_INPUT_HOOK),
    # "head_attention_value_output": ("model.blocks[%s].attn.hook_", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
    # "attention_weight": ("model.blocks[%s].attn.attn_dropout", CONST_INPUT_HOOK),
    "attention_output": ("model.blocks[%s].attn.hook_result", CONST_OUTPUT_HOOK),
    "attention_input": ("model.blocks[%s].hook_attn_in", CONST_INPUT_HOOK),

    # "query_output": ("model.blocks[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0)),
    # "key_output": ("model.blocks[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1)),
    # "value_output": ("model.blocks[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 2)),
    # "head_query_output": ("model.blocks[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0), (split_head_and_permute, "n_head")), 
    # "head_key_output": ("model.blocks[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1), (split_head_and_permute, "n_head")),
    # "head_value_output": ("model.blocks[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 2), (split_head_and_permute, "n_head")),
}

tl_to_dimension_mapping = {
    "n_head": ("n_head", ),
    "block_input": ("n_embd",),
    "block_output": ("n_embd",),
    "mlp_activation": (
        "n_inner",
        "n_embd*4",
    ),
    "mlp_output": ("n_embd",),
    "mlp_input": ("n_embd",),
    "attention_value_output": ("n_embd",),
    "head_attention_value_output": ("n_embd/n_head",),
    "attention_weight": ("max_position_embeddings", ),
    "attention_output": ("n_embd",),
    "attention_input": ("n_embd",),
    "query_output": ("n_embd",),
    "key_output": ("n_embd",),
    "value_output": ("n_embd",),
    "head_query_output": ("n_embd/n_head",),
    "head_key_output": ("n_embd/n_head",),
    "head_value_output": ("n_embd/n_head",),
}

In [26]:
from pyvene.models.intervenable_modelcard import type_to_module_mapping, type_to_dimension_mapping

type_to_module_mapping[type(transformers_model)] = tl_to_module_mapping
type_to_dimension_mapping[type(transformers_model)] = tl_to_dimension_mapping

In [27]:
# PretrainedConfig().get_config_dict('hooked_transformer')

In [28]:
from pyvene import (
    IntervenableModel,
    RotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
)
subspace_partition=[[0, model.cfg.n_layers//2], [model.cfg.n_layers//2, model.cfg.n_layers]]
intervenable_config = IntervenableConfig(
        # model_type=type(transformers_model),
        representations=[
            RepresentationConfig(
                1,  # layer
                "attention_output",  # repr intervention type
                "pos",  # intervention unit
                1,      # max number of unit
                subspace_partition=subspace_partition,
            )
        ],
        intervention_types=RotatedSpaceIntervention,
    )

In [33]:
intervenable_model = IntervenableModel(intervenable_config, transformers_model)

attention_output
n_embd
model.blocks[1].attn.hook_result


In [30]:
intervenable_model.interventions.items()

dict_items([('layer.1.comp.attention_output.unit.pos.nunit.1#0', (RotatedSpaceIntervention(
  (rotate_layer): ParametrizedRotateLayer(
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _Orthogonal()
      )
    )
  )
), <bound method Module.register_forward_hook of HookPoint()>))])

In [34]:
g2m = GPT2Model(my_config)
for module in g2m.modules():
    print(module.__class__)

<class 'transformers.models.gpt2.modeling_gpt2.GPT2Model'>
<class 'torch.nn.modules.sparse.Embedding'>
<class 'torch.nn.modules.sparse.Embedding'>
<class 'torch.nn.modules.dropout.Dropout'>
<class 'torch.nn.modules.container.ModuleList'>
<class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>
<class 'torch.nn.modules.normalization.LayerNorm'>
<class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'>
<class 'transformers.pytorch_utils.Conv1D'>
<class 'transformers.pytorch_utils.Conv1D'>
<class 'torch.nn.modules.dropout.Dropout'>
<class 'torch.nn.modules.dropout.Dropout'>
<class 'torch.nn.modules.normalization.LayerNorm'>
<class 'transformers.models.gpt2.modeling_gpt2.GPT2MLP'>
<class 'transformers.pytorch_utils.Conv1D'>
<class 'transformers.pytorch_utils.Conv1D'>
<class 'transformers.activations.GELUActivation'>
<class 'torch.nn.modules.dropout.Dropout'>
<class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>
<class 'torch.nn.modules.normalization.LayerNorm'>
<class 'transfor

In [9]:
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention

In [35]:
from transformer_lens.components import Attention
from transformer_lens.hook_points import HookPoint

for module in model.modules():
    if isinstance(module, HookPoint):
        continue
    # if isinstance(module, Attention):
        # print(module)
    print(module)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)
Embed()
PosEmbed()
ModuleList(
  (

In [86]:
# tokenizer = AutoTokenizer.from_pretrained("gpt2")

# tokens = tokenizer("Hello, ", return_tensors="pt")["input_ids"]
# tokens = torch.tensor(tokens).unsqueeze(0)
cfg_dict = {
    "n_layers": 2,
    "n_heads": 4,
    "d_head": 2,
    "d_model": 8,
    "d_mlp": 16,
    "seed": 0,
    "act_fn": "gelu",
    "n_ctx": 10,
    "d_vocab": 10,
}
model = HookedTransformer(cfg_dict)
output, cache = model.run_with_cache(torch.tensor([[1, 2, 3,]]))

In [87]:
model.cfg.d_vocab, model.cfg.n_ctx, output.shape

(10, 10, torch.Size([1, 3, 10]))

In [103]:
cache["blocks.0.attn.hook_q"].shape, model.cfg.n_heads, model.cfg.d_head

(torch.Size([1, 3, 4, 2]), 4, 2)

In [93]:
cache.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.