# Sanity check model

Sanity check the model checkpoint by loading it into our own arch and then running a prompt through the encoder to see what the sparsity of the model is

In [1]:
from crosslayer_transcoder.utils.module_builder import build_module_from_config, yaml_to_config
from crosslayer_transcoder.utils.checkpoints import load_model_from_lightning_checkpoint

config = yaml_to_config("../../config/circuit-tracer.yaml")
clt_module = build_module_from_config(config)

# load checkpoint
checkpoint = "../checkpoints/clt.ckpt"

clt_module = load_model_from_lightning_checkpoint(clt_module, checkpoint)

print(clt_module)

model {'class_path': 'crosslayer_transcoder.model.clt.CrossLayerTranscoder', 'init_args': {'encoder': {'class_path': 'crosslayer_transcoder.model.clt.Encoder', 'init_args': {'d_acts': 768, 'd_features': 10000, 'n_layers': 12}}, 'decoder': {'class_path': 'crosslayer_transcoder.model.clt.CrosslayerDecoder', 'init_args': {'d_acts': 768, 'd_features': 10000, 'n_layers': 12}}, 'nonlinearity': {'class_path': 'crosslayer_transcoder.model.jumprelu.JumpReLU', 'init_args': {'theta': 0.03, 'bandwidth': 0.01, 'n_layers': 12, 'd_features': 10000}}, 'input_standardizer': {'class_path': 'crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer', 'init_args': {'n_layers': 12, 'activation_dim': 768}}, 'output_standardizer': {'class_path': 'crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer', 'init_args': {'n_layers': 12, 'activation_dim': 768}}}}
encoder {'class_path': 'crosslayer_transcoder.model.clt.Encoder', 'init_args': {'d_acts': 768, 'd_features': 10000, 'n_laye

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors


dead_features {'class_path': 'crosslayer_transcoder.metrics.dead_features.DeadFeatures', 'init_args': {'n_features': 10000, 'n_layers': 12, 'return_per_layer': True, 'return_log_freqs': True, 'return_neuron_indices': True}}
n_features 10000
n_layers 12
return_per_layer True
return_log_freqs True
return_neuron_indices True
learning_rate 3e-4
compile True
lr_decay_step 16000
lr_decay_factor 0.1
lambda_sparsity 0.0007
c_sparsity 1
use_tanh True
pre_actv_loss 1e-6
compute_dead_features True
compute_dead_features_every 500
JumpReLUCrossLayerTranscoderModule(
  (model): CrossLayerTranscoder(
    (encoder): Encoder()
    (decoder): CrosslayerDecoder()
    (nonlinearity): JumpReLU()
    (input_standardizer): DimensionwiseInputStandardizer()
    (output_standardizer): DimensionwiseOutputStandardizer()
  )
  (replacement_model): ReplacementModelAccuracy(
    (replacement_model): ReplacementModel()
  )
  (dead_features): DeadFeatures()
)


In [2]:
from nnsight import LanguageModel
prompt = (
    "The capital of state containing Dallas is"  # What you want to get the graph for
)
llm = LanguageModel("openai-community/gpt2")

## Collect Activations

In [17]:
import torch
mlp_in_activations = []
mlp_out_activations = []

with llm.trace(prompt) as trace:
    for layer in range(12):
        layer_activations = llm.transformer.h[layer].ln_2.input.save()
        layer_activations = llm.transformer.h[layer].mlp.output.save()
        mlp_in_activations.append(layer_activations.squeeze(0))
        mlp_out_activations.append(layer_activations.squeeze(0))

mlp_in_activations = torch.stack(mlp_in_activations, dim=0)
mlp_out_activations = torch.stack(mlp_out_activations, dim=0)

print(mlp_in_activations.shape)
print(mlp_out_activations.shape)

torch.Size([12, 7, 768])
torch.Size([12, 7, 768])


## Run Encoder w activations

In [None]:
import einops

in_acts = einops.rearrange(mlp_in_activations, "l b d -> b l d")
out_acts = einops.rearrange(mlp_out_activations, "l b d -> b l d")
print(in_acts.shape)
print(out_acts.shape)
batch_acts = torch.stack([in_acts, out_acts], dim=1)
print(batch_acts.shape)
clt_module.model.initialize_standardizers(batch_acts)
_, features, _, _ = clt_module.model(in_acts)

print(features.shape)

sparse_features = features.to_sparse()
print(sparse_features._nnz())

l0_avg_per_layer = torch.count_nonzero(features > 0) / (features.shape[0] * features.shape[1])
print(l0_avg_per_layer.item())


torch.Size([7, 12, 768])
torch.Size([7, 12, 768])
torch.Size([7, 2, 12, 768])
torch.Size([7, 12, 10000])
torch.Size([7, 12, 10000])
4162
49.5476188659668
