# Compute Graph Replacement Score

In [1]:
from crosslayer_transcoder.utils.model_converters.circuit_tracer import (
    CircuitTracerConverter,
)
import torch
import os
import yaml


### Config

In [2]:
huggingface_path = "georglange/crosslayer-transcoder-topk-16k"

### Load model from HuggingFace 

In [3]:
from crosslayer_transcoder.model.serializable_module import SerializableModule


clt_model = SerializableModule.from_pretrained(huggingface_path)
assert clt_model.is_folded 

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

checkpoint.safetensors:   0%|          | 0.00/4.42G [00:00<?, ?B/s]

config.yaml:   0%|          | 0.00/636 [00:00<?, ?B/s]

Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

In [4]:
save_dir = "clt_module_test"
feature_input_hook = "hook_resid_mid"
feature_output_hook = "hook_mlp_out"

converter = CircuitTracerConverter(
save_dir=save_dir,
feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
)
converter.convert_and_save(clt_module, dtype=DTYPE) 

Converting CLT : 100%|██████████| 12/12 [00:23<00:00,  1.99s/it]


### Load model from local converted checkpoint

In the future, this could be loaded from huggingface using the `ReplacementModel.from_pretrained`

In [5]:
from circuit_tracer.transcoder.cross_layer_transcoder import load_clt
circuit_tracer_clt = load_clt(
    clt_path=save_dir,
    lazy_decoder=False,
    lazy_encoder=False,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    dtype=DTYPE,
    scan=identifier
)


In [6]:
from circuit_tracer import ReplacementModel

rm = ReplacementModel.from_pretrained_and_transcoders(
    "gpt2",
    circuit_tracer_clt,
    dtype=DTYPE,
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2 into HookedTransformer


### Attribution

In [7]:
prompt = (
    "The capital of state containing Dallas is"  # What you want to get the graph for
)
max_n_logits = 10  # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size = 256  # Batch size when attributing
verbose = True  # Whether to display a tqdm progress bar and timing report


In [8]:
from pathlib import Path
import torch

from circuit_tracer import attribute
from circuit_tracer.utils import create_graph_files

torch.cuda.empty_cache()

graph = attribute(
    prompt=prompt,
    model=rm,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=None,
    verbose=verbose
)

Phase 0: Precomputing activations and vectors
Precomputation completed in 0.51s
Found 1316 active features
Phase 1: Running forward pass
Forward pass completed in 0.04s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.2871
Will include 1316 of 1316 feature nodes
Input vectors built in 0.04s
Phase 3: Computing logit attributions
Logit attributions completed in 0.05s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 1316/1316 [00:00<00:00, 7560.45it/s]
Feature attributions completed in 0.18s
Attribution completed in 0.82s


### Replacement Score

In [13]:
from circuit_tracer.graph import compute_graph_scores
print("replacement score, completeness score")
compute_graph_scores(graph)

replacement score, completeness score


(0.7424161434173584, 0.919729471206665)

### [Optional] Viz

In [10]:
graph_dir = 'graphs'
graph_name = 'example_graph.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)

In [11]:
slug = "dallas-austin"  # this is the name that you assign to the graph
graph_file_dir = "./graph_files"  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold = (
    0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
)
edge_threshold = (
    0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98
)

create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold,
)


In [12]:
from circuit_tracer.frontend.local_server import serve
from IPython.display import IFrame

port = 8046
server = serve(data_dir="./graph_files/", port=port)


print(
    f"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'"
)
display(
    IFrame(src=f"http://localhost:{port}/index.html", width="100%", height="800px")
)


Use the IFrame below, or open your graph here: f'http://localhost:8046/index.html'
