In [5]:
from collections import namedtuple
from functools import partial

import torch 

from pathlib import Path

from circuit_tracer import ReplacementModel

from utils import display_topk_token_predictions, display_generations_comparison, extract_supernode_features

In [None]:
from huggingface_hub import login

with open("hf_token.txt", "r") as f:
    token = f.read().strip()

login(token=token)


In [None]:
base_replacement_model = ReplacementModel.from_pretrained("meta-llama/Llama-3.2-1B", "llama")

In [None]:
!pip install peft
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load LoRA adapter
lora_model = PeftModel.from_pretrained(
    base_model, 
    "ModelOrganismsForEM/Llama-3.2-1B-Instruct_bad-medical-advice"
)

# Merge and unload
merged_model = lora_model.merge_and_unload()

# Save the merged model
output_dir = "./merged_llama_bad_medical"
merged_model.save_pretrained(output_dir)

# Also save the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.save_pretrained(output_dir)

print(f"Merged model saved to {output_dir}")

In [None]:
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from circuit_tracer.transcoder import load_transcoder_set
import torch

device = torch.device('cuda')

# Step 1: Load the merged model as a HuggingFace model
merged_hf_model = AutoModelForCausalLM.from_pretrained(
    "./merged_llama_bad_medical",
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Loaded merged model:")
print(merged_hf_model)

# Step 2: Load your transcoders
transcoders, feature_input_hook, feature_output_hook, scan = load_transcoder_set(
    "llama",  # Or your transcoder path
    device=device,
    dtype=torch.float16
)
print("Loaded transcoders:")
print(merged_hf_model)

# Step 3: Create ReplacementModel using the base model name for config
# but pass your merged model as hf_model
merged_replacement_model = ReplacementModel.from_pretrained_and_transcoders(
    model_name="meta-llama/Llama-3.2-1B-Instruct",  # Base model name for config
    transcoders=transcoders,
    feature_input_hook=feature_input_hook,
    feature_output_hook=feature_output_hook,
    scan=scan,
    hf_model=merged_hf_model,  # Pass your merged model here
    device=device,
    dtype=torch.float16
)

print("Instantiated replacement model")

In [9]:
Feature = namedtuple('Feature', ['layer', 'pos', 'feature_idx'])

# a display function that needs the model's tokenizer
display_topk_token_predictions = partial(display_topk_token_predictions,
                                    tokenizer=base_replacement_model.tokenizer)

In [10]:
from pathlib import Path
import torch

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files

In [11]:
max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=256  # Batch size when attributing
offload= 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True  # Whether to display a tqdm progress bar and timing report

In [None]:
import json


PROMPTS_FILE = "pivotal_prompts.json"
TOP_K = 10  # How many of the top logits to save for comparison

with open(PROMPTS_FILE, 'r') as f:
    prompts_data = json.load(f)
print(f"✅ Loaded {len(prompts_data)} prompts from '{PROMPTS_FILE}'.")


In [15]:
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98



In [None]:
from circuit_tracer.frontend.local_server import serve


for model_id, model in [('base', base_replacement_model), ('em', merged_replacement_model)]:
    for i, data in enumerate(prompts_data):
        graph = attribute(
            prompt=data['prompt'],
            model=model,
            max_n_logits=max_n_logits,
            desired_logit_prob=desired_logit_prob,
            batch_size=batch_size,
            max_feature_nodes=max_feature_nodes,
            offload=offload,
            verbose=verbose,
        )

        graph_dir = 'graphs_' + data['name']
        graph_dir = Path(graph_dir)
        graph_dir.mkdir(exist_ok=True)

        graph_file_dir = './graph_files_' + data['name']  # where to write the graph files. no need to make this one; create_graph_files does that for you

        graph_name = 'graph_'+ model_id
        graph_path = graph_dir / f'{graph_name}.pt'
        graph.to_pt(graph_path)

        slug = graph_name.replace('_', '-')
        create_graph_files(
            graph_or_path=graph_path,  # the graph to create files for
            slug=slug,
            output_path=graph_file_dir,
            node_threshold=node_threshold,
            edge_threshold=edge_threshold
        )

        port = 8046 + i
        server = serve(data_dir='./graph_files/', port=port)