# Sanity check model

Sanity check the model checkpoint by loading it into our own arch and then running a prompt through the encoder to see what the sparsity of the model is

In [1]:
from crosslayer_transcoder.utils.module_builder import build_module_from_config, yaml_to_config
from crosslayer_transcoder.utils.checkpoints import load_model_from_lightning_checkpoint

config_path = "../../config/circuit-tracer.yaml"
checkpoint_path = "../checkpoints/clt.ckpt"

config = yaml_to_config(config_path)
clt_module = build_module_from_config(config)

clt_module = load_model_from_lightning_checkpoint(clt_module, checkpoint_path)

print(clt_module)

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors


JumpReLUCrossLayerTranscoderModule(
  (model): CrossLayerTranscoder(
    (encoder): Encoder()
    (decoder): CrosslayerDecoder()
    (nonlinearity): JumpReLU()
    (input_standardizer): DimensionwiseInputStandardizer()
    (output_standardizer): DimensionwiseOutputStandardizer()
  )
  (replacement_model): ReplacementModelAccuracy(
    (replacement_model): ReplacementModel()
  )
  (dead_features): DeadFeatures()
)


## Collect Activations

In [2]:
from nnsight import LanguageModel
prompt = (
    "The capital of state containing Dallas is"  # What you want to get the graph for
)
llm = LanguageModel("openai-community/gpt2")

In [3]:
import torch
mlp_in_activations = []

with llm.trace(prompt) as trace:
    for layer in range(12):
        layer_activations = llm.transformer.h[layer].ln_2.input.save()
        mlp_in_activations.append(layer_activations.squeeze(0))

mlp_in_activations = torch.stack(mlp_in_activations, dim=0)

print(mlp_in_activations.shape)

torch.Size([12, 7, 768])


In [4]:
import einops

in_acts = einops.rearrange(mlp_in_activations, "l b d -> b l d")
print(in_acts.shape)

torch.Size([7, 12, 768])


## Encode w/o standarizer folding

In [5]:
features = clt_module.model.encode(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (
    features.shape[0] * features.shape[1]
)
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 70298
l0_avg_per_layer: 836.8809814453125


## Encode w/ folding

In [6]:

features = clt_module.model.encode_with_standardizer_folding(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (features.shape[0] * features.shape[1])
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 17896
l0_avg_per_layer: 213.04762268066406


## Run encoding in bfloat16

In [7]:
in_acts = in_acts.to(torch.bfloat16)
clt_module.model.to(torch.bfloat16)
features = clt_module.model.encode_with_standardizer_folding(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(f"sparse_features._nnz(): {sparse_features._nnz()}")

l0_avg_per_layer = torch.count_nonzero(features > 0) / (
    features.shape[0] * features.shape[1]
)
print(f"l0_avg_per_layer: {l0_avg_per_layer.item()}")


torch.Size([7, 12, 10000])
sparse_features._nnz(): 17843
l0_avg_per_layer: 212.4166717529297
