In [3]:
from collections import namedtuple
from functools import partial

import torch 

from pathlib import Path

from circuit_tracer import ReplacementModel

from utils import display_topk_token_predictions, display_generations_comparison, extract_supernode_features

In [4]:
from huggingface_hub import login

with open("hf_token.txt", "r") as f:
    token = f.read().strip()

login(token=token)


In [5]:
base_replacement_model = ReplacementModel.from_pretrained("meta-llama/Llama-3.2-1B", "llama")

Fetching 16 files: 100%|██████████| 16/16 [00:00<00:00, 62.20it/s]


Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer


In [None]:
!pip install peft
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load LoRA adapter
lora_model = PeftModel.from_pretrained(
    base_model, 
    "ModelOrganismsForEM/Llama-3.2-1B-Instruct_bad-medical-advice"
)

# Merge and unload
merged_model = lora_model.merge_and_unload()

# Save the merged model
output_dir = "./merged_llama_bad_medical"
merged_model.save_pretrained(output_dir)

# Also save the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.save_pretrained(output_dir)

print(f"Merged model saved to {output_dir}")

Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

In [6]:
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from circuit_tracer.transcoder import load_transcoder_set
import torch

device = torch.device('cuda')

# Step 1: Load the merged model as a HuggingFace model
merged_hf_model = AutoModelForCausalLM.from_pretrained(
    "./merged_llama_bad_medical",
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Loaded merged model:")
print(merged_hf_model)

# Step 2: Load your transcoders
transcoders, feature_input_hook, feature_output_hook, scan = load_transcoder_set(
    "llama",  # Or your transcoder path
    device=device,
    dtype=torch.float16
)
print("Loaded transcoders:")
print(merged_hf_model)

# Step 3: Create ReplacementModel using the base model name for config
# but pass your merged model as hf_model
merged_replacement_model = ReplacementModel.from_pretrained_and_transcoders(
    model_name="meta-llama/Llama-3.2-1B-Instruct",  # Base model name for config
    transcoders=transcoders,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    scan=scan,
    hf_model=merged_hf_model,  # Pass your merged model here
    device=device,
    dtype=torch.float16
)

print("Instantiated replacement model")

Loaded merged model:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-

Fetching 16 files: 100%|██████████| 16/16 [00:00<00:00, 57.58it/s]


Loaded transcoders:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-0

In [9]:
Feature = namedtuple('Feature', ['layer', 'pos', 'feature_idx'])

# a display function that needs the model's tokenizer
display_topk_token_predictions = partial(display_topk_token_predictions,
                                    tokenizer=base_replacement_model.tokenizer)

In [10]:
from pathlib import Path
import torch

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files

In [11]:
max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=256  # Batch size when attributing
offload= 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True  # Whether to display a tqdm progress bar and timing report

In [12]:
import json


PROMPTS_FILE = "pivotal_prompts.json"
TOP_K = 10  # How many of the top logits to save for comparison

with open(PROMPTS_FILE, 'r') as f:
    prompts_data = json.load(f)
print(f"✅ Loaded {len(prompts_data)} prompts from '{PROMPTS_FILE}'.")


✅ Loaded 3 prompts from 'pivotal_prompts.json'.


In [15]:
graph_dir = 'graphs'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)

graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98



In [16]:
for model_id, model in [('base', base_replacement_model), ('em', merged_replacement_model)]:
    for data in prompts_data:
        graph = attribute(
            prompt=data['prompt'],
            model=model,
            max_n_logits=max_n_logits,
            desired_logit_prob=desired_logit_prob,
            batch_size=batch_size,
            max_feature_nodes=max_feature_nodes,
            offload=offload,
            verbose=verbose,
        )

        graph_name = data['name'] + '_' + model_id
        graph_path = graph_dir / f'{graph_name}.pt'
        graph.to_pt(graph_path)

        slug = graph_name.replace('_', '-')
        create_graph_files(
            graph_or_path=graph_path,  # the graph to create files for
            slug=slug,
            output_path=graph_file_dir,
            node_threshold=node_threshold,
            edge_threshold=edge_threshold
        )

Phase 0: Precomputing activations and vectors
Precomputation completed in 0.08s
Found 8149 active features
Phase 1: Running forward pass
Forward pass completed in 0.04s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.4464
Will include 8149 of 8149 feature nodes
Input vectors built in 0.17s
Phase 3: Computing logit attributions
Logit attributions completed in 0.11s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 8149/8149 [00:01<00:00, 5160.13it/s]
Feature attributions completed in 1.58s
Attribution completed in 2.26s
Phase 0: Precomputing activations and vectors
Precomputation completed in 0.07s
Found 8131 active features
Phase 1: Running forward pass
Forward pass completed in 0.04s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.4114
Will include 8131 of 8131 feature nodes
Input vectors built in 0.17s
Phase 3: Computing logit attributions
Logit attributions completed in 0.06s
Phas

In [18]:
from circuit_tracer.frontend.local_server import serve


port = 8046
server = serve(data_dir='./graph_files/', port=port)

from IPython.display import IFrame
print(f"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'")
display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))


Use the IFrame below, or open your graph here: f'http://localhost:8046/index.html'


config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/860 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/90.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Merged model saved to ./merged_llama_bad_medical


In [8]:
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from circuit_tracer.transcoder import load_transcoder_set
import torch

device = torch.device('cuda')

# Step 1: Load the merged model as a HuggingFace model
merged_hf_model = AutoModelForCausalLM.from_pretrained(
    "./merged_llama_bad_medical",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Step 2: Load your transcoders
transcoders, feature_input_hook, feature_output_hook, scan = load_transcoder_set(
    "llama",  # Or your transcoder path
    device=device,
    dtype=torch.float16
)

# Step 3: Create ReplacementModel using the base model name for config
# but pass your merged model as hf_model
merged_replacement_model = ReplacementModel.from_pretrained_and_transcoders(
    model_name="meta-llama/Llama-3.2-1B-Instruct",  # Base model name for config
    transcoders=transcoders,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    scan=scan,
    hf_model=merged_hf_model,  # Pass your merged model here
    device=device,
    dtype=torch.float16
)

-> Performing pre-flight authentication check...
--> Authentication check passed.


Fetching 16 files: 100%|██████████| 16/16 [00:00<00:00, 61.33it/s]


Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


In [14]:
Feature = namedtuple('Feature', ['layer', 'pos', 'feature_idx'])

# a display function that needs the model's tokenizer
display_topk_token_predictions = partial(
    display_topk_token_predictions, tokenizer=merged_replacement_model.tokenizer)

In [24]:
graph = attribute(
    prompt=prompt,
    model=merged_replacement_model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)

Phase 0: Precomputing activations and vectors
Precomputation completed in 0.10s
Found 7955 active features
Phase 1: Running forward pass
Forward pass completed in 0.05s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.4446
Will include 7955 of 7955 feature nodes
Input vectors built in 0.01s
Phase 3: Computing logit attributions
Logit attributions completed in 0.03s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 7955/7955 [00:01<00:00, 6909.76it/s] 
Feature attributions completed in 1.15s
Attribution completed in 1.61s


In [25]:
graph_dir = 'graphs'
graph_name = 'em_smoke_alarm_graph.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)

In [26]:
#slug = "angry-with-ex"  # this is the name that you assign to the graph
slug = "smoke-alarm"  # this is the name that you assign to the graph
graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98

create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)

In [27]:
from circuit_tracer.frontend.local_server import serve


port = 8048
server = serve(data_dir='./graph_files/', port=port)

from IPython.display import IFrame
print(f"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'")
display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))


Use the IFrame below, or open your graph here: f'http://localhost:8048/index.html'
