# Model Sanity checks


## Load into `crosslayer-transcoder` arch

In [None]:
import yaml
from crosslayer_transcoder.utils.module_builder import build_module_from_config 
import torch

config_path = "../../config/circuit-tracer.yaml"
checkpoint_path = "../checkpoints/clt.ckpt"

with open(config_path, "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    clt_module = build_module_from_config(config["model"])

checkpoint = torch.load(checkpoint_path, map_location='cuda:0')
clt_module.load_state_dict(checkpoint["state_dict"])

print(clt_module)

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors


JumpReLUCrossLayerTranscoderModule(
  (model): CrossLayerTranscoder(
    (encoder): Encoder()
    (decoder): CrosslayerDecoder()
    (nonlinearity): JumpReLU()
    (input_standardizer): DimensionwiseInputStandardizer()
    (output_standardizer): DimensionwiseOutputStandardizer()
  )
  (replacement_model): ReplacementModelAccuracy(
    (replacement_model): ReplacementModel()
  )
  (dead_features): DeadFeatures()
)


### Collect Activations

In [2]:
from nnsight import LanguageModel
prompt = (
    "The capital of state containing Dallas is"  # What you want to get the graph for
)
llm = LanguageModel("openai-community/gpt2")

In [3]:
import torch
mlp_in_activations = []

with llm.trace(prompt) as trace:
    for layer in range(12):
        layer_activations = llm.transformer.h[layer].ln_2.input.save()
        mlp_in_activations.append(layer_activations.squeeze(0))

mlp_in_activations = torch.stack(mlp_in_activations, dim=0)

print(mlp_in_activations.shape)

torch.Size([12, 7, 768])


In [4]:
import einops

in_acts = einops.rearrange(mlp_in_activations, "l b d -> b l d")
print(in_acts.shape)

torch.Size([7, 12, 768])


### Encode w/o standarizer folding

In [5]:
features = clt_module.model.encode(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (
    features.shape[0] * features.shape[1]
)
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 70298
l0_avg_per_layer: 836.8809814453125


### Encode w/ folding

In [6]:

features = clt_module.model.encode_with_standardizer_folding(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (features.shape[0] * features.shape[1])
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 17896
l0_avg_per_layer: 213.04762268066406


### Run encoding in bfloat16

In [7]:
in_acts = in_acts.to(torch.bfloat16)
clt_module.model.to(torch.bfloat16)
features = clt_module.model.encode_with_standardizer_folding(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (
    features.shape[0] * features.shape[1]
)
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 17843
l0_avg_per_layer: 212.4166717529297


## Interface w circuit tracer

In [8]:
import pathlib
from crosslayer_transcoder.utils.model_converters.circuit_tracer import (
    CircuitTracerConverter,
)
from circuit_tracer.transcoder.cross_layer_transcoder import load_clt

### Convert model to circuit-tracer format and save `.safetensors`

In [9]:
save_dir = pathlib.Path("clt_module_test")
feature_input_hook = "hook_resid_mid"
feature_output_hook = "hook_mlp_out"

converter = CircuitTracerConverter(
save_dir=save_dir,
feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
)
converter.convert_and_save(clt_module, dtype=torch.bfloat16) 

Converting CLT : 100%|██████████| 12/12 [00:13<00:00,  1.15s/it]


### Load CLT into circuit-tracer

In [10]:

circuit_tracer_transcoder, state_dict_pre_load = load_clt(
    clt_path=save_dir.as_posix(),
    lazy_decoder=False,
    lazy_encoder=False,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    dtype=torch.bfloat16,
)


12
['W_enc_0', 'b_dec_0', 'b_enc_0', 'threshold_0']
['W_enc_1', 'b_dec_1', 'b_enc_1', 'threshold_1']
['W_enc_2', 'b_dec_2', 'b_enc_2', 'threshold_2']
['W_enc_3', 'b_dec_3', 'b_enc_3', 'threshold_3']
['W_enc_4', 'b_dec_4', 'b_enc_4', 'threshold_4']
['W_enc_5', 'b_dec_5', 'b_enc_5', 'threshold_5']
['W_enc_6', 'b_dec_6', 'b_enc_6', 'threshold_6']
['W_enc_7', 'b_dec_7', 'b_enc_7', 'threshold_7']
['W_enc_8', 'b_dec_8', 'b_enc_8', 'threshold_8']
['W_enc_9', 'b_dec_9', 'b_enc_9', 'threshold_9']
['W_enc_10', 'b_dec_10', 'b_enc_10', 'threshold_10']
['W_enc_11', 'b_dec_11', 'b_enc_11', 'threshold_11']
dict_keys(['b_dec', 'b_enc', 'activation_function.threshold', 'W_enc', 'W_dec.0', 'W_dec.1', 'W_dec.2', 'W_dec.3', 'W_dec.4', 'W_dec.5', 'W_dec.6', 'W_dec.7', 'W_dec.8', 'W_dec.9', 'W_dec.10', 'W_dec.11'])


### Sanity checks for loaded model

In [11]:
import einops
from safetensors import safe_open
import torch
import os
# sanity check against original model

clt_path = save_dir.as_posix()
assert circuit_tracer_transcoder.clt_path == clt_path


# TEST: state_dict_pre_load weights match the files before being loaded into the model
for i in range(circuit_tracer_transcoder.n_layers):
    enc_file = os.path.join(clt_path, f"W_enc_{i}.safetensors")
    with safe_open(enc_file, framework="pt", device=circuit_tracer_transcoder.device.type) as f:
        w_file = f.get_tensor(f"W_enc_{i}").to(
            dtype=state_dict_pre_load["W_enc"][i].dtype,
            device=state_dict_pre_load["W_enc"][i].device,
        )
        w_state = state_dict_pre_load["W_enc"][i]
        assert w_file.shape == w_state.shape, (
            f"W_enc_{i} shape mismatch: {w_file.shape} != {w_state.shape}"
        )
        assert w_file.dtype == w_state.dtype, (
            f"W_enc_{i} dtype mismatch: {w_file.dtype} != {w_state.dtype}"
        )
        assert torch.allclose(w_file, w_state), (
            i,
            (w_file - w_state).abs().max().item(),
        )


# TEST: weights in the files should equal weights from the circuit_tracer_transcoder
for i in range(circuit_tracer_transcoder.n_layers):
    w_model = circuit_tracer_transcoder._get_encoder_weights(i)  # works for both lazy and eager
    enc_file = os.path.join(clt_path, f"W_enc_{i}.safetensors")
    with safe_open(enc_file, framework="pt", device=circuit_tracer_transcoder.device.type) as f:
        w_file = f.get_tensor(f"W_enc_{i}").to(
            dtype=w_model.dtype, device=w_model.device
        )

    assert w_model.shape == w_file.shape
    assert w_model.dtype == w_file.dtype
    assert torch.allclose(w_model, w_file), (i, (w_model - w_file).abs().max().item())


# fold
standardizer = clt_module.model.input_standardizer
W_enc_folded, b_enc_folded = standardizer.fold_in_encoder(
    clt_module.model.encoder.W.to(dtype=torch.bfloat16),
    clt_module.model.encoder.b.to(dtype=torch.bfloat16),
)

W_enc_folded = W_enc_folded.to(dtype=torch.bfloat16)
b_enc_folded = b_enc_folded.to(dtype=torch.bfloat16)

state_dict = {}
device = clt_module.device
# TEST: eights in the file should equal the folded weights
for i in range(clt_module.model.encoder.n_layers):
    enc_file = f"W_enc_{i}.safetensors"
    with safe_open(
        os.path.join(clt_path, enc_file), framework="pt", device=device.type
    ) as f:
        assert W_enc_folded[i].T.shape == f.get_tensor(f"W_enc_{i}").shape, (
            f"W_enc_{i} shape mismatch: {W_enc_folded[i].shape} != {f.get_tensor(f'W_enc_{i}').shape}"
        )
        assert W_enc_folded[i].dtype == f.get_tensor(f"W_enc_{i}").dtype, (
            f"W_enc_{i} dtype mismatch: {W_enc_folded[i].dtype} != {f.get_tensor(f'W_enc_{i}').dtype}"
        )
        assert torch.allclose(W_enc_folded[i].T, f.get_tensor(f"W_enc_{i}"))
        assert torch.allclose(b_enc_folded[i], f.get_tensor(f"b_enc_{i}"))

        # loaded model
        assert circuit_tracer_transcoder.W_enc[i].shape == f.get_tensor(f"W_enc_{i}").shape
        assert circuit_tracer_transcoder.W_enc[i].dtype == f.get_tensor(f"W_enc_{i}").dtype, (
            f"W_enc_{i} dtype mismatch: {circuit_tracer_transcoder.W_enc[i].dtype} != {f.get_tensor(f'W_enc_{i}').dtype}"
        )
        assert torch.allclose(
            circuit_tracer_transcoder.W_enc[i], f.get_tensor(f"W_enc_{i}").to("cuda:0")
        )


for i in range(clt_module.model.encoder.n_layers):
    rearranged_W_enc = einops.rearrange(
        W_enc_folded[i],
        "d_acts d_features -> d_features d_acts",
    ).contiguous()
    assert torch.allclose(
        rearranged_W_enc.to("cuda:0"), circuit_tracer_transcoder.W_enc[i].to("cuda:0")
    )
    assert torch.allclose(
        b_enc_folded[i].to(dtype=torch.bfloat16).to("cuda:0"),
        circuit_tracer_transcoder.b_enc[i].to(dtype=torch.bfloat16).to("cuda:0"),
    )


### Test Loaded CLT encoding

In [12]:
in_acts_clt = einops.rearrange(in_acts, "b l d -> l b d")
assert in_acts_clt.shape == (12, 7, 768), in_acts_clt.shape

in_acts_clt = in_acts_clt.to("cuda:0")

print(in_acts_clt.shape)
print(in_acts_clt.device, in_acts_clt.dtype)

circuit_tracer_transcoder.encode_sparse(in_acts_clt)

torch.Size([12, 7, 768])
cuda:0 torch.bfloat16
layer 0 nnz: 426
layer 1 nnz: 204
layer 2 nnz: 107
layer 3 nnz: 94
layer 4 nnz: 104
layer 5 nnz: 72
layer 6 nnz: 88
layer 7 nnz: 36
layer 8 nnz: 35
layer 9 nnz: 18
layer 10 nnz: 14
layer 11 nnz: 2


(tensor(indices=tensor([[   0,    0,    0,  ...,   10,   11,   11],
                        [   1,    1,    1,  ...,    6,    1,    6],
                        [  71,  193,  330,  ..., 9154, 7738,  433]]),
        values=tensor([-0.7852,  1.4062,  1.7812,  ...,  1.3203,  1.1094,
                        0.4297]),
        device='cuda:0', size=(12, 7, 10000), nnz=1200, dtype=torch.bfloat16,
        layout=torch.sparse_coo, grad_fn=<CoalesceBackward0>),
 tensor([[-1.9409e-02, -1.8677e-02, -5.5420e-02,  ...,  2.7222e-02,
           2.3438e-01, -2.4609e-01],
         [-2.0447e-03,  3.7109e-02,  6.4697e-03,  ..., -4.1797e-01,
          -2.8516e-01,  6.2988e-02],
         [-2.6978e-02, -1.0193e-02,  6.5613e-04,  ..., -4.7656e-01,
          -2.8564e-02, -1.9824e-01],
         ...,
         [ 2.1210e-03,  1.1520e-03,  6.0730e-03,  ..., -2.9907e-03,
          -1.3885e-03,  5.0964e-03],
         [-2.5787e-03,  1.1826e-03,  5.6152e-03,  ...,  4.8218e-03,
           5.2795e-03,  4.4861e-03],
      

In [13]:
print(in_acts.shape)
print(in_acts_clt.device, in_acts_clt.dtype)
circuit_tracer_transcoder.compute_attribution_components(in_acts_clt)

torch.Size([7, 12, 768])
cuda:0 torch.bfloat16
layer 0 nnz: 426
layer 1 nnz: 204
layer 2 nnz: 107
layer 3 nnz: 94
layer 4 nnz: 104
layer 5 nnz: 72
layer 6 nnz: 88
layer 7 nnz: 36
layer 8 nnz: 35
layer 9 nnz: 18
layer 10 nnz: 14
layer 11 nnz: 2
nnz features: 1200
torch.Size([12, 7, 10000])


{'activation_matrix': tensor(indices=tensor([[   0,    0,    0,  ...,   10,   11,   11],
                        [   1,    1,    1,  ...,    6,    1,    6],
                        [  71,  193,  330,  ..., 9154, 7738,  433]]),
        values=tensor([-0.7852,  1.4062,  1.7812,  ...,  1.3203,  1.1094,
                        0.4297]),
        device='cuda:0', size=(12, 7, 10000), nnz=1200, dtype=torch.bfloat16,
        layout=torch.sparse_coo, grad_fn=<CoalesceBackward0>),
 'reconstruction': tensor([[[-3.6621e-02,  9.7656e-02,  5.9814e-03,  ..., -6.0547e-02,
           -5.7373e-02, -1.0986e-01],
          [-1.9531e+00,  6.2109e-01, -1.5625e+00,  ...,  1.0547e+00,
            1.2939e-02, -2.5977e-01],
          [ 4.8340e-02,  7.9688e-01,  1.2695e-01,  ..., -6.1719e-01,
            7.6953e-01, -5.2490e-02],
          ...,
          [-1.9688e+00,  1.0312e+00,  4.0430e-01,  ...,  2.0117e-01,
            1.2969e+00, -4.1406e-01],
          [-1.1016e+00,  3.1641e-01, -2.0312e-01,  ..., -1.9062

### Test with ReplacementModel

In [14]:
from circuit_tracer import ReplacementModel
from transformer_lens.loading_from_pretrained import get_pretrained_model_config

# copy transcoder 
circuit_tracer_transcoder_copy, _ = load_clt(
    clt_path=save_dir.as_posix(),
    lazy_decoder=False,
    lazy_encoder=False,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    dtype=torch.bfloat16,
)

print(circuit_tracer_transcoder_copy.device, circuit_tracer_transcoder_copy.dtype)
config = get_pretrained_model_config("gpt2")
config.dtype=torch.bfloat16
rm = ReplacementModel.from_pretrained_and_transcoders(
    "gpt2",
    circuit_tracer_transcoder_copy,
)

print(rm.transcoders.device, rm.transcoders.dtype)
print(rm.hook_dict)

12
['W_enc_0', 'b_dec_0', 'b_enc_0', 'threshold_0']
['W_enc_1', 'b_dec_1', 'b_enc_1', 'threshold_1']
['W_enc_2', 'b_dec_2', 'b_enc_2', 'threshold_2']
['W_enc_3', 'b_dec_3', 'b_enc_3', 'threshold_3']
['W_enc_4', 'b_dec_4', 'b_enc_4', 'threshold_4']
['W_enc_5', 'b_dec_5', 'b_enc_5', 'threshold_5']
['W_enc_6', 'b_dec_6', 'b_enc_6', 'threshold_6']
['W_enc_7', 'b_dec_7', 'b_enc_7', 'threshold_7']
['W_enc_8', 'b_dec_8', 'b_enc_8', 'threshold_8']
['W_enc_9', 'b_dec_9', 'b_enc_9', 'threshold_9']
['W_enc_10', 'b_dec_10', 'b_enc_10', 'threshold_10']
['W_enc_11', 'b_dec_11', 'b_enc_11', 'threshold_11']
dict_keys(['b_dec', 'b_enc', 'activation_function.threshold', 'W_enc', 'W_dec.0', 'W_dec.1', 'W_dec.2', 'W_dec.3', 'W_dec.4', 'W_dec.5', 'W_dec.6', 'W_dec.7', 'W_dec.8', 'W_dec.9', 'W_dec.10', 'W_dec.11'])
cuda:0 torch.bfloat16


`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Loaded pretrained model gpt2 into HookedTransformer
cuda:0 torch.float32
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.ln1.hook_scale': HookPoint(), 'blocks.0.ln1.hook_normalized': HookPoint(), 'blocks.0.ln2.hook_scale': HookPoint(), 'blocks.0.ln2.hook_normalized': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.old_mlp.hook_pre': HookPoint(), 'blocks.0.mlp.old_mlp.hook_post': HookPoint(), 'blocks.0.mlp.hook_in': HookPoint(), 'blocks.0.mlp.hook_out': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out

In [15]:

if isinstance(prompt, str):
    tokens = rm.ensure_tokenized(prompt)
else:
    tokens = prompt.squeeze()
# COuld it be the caching hooks?
mlp_in_cache, mlp_in_caching_hooks, _ = rm.get_caching_hooks(
    lambda name: rm.feature_input_hook in name
)
print(mlp_in_caching_hooks)

mlp_out_cache, mlp_out_caching_hooks, _ = rm.get_caching_hooks(
    lambda name: rm.feature_output_hook in name
)
logits = rm.run_with_hooks(tokens, fwd_hooks=mlp_in_caching_hooks + mlp_out_caching_hooks)
print(logits.shape)


print(mlp_in_cache.items())

mlp_in_cache = torch.cat(list(mlp_in_cache.values()), dim=0)
mlp_out_cache = torch.cat(list(mlp_out_cache.values()), dim=0)
print(mlp_in_cache.device, mlp_in_cache.dtype)

[('blocks.0.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7ae5023fd800>, is_backward=False)), ('blocks.1.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7ae5023fd800>, is_backward=False)), ('blocks.2.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7ae5023fd800>, is_backward=False)), ('blocks.3.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7ae5023fd800>, is_backward=False)), ('blocks.4.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7ae5023fd800>, is_backward=False)), ('blocks.5.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7ae5023fd800>, is_backward=False)), ('blocks.6.hook_resid_mid', functools.partial(<function HookedRootModule.get_caching_hooks.<l

In [16]:
nnsight_acts = []
with llm.trace(tokens) as trace:
    for layer in range(12):
        layer_activations = llm.transformer.h[layer].ln_2.input.save()
        nnsight_acts.append(layer_activations.squeeze(0))

nnsight_acts = torch.stack(nnsight_acts, dim=0)

print(nnsight_acts.shape)
nnsight_acts = nnsight_acts.to("cuda:0").to(torch.bfloat16)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([12, 8, 768])


In [17]:
mlp_in_cache = mlp_in_cache.to("cuda:0").to(torch.bfloat16)
print(mlp_in_cache.device, mlp_in_cache.dtype)

cuda:0 torch.bfloat16


In [18]:
## TEST Difference between nnsight and TL activations
for i in range(12):
    print(nnsight_acts[i].shape, mlp_in_cache[i].shape)
    print(nnsight_acts[i].device, mlp_in_cache[i].device)
    print(nnsight_acts[i].dtype, mlp_in_cache[i].dtype)
    print(f"nans?: {torch.isnan(nnsight_acts[i]).any()}, {torch.isnan(mlp_in_cache[i]).any()}")
    print(f"max diff: {(nnsight_acts[i] - mlp_in_cache[i]).abs().max().item()}")
    print("-"*100)


torch.Size([8, 768]) torch.Size([8, 768])
cuda:0 cuda:0
torch.bfloat16 torch.bfloat16
nans?: False, False
max diff: 0.001953125
----------------------------------------------------------------------------------------------------
torch.Size([8, 768]) torch.Size([8, 768])
cuda:0 cuda:0
torch.bfloat16 torch.bfloat16
nans?: False, False
max diff: 0.000244140625
----------------------------------------------------------------------------------------------------
torch.Size([8, 768]) torch.Size([8, 768])
cuda:0 cuda:0
torch.bfloat16 torch.bfloat16
nans?: False, False
max diff: 0.0009765625
----------------------------------------------------------------------------------------------------
torch.Size([8, 768]) torch.Size([8, 768])
cuda:0 cuda:0
torch.bfloat16 torch.bfloat16
nans?: False, False
max diff: 0.0
----------------------------------------------------------------------------------------------------
torch.Size([8, 768]) torch.Size([8, 768])
cuda:0 cuda:0
torch.bfloat16 torch.bfloat16
na

In [19]:
print(nnsight_acts.shape, nnsight_acts.device, nnsight_acts.dtype)
print(mlp_in_cache.shape, mlp_in_cache.device, mlp_in_cache.dtype)
attribution_data_a = circuit_tracer_transcoder.compute_attribution_components(nnsight_acts)
attribution_data_c = circuit_tracer_transcoder.compute_attribution_components(mlp_in_cache)


torch.Size([12, 8, 768]) cuda:0 torch.bfloat16
torch.Size([12, 8, 768]) cuda:0 torch.bfloat16
layer 0 nnz: 467
layer 1 nnz: 241
layer 2 nnz: 130
layer 3 nnz: 108
layer 4 nnz: 107
layer 5 nnz: 66
layer 6 nnz: 88
layer 7 nnz: 42
layer 8 nnz: 31
layer 9 nnz: 18
layer 10 nnz: 12
layer 11 nnz: 1
nnz features: 1311
torch.Size([12, 8, 10000])
layer 0 nnz: 467
layer 1 nnz: 241
layer 2 nnz: 130
layer 3 nnz: 108
layer 4 nnz: 107
layer 5 nnz: 66
layer 6 nnz: 88
layer 7 nnz: 42
layer 8 nnz: 31
layer 9 nnz: 18
layer 10 nnz: 12
layer 11 nnz: 1
nnz features: 1311
torch.Size([12, 8, 10000])


In [20]:
print(circuit_tracer_transcoder_copy.device, circuit_tracer_transcoder_copy.dtype)
attribution_data_e = circuit_tracer_transcoder_copy.compute_attribution_components(mlp_in_cache)
attribution_data_e = circuit_tracer_transcoder_copy.compute_attribution_components(in_acts_clt)

cuda:0 torch.float32


RuntimeError: expected scalar type BFloat16 but found Float

In [None]:

attribution_data = rm.transcoders.compute_attribution_components(mlp_in_cache)

### Attribution test

In [None]:
max_n_logits = 10  # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 100  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size = 256 // 8  # Batch size when attributing
offload = "cpu"  # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True  # Whether to display a tqdm progress bar and timing report


In [None]:
from pathlib import Path
import torch

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files

torch.cuda.empty_cache()

# FOr some reason this takes up a lot of VRAM
graph = attribute(
    prompt=prompt,
    model=rm,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose,
)
