# Model Sanity checks


## Load into `crosslayer-transcoder` arch

In [1]:
from crosslayer_transcoder.utils.module_builder import build_module_from_config, yaml_to_config
from crosslayer_transcoder.utils.checkpoints import load_model_from_lightning_checkpoint

config_path = "../../config/circuit-tracer.yaml"
checkpoint_path = "../checkpoints/clt.ckpt"

config = yaml_to_config(config_path)
clt_module = build_module_from_config(config)

clt_module = load_model_from_lightning_checkpoint(clt_module, checkpoint_path)

print(clt_module)

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors


JumpReLUCrossLayerTranscoderModule(
  (model): CrossLayerTranscoder(
    (encoder): Encoder()
    (decoder): CrosslayerDecoder()
    (nonlinearity): JumpReLU()
    (input_standardizer): DimensionwiseInputStandardizer()
    (output_standardizer): DimensionwiseOutputStandardizer()
  )
  (replacement_model): ReplacementModelAccuracy(
    (replacement_model): ReplacementModel()
  )
  (dead_features): DeadFeatures()
)


### Collect Activations

In [2]:
from nnsight import LanguageModel
prompt = (
    "The capital of state containing Dallas is"  # What you want to get the graph for
)
llm = LanguageModel("openai-community/gpt2")

In [3]:
import torch
mlp_in_activations = []

with llm.trace(prompt) as trace:
    for layer in range(12):
        layer_activations = llm.transformer.h[layer].ln_2.input.save()
        mlp_in_activations.append(layer_activations.squeeze(0))

mlp_in_activations = torch.stack(mlp_in_activations, dim=0)

print(mlp_in_activations.shape)

torch.Size([12, 7, 768])


In [4]:
import einops

in_acts = einops.rearrange(mlp_in_activations, "l b d -> b l d")
print(in_acts.shape)

torch.Size([7, 12, 768])


### Encode w/o standarizer folding

In [5]:
features = clt_module.model.encode(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (
    features.shape[0] * features.shape[1]
)
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 70298
l0_avg_per_layer: 836.8809814453125


### Encode w/ folding

In [6]:

features = clt_module.model.encode_with_standardizer_folding(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (features.shape[0] * features.shape[1])
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 17896
l0_avg_per_layer: 213.04762268066406


### Run encoding in bfloat16

In [7]:
in_acts = in_acts.to(torch.bfloat16)
clt_module.model.to(torch.bfloat16)
features = clt_module.model.encode_with_standardizer_folding(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (
    features.shape[0] * features.shape[1]
)
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 17843
l0_avg_per_layer: 212.4166717529297


## Interface w circuit tracer

In [8]:
import pathlib
from crosslayer_transcoder.utils.model_converters.circuit_tracer import (
    CircuitTracerConverter,
)
from circuit_tracer.transcoder.cross_layer_transcoder import load_clt

### Convert model to circuit-tracer format and save `.safetensors`

In [9]:
save_dir = pathlib.Path("clt_module_test")
feature_input_hook = "hook_resid_mid"
feature_output_hook = "hook_mlp_out"

converter = CircuitTracerConverter(
save_dir=save_dir,
feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
)
converter.convert_and_save(clt_module, dtype=torch.bfloat16) 

Converting CLT : 100%|██████████| 12/12 [00:10<00:00,  1.13it/s]


### Load CLT into circuit-tracer

In [10]:

circuit_tracer_transcoder, state_dict_pre_load = load_clt(
    clt_path=save_dir.as_posix(),
    lazy_decoder=False,
    lazy_encoder=False,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    dtype=torch.bfloat16,
)


12
['W_enc_0', 'b_dec_0', 'b_enc_0', 'threshold_0']
['W_enc_1', 'b_dec_1', 'b_enc_1', 'threshold_1']
['W_enc_2', 'b_dec_2', 'b_enc_2', 'threshold_2']
['W_enc_3', 'b_dec_3', 'b_enc_3', 'threshold_3']
['W_enc_4', 'b_dec_4', 'b_enc_4', 'threshold_4']
['W_enc_5', 'b_dec_5', 'b_enc_5', 'threshold_5']
['W_enc_6', 'b_dec_6', 'b_enc_6', 'threshold_6']
['W_enc_7', 'b_dec_7', 'b_enc_7', 'threshold_7']
['W_enc_8', 'b_dec_8', 'b_enc_8', 'threshold_8']
['W_enc_9', 'b_dec_9', 'b_enc_9', 'threshold_9']
['W_enc_10', 'b_dec_10', 'b_enc_10', 'threshold_10']
['W_enc_11', 'b_dec_11', 'b_enc_11', 'threshold_11']
dict_keys(['b_dec', 'b_enc', 'activation_function.threshold', 'W_enc', 'W_dec.0', 'W_dec.1', 'W_dec.2', 'W_dec.3', 'W_dec.4', 'W_dec.5', 'W_dec.6', 'W_dec.7', 'W_dec.8', 'W_dec.9', 'W_dec.10', 'W_dec.11'])


### Sanity checks for loaded model

In [11]:
import einops
from safetensors import safe_open
import torch
import os
# sanity check against original model

clt_path = save_dir.as_posix()
assert circuit_tracer_transcoder.clt_path == clt_path


# TEST: state_dict_pre_load weights match the files before being loaded into the model
for i in range(circuit_tracer_transcoder.n_layers):
    enc_file = os.path.join(clt_path, f"W_enc_{i}.safetensors")
    with safe_open(enc_file, framework="pt", device=circuit_tracer_transcoder.device.type) as f:
        w_file = f.get_tensor(f"W_enc_{i}").to(
            dtype=state_dict_pre_load["W_enc"][i].dtype,
            device=state_dict_pre_load["W_enc"][i].device,
        )
        w_state = state_dict_pre_load["W_enc"][i]
        assert w_file.shape == w_state.shape, (
            f"W_enc_{i} shape mismatch: {w_file.shape} != {w_state.shape}"
        )
        assert w_file.dtype == w_state.dtype, (
            f"W_enc_{i} dtype mismatch: {w_file.dtype} != {w_state.dtype}"
        )
        assert torch.allclose(w_file, w_state), (
            i,
            (w_file - w_state).abs().max().item(),
        )


# TEST: weights in the files should equal weights from the circuit_tracer_transcoder
for i in range(circuit_tracer_transcoder.n_layers):
    w_model = circuit_tracer_transcoder._get_encoder_weights(i)  # works for both lazy and eager
    enc_file = os.path.join(clt_path, f"W_enc_{i}.safetensors")
    with safe_open(enc_file, framework="pt", device=circuit_tracer_transcoder.device.type) as f:
        w_file = f.get_tensor(f"W_enc_{i}").to(
            dtype=w_model.dtype, device=w_model.device
        )

    assert w_model.shape == w_file.shape
    assert w_model.dtype == w_file.dtype
    assert torch.allclose(w_model, w_file), (i, (w_model - w_file).abs().max().item())


# fold
standardizer = clt_module.model.input_standardizer
W_enc_folded, b_enc_folded = standardizer.fold_in_encoder(
    clt_module.model.encoder.W.to(dtype=torch.bfloat16),
    clt_module.model.encoder.b.to(dtype=torch.bfloat16),
)

W_enc_folded = W_enc_folded.to(dtype=torch.bfloat16)
b_enc_folded = b_enc_folded.to(dtype=torch.bfloat16)

state_dict = {}
device = clt_module.device
# TEST: eights in the file should equal the folded weights
for i in range(clt_module.model.encoder.n_layers):
    enc_file = f"W_enc_{i}.safetensors"
    with safe_open(
        os.path.join(clt_path, enc_file), framework="pt", device=device.type
    ) as f:
        assert W_enc_folded[i].T.shape == f.get_tensor(f"W_enc_{i}").shape, (
            f"W_enc_{i} shape mismatch: {W_enc_folded[i].shape} != {f.get_tensor(f'W_enc_{i}').shape}"
        )
        assert W_enc_folded[i].dtype == f.get_tensor(f"W_enc_{i}").dtype, (
            f"W_enc_{i} dtype mismatch: {W_enc_folded[i].dtype} != {f.get_tensor(f'W_enc_{i}').dtype}"
        )
        assert torch.allclose(W_enc_folded[i].T, f.get_tensor(f"W_enc_{i}"))
        assert torch.allclose(b_enc_folded[i], f.get_tensor(f"b_enc_{i}"))

        # loaded model
        assert circuit_tracer_transcoder.W_enc[i].shape == f.get_tensor(f"W_enc_{i}").shape
        assert circuit_tracer_transcoder.W_enc[i].dtype == f.get_tensor(f"W_enc_{i}").dtype, (
            f"W_enc_{i} dtype mismatch: {circuit_tracer_transcoder.W_enc[i].dtype} != {f.get_tensor(f'W_enc_{i}').dtype}"
        )
        assert torch.allclose(
            circuit_tracer_transcoder.W_enc[i], f.get_tensor(f"W_enc_{i}").to("cuda:0")
        )


for i in range(clt_module.model.encoder.n_layers):
    rearranged_W_enc = einops.rearrange(
        W_enc_folded[i],
        "d_acts d_features -> d_features d_acts",
    ).contiguous()
    assert torch.allclose(
        rearranged_W_enc.to("cuda:0"), circuit_tracer_transcoder.W_enc[i].to("cuda:0")
    )
    assert torch.allclose(
        b_enc_folded[i].to(dtype=torch.bfloat16).to("cuda:0"),
        circuit_tracer_transcoder.b_enc[i].to(dtype=torch.bfloat16).to("cuda:0"),
    )
