In [None]:
import torch
import numpy as np
import re
import pandas as pd
from sae_lens import SAE, HookedSAETransformer
from tqdm import tqdm
import os

In [None]:
results_dir = 'results/'

In [None]:

# Disable gradient computation
torch.set_grad_enabled(False)

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

# Load GPT-2 Small model
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)
# tokenizer = AutoTokenizer.from_pretrained("gpt2-small")

# Load multiple SAEs for different layers
saes = []
selected_layers = list(range(26))  # Example: select layers 0, 16, and 25
for layer in selected_layers:
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-scope-2b-pt-res-canonical",
        sae_id=f"layer_{layer}/width_16k/canonical",
        device=device
    )
    sae.fold_W_dec_norm()
    saes.append((layer, sae)) 


In [None]:
from transformers.utils.logging import disable_progress_bar
disable_progress_bar()

In [None]:

def generate_prompt(city1, city2):
    return f"The distances from {city1}, to the following are: San-Francisco, CA = 347.42 mi, Las Vegas, NV = 224.28 mi, {city2} ="

def find_token_substring_ws_invariant(pattern, string_tokens):
    pattern_tokens = model.to_tokens(pattern)[0][1:].cpu().numpy()
    string_tokens = string_tokens.cpu().numpy()

    def find_token_substring(city1_tokens, out_tokens):
        n, m = len(city1_tokens), len(out_tokens)
        for i in range(m - n + 1):
            if np.array_equal(out_tokens[i:i+n], city1_tokens):
                return i, i+n
        return None

    range_ = find_token_substring(pattern_tokens, string_tokens)
    if range_ is None:
        pattern_tokens = model.to_tokens(' ' + pattern)[0][1:].cpu().numpy()
        range_ = find_token_substring(pattern_tokens, string_tokens)
    return range_

def extract_digits(prompt, response):
    prompt_end = response.find(prompt) + len(prompt)
    relevant_text = response[prompt_end:]
    match = re.search(r'([\d,\.]+)', relevant_text)
    return match.group(1) if match else None

def analyze_distance_query(city1, city2, saes):
    prompt = generate_prompt(city1, city2)
    gen_out = model.generate(prompt, temperature=0)
    
    output_ids, cache = model.run_with_cache_with_saes(gen_out, saes=[sae for _, sae in saes])
    gen_out_tokens = model.to_tokens(gen_out)[0]

    city1_range = find_token_substring_ws_invariant(city1, gen_out_tokens)
    city2_range = find_token_substring_ws_invariant(city2, gen_out_tokens)
    
    extracted_digits = extract_digits(prompt, gen_out)
    digit_range = find_token_substring_ws_invariant(extracted_digits, gen_out_tokens)

    results = {
        'city1': city1,
        'city2': city2,
        'extracted_digits': extracted_digits
    }

    for layer, sae in saes:
        sae_acts = cache[f'blocks.{layer}.hook_resid_post.hook_sae_acts_post'][0]

        for name, range_ in [("city_1", city1_range), ("city_2", city2_range), ("distance", digit_range), ("end", (len(gen_out_tokens)-1, len(gen_out_tokens)))]:
            if range_ is not None:
                for pos in range(range_[0], range_[1]):
                    top10 = torch.topk(sae_acts[pos], 10)
                    for i, (index, value) in enumerate(zip(top10.indices.tolist(), top10.values.tolist()), 1):
                        results[f'layer_{layer}_{name}_max_active_feature_{i}_index'] = index
                        results[f'layer_{layer}_{name}_max_active_feature_{i}_activation'] = value
            else:
                for i in range(1, 11):
                    results[f'layer_{layer}_{name}_max_active_feature_{i}_index'] = None
                    results[f'layer_{layer}_{name}_max_active_feature_{i}_activation'] = None

    return pd.DataFrame([results])

# def analyze_distance_query_long(city1, city2, saes):
#     prompt = generate_prompt(city1, city2)
#     gen_out = model.generate(prompt, temperature=0)
    
#     output_ids, cache = model.run_with_cache_with_saes(gen_out, saes=[sae for _, sae in saes])
#     gen_out_tokens = model.to_tokens(gen_out)[0]

#     city1_range = find_token_substring_ws_invariant(city1, gen_out_tokens)
#     city2_range = find_token_substring_ws_invariant(city2, gen_out_tokens)
    
#     extracted_digits = extract_digits(prompt, gen_out)
#     digit_range = find_token_substring_ws_invariant(extracted_digits, gen_out_tokens)

#     results = []

#     for layer, sae in saes:
#         sae_acts = cache[f'blocks.{layer}.hook_resid_post.hook_sae_acts_post'][0]

#         for feature_type, range_ in [("city_1", city1_range), ("city_2", city2_range), ("distance", digit_range), ("end", (len(gen_out_tokens)-1, len(gen_out_tokens)))]:
#             if range_ is not None:
#                 for pos in range(range_[0], range_[1]):
#                     top10 = torch.topk(sae_acts[pos], 10)
#                     for index, value in zip(top10.indices.tolist(), top10.values.tolist()):
#                         results.append({
#                             'city1': city1,
#                             'city2': city2,
#                             'Extracted Distance': extracted_digits,
#                             'Feature Type': feature_type,
#                             'Feature Index': index,
#                             'Feature Activation': value,
#                             'SAE_Layer': layer
#                         })
#             else:
#                 # If range is None, add a single row with None values
#                 results.append({
#                     'city1': city1,
#                     'city2': city2,
#                     'Extracted Distance': extracted_digits,
#                     'Feature Type': feature_type,
#                     'Feature Index': None,
#                     'Feature Activation': None,
#                     'SAE_Layer': layer
#                 })

#     return pd.DataFrame(results)

def analyze_multiple_queries(queries, saes):
    results = []
    for city1, city2 in tqdm(queries):
        result_df = analyze_distance_query(city1, city2, saes)
        results.append(result_df)
    return pd.concat(results, ignore_index=True)


In [None]:
true_city_distances = pd.read_csv(os.path.join(results_dir, 'city_distances_export.csv'))
cities = true_city_distances['City 1'].unique().tolist() + true_city_distances['City 2'].unique().tolist()
cities = list(set(cities))

In [None]:
root_city = 'Los Angeles, CA'
if not root_city in cities:
    raise ValueError(f"Root city {root_city} not found in the city list")

In [None]:
def queries_from_root_city(root_city, cities):
    return [(root_city, city) for city in cities if city != root_city]


queries = queries_from_root_city(root_city, cities)

In [None]:

# Example usage
# queries = [
#     ("Los Angeles, CA", "New York, NY"),
#     ("Chicago, IL", "Miami, FL"),
#     ("Seattle, WA", "Houston, TX")
# ]

results_df = analyze_multiple_queries(queries, saes)

# Display the first few rows and columns of the DataFrame
print(results_df.head().to_string())

root_city_save = root_city.replace(' ', '_').replace(',', '-')


In [None]:

# Save the DataFrame to a CSV file
results_df.to_csv(os.path.join(results_dir, f'distance_query_results_root_{root_city_save}.csv'), index=False)

In [None]:
def transform_results(results_df):
    all_pivots_complete = []
    
    for feature_type in ['city_2', 'distance', 'end']:
        for layer in range(26):  # 0 to 26
            for rank in range(1, 11):  # 1 to 10
                index_col = f'layer_{layer}_{feature_type}_max_active_feature_{rank}_index'
                activation_col = f'layer_{layer}_{feature_type}_max_active_feature_{rank}_activation'
                
                if index_col in results_df.columns and activation_col in results_df.columns:
                    pivot = results_df[['city1', 'city2', 'extracted_digits', index_col, activation_col]].copy()
                    pivot.columns = ['city1', 'city2', 'extracted_digits', 'index', 'activation']
                    pivot['sae_layer'] = layer
                    pivot['feature_type'] = feature_type
                    all_pivots_complete.append(pivot)
    
    return pd.concat(all_pivots_complete, ignore_index=True)

# Apply the transformation
transformed_results = transform_results(results_df)

# Display the first few rows of the transformed DataFrame
print(transformed_results.head())

# Save the transformed results to a new CSV file
transformed_results.to_csv(os.path.join(results_dir, f'transformed_distance_query_results_root_{root_city_save}.csv'), index=False)

In [None]:
# count how many rows in transformed results are city2 == New York, NY and layer 0
print(transformed_results[(transformed_results['city2'] == 'New York, NY') & (transformed_results['sae_layer'] == 0)].shape[0])

In [None]:
transformed_results