# Compute Graph Replacement Score

### Config

In [1]:
huggingface_path = "georglange/crosslayer-transcoder-topk-16k"

### Load model from HuggingFace 

In [2]:
from crosslayer_transcoder.model.serializable_module import SerializableModule


clt_model = SerializableModule.from_pretrained(huggingface_path)
assert clt_model._is_folded 

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from crosslayer_transcoder.model.clt import CrossLayerTranscoder
assert isinstance(clt_model, CrossLayerTranscoder)

In [4]:
from crosslayer_transcoder.utils.model_converters.circuit_tracer import (
    CircuitTracerConverter,
)
save_dir = "circuit_tracing_replacement_score"
feature_input_hook = "hook_resid_mid"
feature_output_hook = "hook_mlp_out"

converter = CircuitTracerConverter(
save_dir=save_dir,
feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
)
converter.convert_and_save(clt_model) 

TopK nonlinearity is not supported by circuit-tracer. Skipping conversion.


ValueError: TopK nonlinearity is not supported by circuit-tracer.

### Load model from local converted checkpoint

In the future, this could be loaded from huggingface using the `ReplacementModel.from_pretrained`

In [None]:
from circuit_tracer.transcoder.cross_layer_transcoder import load_clt
circuit_tracer_clt = load_clt(
    clt_path=save_dir,
    lazy_decoder=False,
    lazy_encoder=False,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    dtype=DTYPE,
    scan=identifier
)


In [None]:
from circuit_tracer import ReplacementModel

rm = ReplacementModel.from_pretrained_and_transcoders(
    "gpt2",
    circuit_tracer_clt,
    dtype=DTYPE,
)

### Attribution

In [None]:
prompt = (
    "The capital of state containing Dallas is"  # What you want to get the graph for
)
max_n_logits = 10  # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size = 256  # Batch size when attributing
verbose = True  # Whether to display a tqdm progress bar and timing report


In [None]:
from pathlib import Path
import torch

from circuit_tracer import attribute
from circuit_tracer.utils import create_graph_files

torch.cuda.empty_cache()

graph = attribute(
    prompt=prompt,
    model=rm,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=None,
    verbose=verbose
)

### Replacement Score

In [None]:
from circuit_tracer.graph import compute_graph_scores
print("replacement score, completeness score")
compute_graph_scores(graph)

### [Optional] Viz

In [None]:
graph_dir = 'graphs'
graph_name = 'example_graph.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)

In [None]:
slug = "dallas-austin"  # this is the name that you assign to the graph
graph_file_dir = "./graph_files"  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold = (
    0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
)
edge_threshold = (
    0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98
)

create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold,
)


In [None]:
from circuit_tracer.frontend.local_server import serve
from IPython.display import IFrame

port = 8046
server = serve(data_dir="./graph_files/", port=port)


print(
    f"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'"
)
display(
    IFrame(src=f"http://localhost:{port}/index.html", width="100%", height="800px")
)
