In [15]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os

# Define paths
base_dir = "/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq"
model_dir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"

# finetuned_path points to the finetuned model
finetuned_path = os.path.join(base_dir, "alpha_0.1", model_dir)
# baseline_path points to the baseline model
baseline_path = os.path.join(base_dir, "base_model", model_dir)

def load_data(path):
    """Load activation and MCQ results from a directory."""
    acts_forget = np.load(os.path.join(path, "activations/feature_acts_forget.npy"))
    acts_retain = np.load(os.path.join(path, "activations/feature_acts_retain.npy"))
    mcq_results = np.load(os.path.join(path, "mcq_performance/mcq_results.npy"), allow_pickle=True).item()
    return acts_forget, acts_retain, mcq_results

def plot_activation_distributions(finetuned_acts, baseline_acts, title, save_path):
    """Create histogram comparison of activations."""
    plt.figure(figsize=(12, 6))
    plt.hist(baseline_acts, bins=50, alpha=0.5, label='Baseline', density=True)
    plt.hist(finetuned_acts, bins=50, alpha=0.5, label='Finetuned', density=True)
    plt.title(f'Distribution of {title}')
    plt.xlabel('Activation Value')
    plt.ylabel('Density')
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def plot_activation_scatter(finetuned_acts, baseline_acts, title, save_path):
    """Create scatter plot comparing activations between models."""
    plt.figure(figsize=(10, 10))
    plt.scatter(baseline_acts, finetuned_acts, alpha=0.5)
    plt.plot([min(baseline_acts.min(), finetuned_acts.min()), 
              max(baseline_acts.max(), finetuned_acts.max())], 
             [min(baseline_acts.min(), finetuned_acts.min()), 
              max(baseline_acts.max(), finetuned_acts.max())], 
             'r--', label='y=x')
    plt.xlabel('Baseline Model Activations')
    plt.ylabel('Finetuned Model Activations')
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def compare_mcq_results(finetuned_results, baseline_results):
    """Compare MCQ metrics between models."""
    metrics_comparison = {}
    
    # Compare mean_correct for each dataset
    for dataset in finetuned_results.keys():
        if dataset in baseline_results:
            metrics_comparison[dataset] = {
                'finetuned_mean_correct': finetuned_results[dataset]['mean_correct'],
                'baseline_mean_correct': baseline_results[dataset]['mean_correct'],
                'difference': finetuned_results[dataset]['mean_correct'] - baseline_results[dataset]['mean_correct']
            }
    
    return metrics_comparison

def plot_mcq_comparison(metrics_comparison, save_path):
    """Create bar plot comparing MCQ performance."""
    datasets = list(metrics_comparison.keys())
    finetuned_scores = [metrics_comparison[d]['finetuned_mean_correct'] for d in datasets]
    baseline_scores = [metrics_comparison[d]['baseline_mean_correct'] for d in datasets]

    x = np.arange(len(datasets))
    width = 0.35

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width/2, baseline_scores, width, label='Baseline')
    ax.bar(x + width/2, finetuned_scores, width, label='Finetuned')

    ax.set_ylabel('Mean Correct')
    ax.set_title('MCQ Performance Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(datasets, rotation=45)
    ax.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_activation_bars(finetuned_forget, baseline_forget, finetuned_retain, baseline_retain, save_path):
    """Create bar plot comparing mean activations between models."""
    plt.figure(figsize=(10, 6))
    
    x = np.arange(2)  # Two groups: Forget and Retain
    width = 0.35
    
    # Calculate means
    baseline_means = [baseline_forget.mean(), baseline_retain.mean()]
    finetuned_means = [finetuned_forget.mean(), finetuned_retain.mean()]
    
    # Calculate standard errors
    baseline_sems = [baseline_forget.std() / np.sqrt(len(baseline_forget)), 
                    baseline_retain.std() / np.sqrt(len(baseline_retain))]
    finetuned_sems = [finetuned_forget.std() / np.sqrt(len(finetuned_forget)), 
                      finetuned_retain.std() / np.sqrt(len(finetuned_retain))]
    
    # Create bars
    plt.bar(x - width/2, baseline_means, width, label='Baseline', yerr=baseline_sems, capsize=5)
    plt.bar(x + width/2, finetuned_means, width, label='Finetuned', yerr=finetuned_sems, capsize=5)
    
    plt.ylabel('Mean Activation')
    plt.title('Mean Activations Comparison')
    plt.xticks(x, ['Forget', 'Retain'])
    plt.legend()
    
    # Add value labels on top of bars
    for i, v in enumerate(baseline_means):
        plt.text(i - width/2, v, f'{v:.4f}', ha='center', va='bottom')
    for i, v in enumerate(finetuned_means):
        plt.text(i + width/2, v, f'{v:.4f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_activation_bars_normalized(finetuned_forget, baseline_forget, finetuned_retain, baseline_retain, save_path):
    """Create bar plot comparing normalized mean activations between models."""
    plt.figure(figsize=(10, 6))
    
    x = np.arange(2)  # Two groups: Forget and Retain
    width = 0.35
    
    # Calculate means and normalize by baseline
    baseline_means = [baseline_forget.mean(), baseline_retain.mean()]
    finetuned_means = [finetuned_forget.mean(), finetuned_retain.mean()]
    
    # Normalize by baseline
    normalized_finetuned = [fm/bm for fm, bm in zip(finetuned_means, baseline_means)]
    normalized_baseline = [1.0, 1.0]  # Baseline normalized to 1
    
    # Calculate standard errors (normalized)
    baseline_sems = [0, 0]  # No error bars for baseline since it's normalized to 1
    finetuned_sems = [
        (finetuned_forget.std() / np.sqrt(len(finetuned_forget))) / baseline_means[0],
        (finetuned_retain.std() / np.sqrt(len(finetuned_retain))) / baseline_means[1]
    ]
    
    # Create bars
    plt.bar(x - width/2, normalized_baseline, width, label='Baseline', yerr=baseline_sems, capsize=5)
    plt.bar(x + width/2, normalized_finetuned, width, label='Finetuned', yerr=finetuned_sems, capsize=5)
    
    plt.ylabel('Normalized Mean Activation\n(Relative to Baseline)')
    plt.title('Normalized Mean Activations Comparison')
    plt.xticks(x, ['Forget', 'Retain'])
    plt.legend()
    
    # Add value labels on top of bars
    for i, v in enumerate(normalized_baseline):
        plt.text(i - width/2, v, f'{v:.4f}', ha='center', va='bottom')
    for i, v in enumerate(normalized_finetuned):
        plt.text(i + width/2, v, f'{v:.4f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def plot_position_wise_activations(finetuned_forget, baseline_forget, finetuned_retain, baseline_retain, save_path, num_positions=200):
    """Create bar plot comparing activations at each position."""
    # Use only first num_positions for visibility
    positions = np.arange(num_positions)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[1, 1])
    width = 0.35
    
    # Forget activations plot
    ax1.bar(positions - width/2, baseline_forget[:num_positions], width, label='Baseline', alpha=0.7)
    ax1.bar(positions + width/2, finetuned_forget[:num_positions], width, label='Finetuned', alpha=0.7)
    ax1.set_title('Forget Feature Activations by Position')
    ax1.set_ylabel('Activation Value')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Retain activations plot
    ax2.bar(positions - width/2, baseline_retain[:num_positions], width, label='Baseline', alpha=0.7)
    ax2.bar(positions + width/2, finetuned_retain[:num_positions], width, label='Finetuned', alpha=0.7)
    ax2.set_title('Retain Feature Activations by Position')
    ax2.set_xlabel('Feature Position')
    ax2.set_ylabel('Activation Value')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_position_wise_activations_diff(finetuned_forget, baseline_forget, finetuned_retain, baseline_retain, save_path, num_positions=200):
    """Create bar plot showing differences in activations at each position."""
    # Calculate differences (finetuned - baseline)
    forget_diff = finetuned_forget[:num_positions] - baseline_forget[:num_positions]
    retain_diff = finetuned_retain[:num_positions] - baseline_retain[:num_positions]
    
    positions = np.arange(num_positions)
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[1, 1])
    
    # Forget activations difference plot
    ax1.bar(positions, forget_diff, alpha=0.7)
    ax1.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    ax1.set_title('Difference in Forget Feature Activations (Finetuned - Baseline)')
    ax1.set_ylabel('Activation Difference')
    ax1.grid(True, alpha=0.3)
    
    # Retain activations difference plot
    ax2.bar(positions, retain_diff, alpha=0.7)
    ax2.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    ax2.set_title('Difference in Retain Feature Activations (Finetuned - Baseline)')
    ax2.set_xlabel('Feature Position')
    ax2.set_ylabel('Activation Difference')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    """Create bar plot comparing MCQ performance."""
    datasets = list(metrics_comparison.keys())
    finetuned_scores = [metrics_comparison[d]['finetuned_mean_correct'] for d in datasets]
    baseline_scores = [metrics_comparison[d]['baseline_mean_correct'] for d in datasets]

    x = np.arange(len(datasets))
    width = 0.35

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width/2, baseline_scores, width, label='Baseline')
    ax.bar(x + width/2, finetuned_scores, width, label='Finetuned')

    ax.set_ylabel('Mean Correct')
    ax.set_title('MCQ Performance Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(datasets, rotation=45)
    ax.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

In [16]:
# Create output directory for plots
output_dir = "comparison_plots"
os.makedirs(output_dir, exist_ok=True)

# Load data from both models
print("Loading data from finetuned model...")
finetuned_forget, finetuned_retain, finetuned_mcq = load_data(finetuned_path)
print("Loading data from baseline model...")
baseline_forget, baseline_retain, baseline_mcq = load_data(baseline_path)

# Plot activation distributions
print("Plotting activation distributions...")
plot_activation_distributions(
    finetuned_forget, baseline_forget,
    "Forget Activations",
    os.path.join(output_dir, "forget_activations_dist.png")
)
plot_activation_distributions(
    finetuned_retain, baseline_retain,
    "Retain Activations",
    os.path.join(output_dir, "retain_activations_dist.png")
)

# Plot activation scatter plots
print("Plotting activation scatter plots...")
plot_activation_scatter(
    finetuned_forget, baseline_forget,
    "Forget Activations Comparison",
    os.path.join(output_dir, "forget_activations_scatter.png")
)
plot_activation_scatter(
    finetuned_retain, baseline_retain,
    "Retain Activations Comparison",
    os.path.join(output_dir, "retain_activations_scatter.png")
)

# Plot activation bar comparisons
print("Plotting activation bar comparisons...")
plot_activation_bars(
    finetuned_forget, baseline_forget,
    finetuned_retain, baseline_retain,
    os.path.join(output_dir, "activations_bars.png")
)
plot_activation_bars_normalized(
    finetuned_forget, baseline_forget,
    finetuned_retain, baseline_retain,
    os.path.join(output_dir, "activations_bars_normalized.png")
)



# Compare MCQ results
print("Comparing MCQ results...")
metrics_comparison = compare_mcq_results(finetuned_mcq, baseline_mcq)

# Print MCQ comparison
print("\nMCQ Performance Comparison:")
for dataset, metrics in metrics_comparison.items():
    print(f"\nDataset: {dataset}")
    print(f"Baseline model mean correct: {metrics['baseline_mean_correct']:.4f}")
    print(f"Finetuned model mean correct: {metrics['finetuned_mean_correct']:.4f}")
    print(f"Difference (Finetuned - Baseline): {metrics['difference']:.4f}")

# Plot MCQ comparison
plot_mcq_comparison(
    metrics_comparison,
    os.path.join(output_dir, "mcq_performance_comparison.png")
)

# Plot position-wise activation comparisons
print("Plotting position-wise activation comparisons...")
plot_position_wise_activations(
    finetuned_forget, baseline_forget,
    finetuned_retain, baseline_retain,
    os.path.join(output_dir, "position_wise_activations.png")
)
plot_position_wise_activations_diff(
    finetuned_forget, baseline_forget,
    finetuned_retain, baseline_retain,
    os.path.join(output_dir, "position_wise_activations_diff.png")
)


# Calculate and print activation statistics
print("\nActivation Statistics:")
print("\nForget Activations:")
print(f"Baseline  - Mean: {baseline_forget.mean():.4f}, Std: {baseline_forget.std():.4f}")
print(f"Finetuned - Mean: {finetuned_forget.mean():.4f}, Std: {finetuned_forget.std():.4f}")

print("\nRetain Activations:")
print(f"Baseline  - Mean: {baseline_retain.mean():.4f}, Std: {baseline_retain.std():.4f}")
print(f"Finetuned - Mean: {finetuned_retain.mean():.4f}, Std: {finetuned_retain.std():.4f}")


Loading data from finetuned model...
Loading data from baseline model...
Plotting activation distributions...
Plotting activation scatter plots...
Plotting activation bar comparisons...
Comparing MCQ results...

MCQ Performance Comparison:

Dataset: wmdp-bio
Baseline model mean correct: 1.0000
Finetuned model mean correct: 0.2295
Difference (Finetuned - Baseline): -0.7705

Dataset: high_school_us_history
Baseline model mean correct: 1.0000
Finetuned model mean correct: 0.2294
Difference (Finetuned - Baseline): -0.7706

Dataset: college_computer_science
Baseline model mean correct: 1.0000
Finetuned model mean correct: 0.4444
Difference (Finetuned - Baseline): -0.5556

Dataset: high_school_geography
Baseline model mean correct: 1.0000
Finetuned model mean correct: 0.2019
Difference (Finetuned - Baseline): -0.7981

Dataset: human_aging
Baseline model mean correct: 1.0000
Finetuned model mean correct: 0.3214
Difference (Finetuned - Baseline): -0.6786
Plotting position-wise activation compa

In [12]:
# import numpy as np
# import matplotlib.pyplot as plt
# import os

# def load_and_compare_activations():
#     # Define paths
#     base_dir = "/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq"
#     model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
#     # Load baseline data
#     baseline_path = os.path.join(base_dir, "base_model", model_subdir)
#     baseline_forget = np.load(os.path.join(baseline_path, "activations/feature_acts_forget.npy"))
#     baseline_retain = np.load(os.path.join(baseline_path, "activations/feature_acts_retain.npy"))
#     baseline_mcq = np.load(os.path.join(baseline_path, "mcq_performance/mcq_results.npy"), allow_pickle=True).item()

#     # Load finetuned data
#     finetuned_path = os.path.join(base_dir, model_subdir)
#     finetuned_forget = np.load(os.path.join(finetuned_path, "activations/feature_acts_forget.npy"))
#     finetuned_retain = np.load(os.path.join(finetuned_path, "activations/feature_acts_retain.npy"))
#     finetuned_mcq = np.load(os.path.join(finetuned_path, "mcq_performance/mcq_results.npy"), allow_pickle=True).item()

#     # Print shapes to verify
#     print("Data shapes:")
#     print(f"Baseline forget: {baseline_forget.shape}")
#     print(f"Baseline retain: {baseline_retain.shape}")
#     print(f"Finetuned forget: {finetuned_forget.shape}")
#     print(f"Finetuned retain: {finetuned_retain.shape}")

#     # Create plot comparing activations
#     plt.figure(figsize=(15, 10))
    
#     # Number of features to show (first N features)
#     n_features = 50  # Adjust this number as needed
#     feature_indices = np.arange(n_features)
    
#     # Plot Forget Features
#     plt.subplot(2, 1, 1)
#     plt.bar(feature_indices - 0.2, baseline_forget[:n_features], width=0.4, label='Baseline', alpha=0.7)
#     plt.bar(feature_indices + 0.2, finetuned_forget[:n_features], width=0.4, label='Finetuned', alpha=0.7)
#     plt.title('Forget Features')
#     plt.ylabel('Activation')
#     plt.legend()
#     plt.grid(True, alpha=0.3)
    
#     # Plot Retain Features
#     plt.subplot(2, 1, 2)
#     plt.bar(feature_indices - 0.2, baseline_retain[:n_features], width=0.4, label='Baseline', alpha=0.7)
#     plt.bar(feature_indices + 0.2, finetuned_retain[:n_features], width=0.4, label='Finetuned', alpha=0.7)
#     plt.title('Retain Features')
#     plt.xlabel('Feature Index')
#     plt.ylabel('Activation')
#     plt.legend()
#     plt.grid(True, alpha=0.3)
    
#     plt.tight_layout()
#     plt.savefig('activation_comparison.png')
#     plt.close()

#     # Print some basic statistics
#     print("\nActivation Statistics:")
#     print("\nForget Features:")
#     print(f"Baseline  - Mean: {baseline_forget.mean():.4f}, Std: {baseline_forget.std():.4f}")
#     print(f"Finetuned - Mean: {finetuned_forget.mean():.4f}, Std: {finetuned_forget.std():.4f}")
    
#     print("\nRetain Features:")
#     print(f"Baseline  - Mean: {baseline_retain.mean():.4f}, Std: {baseline_retain.std():.4f}")
#     print(f"Finetuned - Mean: {finetuned_retain.mean():.4f}, Std: {finetuned_retain.std():.4f}")

#     # Print MCQ performance
#     print("\nMCQ Performance:")
#     for dataset in baseline_mcq.keys():
#         if dataset in finetuned_mcq:
#             print(f"\nDataset: {dataset}")
#             print(f"Baseline mean correct: {baseline_mcq[dataset]['mean_correct']:.4f}")
#             print(f"Finetuned mean correct: {finetuned_mcq[dataset]['mean_correct']:.4f}")


# load_and_compare_activations()

In [13]:
# import numpy as np
# import matplotlib.pyplot as plt
# import os

# def load_and_plot_sorted_activations():
#     # Define paths
#     base_dir = "/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq"
#     model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
#     # Load baseline data
#     baseline_path = os.path.join(base_dir, "base_model", model_subdir)
#     baseline_forget = np.load(os.path.join(baseline_path, "activations/feature_acts_forget.npy"))
#     baseline_retain = np.load(os.path.join(baseline_path, "activations/feature_acts_retain.npy"))

#     # Load finetuned data
#     finetuned_path = os.path.join(base_dir, model_subdir)
#     finetuned_forget = np.load(os.path.join(finetuned_path, "activations/feature_acts_forget.npy"))
#     finetuned_retain = np.load(os.path.join(finetuned_path, "activations/feature_acts_retain.npy"))

#     # Print shapes to verify
#     print("Data shapes:")
#     print(f"Baseline forget: {baseline_forget.shape}")
#     print(f"Baseline retain: {baseline_retain.shape}")
#     print(f"Finetuned forget: {finetuned_forget.shape}")
#     print(f"Finetuned retain: {finetuned_retain.shape}")

#     # Get sorting indices based on finetuned retain activations
#     sort_indices = np.argsort(-finetuned_retain)  # negative for descending order
    
#     # Sort all arrays using these indices
#     baseline_forget_sorted = baseline_forget[sort_indices]
#     baseline_retain_sorted = baseline_retain[sort_indices]
#     finetuned_forget_sorted = finetuned_forget[sort_indices]
#     finetuned_retain_sorted = finetuned_retain[sort_indices]

#     # Create plot
#     plt.figure(figsize=(20, 10))
#     feature_indices = np.arange(len(baseline_forget))
    
#     plt.subplot(2, 1, 1)
#     plt.bar(feature_indices - 0.2, baseline_forget_sorted, width=0.4, label='Baseline', alpha=0.7)
#     plt.bar(feature_indices + 0.2, finetuned_forget_sorted, width=0.4, label='Finetuned', alpha=0.7)
#     plt.title('Forget Features (Sorted by Retain Feature Activations)')
#     plt.ylabel('Forget Activation')
#     plt.legend()
#     plt.grid(True, alpha=0.3)
    
#     plt.subplot(2, 1, 2)
#     plt.bar(feature_indices - 0.2, baseline_retain_sorted, width=0.4, label='Baseline', alpha=0.7)
#     plt.bar(feature_indices + 0.2, finetuned_retain_sorted, width=0.4, label='Finetuned', alpha=0.7)
#     plt.title('Retain Features (Sorted by Activation Value)')
#     plt.xlabel('Feature Index (Sorted)')
#     plt.ylabel('Retain Activation')
#     plt.legend()
#     plt.grid(True, alpha=0.3)
    
#     plt.tight_layout()
#     plt.savefig('activation_comparison_sorted.png')
#     plt.close()

#     # Print statistics about sorted activations
#     print("\nActivation Statistics (Top 10 features by retain activation):")
#     print("\nForget Features:")
#     print("Baseline  - Mean of top 10:", baseline_forget_sorted[:10].mean())
#     print("Finetuned - Mean of top 10:", finetuned_forget_sorted[:10].mean())
    
#     print("\nRetain Features:")
#     print("Baseline  - Mean of top 10:", baseline_retain_sorted[:10].mean())
#     print("Finetuned - Mean of top 10:", finetuned_retain_sorted[:10].mean())

#     # Print correlation between forget and retain features
#     baseline_corr = np.corrcoef(baseline_forget_sorted, baseline_retain_sorted)[0,1]
#     finetuned_corr = np.corrcoef(finetuned_forget_sorted, finetuned_retain_sorted)[0,1]
    
#     print("\nCorrelations between forget and retain features:")
#     print(f"Baseline: {baseline_corr:.4f}")
#     print(f"Finetuned: {finetuned_corr:.4f}")


# load_and_plot_sorted_activations()

# Log plot

In [17]:
# Log plot

import numpy as np
import matplotlib.pyplot as plt
import os

def load_and_plot_sorted_activations():
    # Define paths
    base_dir = "/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq"
    model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
    # Load baseline data
    baseline_path = os.path.join(base_dir, "base_model", model_subdir)
    baseline_forget = np.load(os.path.join(baseline_path, "activations/feature_acts_forget.npy"))
    baseline_retain = np.load(os.path.join(baseline_path, "activations/feature_acts_retain.npy"))

    # Load finetuned data
    finetuned_path = os.path.join(base_dir, "alpha_0.1", model_subdir)
    finetuned_forget = np.load(os.path.join(finetuned_path, "activations/feature_acts_forget.npy"))
    finetuned_retain = np.load(os.path.join(finetuned_path, "activations/feature_acts_retain.npy"))

    # Print shapes to verify
    print("Data shapes:")
    print(f"Baseline forget: {baseline_forget.shape}")
    print(f"Baseline retain: {baseline_retain.shape}")
    print(f"Finetuned forget: {finetuned_forget.shape}")
    print(f"Finetuned retain: {finetuned_retain.shape}")

    # Get sorting indices based on finetuned retain activations
    sort_indices = np.argsort(-baseline_retain)  # negative for descending order
    
    # Sort all arrays using these indices
    baseline_forget_sorted = baseline_forget[sort_indices]
    baseline_retain_sorted = baseline_retain[sort_indices]
    finetuned_forget_sorted = finetuned_forget[sort_indices]
    finetuned_retain_sorted = finetuned_retain[sort_indices]

    # Create plot
    plt.figure(figsize=(20, 10))
    feature_indices = np.arange(len(baseline_forget))
    
    plt.subplot(2, 1, 1)
    plt.bar(feature_indices - 0.2, np.abs(baseline_forget_sorted), width=0.4, label='Baseline', alpha=0.7)
    plt.bar(feature_indices + 0.2, np.abs(finetuned_forget_sorted), width=0.4, label='Finetuned', alpha=0.7)
    plt.yscale('log')
    plt.title('Forget Features (Sorted by Retain Feature Activations)')
    plt.ylabel('Forget Activation (log scale)')
    plt.legend()
    plt.grid(True, alpha=0.3, which="both")
    
    plt.subplot(2, 1, 2)
    plt.bar(feature_indices - 0.2, np.abs(baseline_retain_sorted), width=0.4, label='Baseline', alpha=0.7)
    plt.bar(feature_indices + 0.2, np.abs(finetuned_retain_sorted), width=0.4, label='Finetuned', alpha=0.7)
    plt.yscale('log')
    plt.title('Retain Features (Sorted by Activation Value)')
    plt.xlabel('Feature Index (Sorted)')
    plt.ylabel('Retain Activation (log scale)')
    plt.legend()
    plt.grid(True, alpha=0.3, which="both")
    
    plt.tight_layout()
    plt.savefig('activation_comparison_sorted.png')
    plt.close()

    # Print statistics about sorted activations
    print("\nActivation Statistics (Top 10 features by retain activation):")
    print("\nForget Features:")
    print("Baseline  - Mean of top 10:", baseline_forget_sorted[:10].mean())
    print("Finetuned - Mean of top 10:", finetuned_forget_sorted[:10].mean())
    
    print("\nRetain Features:")
    print("Baseline  - Mean of top 10:", baseline_retain_sorted[:10].mean())
    print("Finetuned - Mean of top 10:", finetuned_retain_sorted[:10].mean())

    # Print correlation between forget and retain features
    baseline_corr = np.corrcoef(baseline_forget_sorted, baseline_retain_sorted)[0,1]
    finetuned_corr = np.corrcoef(finetuned_forget_sorted, finetuned_retain_sorted)[0,1]
    
    print("\nCorrelations between forget and retain features:")
    print(f"Baseline: {baseline_corr:.4f}")
    print(f"Finetuned: {finetuned_corr:.4f}")


load_and_plot_sorted_activations()

Data shapes:
Baseline forget: (16384,)
Baseline retain: (16384,)
Finetuned forget: (16384,)
Finetuned retain: (16384,)


  plt.tight_layout()
  plt.savefig('activation_comparison_sorted.png')



Activation Statistics (Top 10 features by retain activation):

Forget Features:
Baseline  - Mean of top 10: 4.043584
Finetuned - Mean of top 10: 6.516452

Retain Features:
Baseline  - Mean of top 10: 4.0422974
Finetuned - Mean of top 10: 8.988276

Correlations between forget and retain features:
Baseline: 0.9287
Finetuned: 0.0880


# Revised shot

In [22]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# import os

# def plot_aggregated_activations(baseline_forget, baseline_retain, 
#                               finetuned_forget, finetuned_retain):
#     # Create figure with subplots
#     fig = plt.figure(figsize=(20, 15))
    
#     # 1. Feature activation changes (log scale)
#     ax1 = plt.subplot(221)
#     sort_indices = np.argsort(-baseline_retain)
#     x = np.arange(len(baseline_retain))
    
#     # Change alpha values in plot 1
#     ax1.scatter(x, np.abs(baseline_retain[sort_indices]), alpha=0.3, label='Baseline Retain', color='blue', s=10)
#     ax1.scatter(x, np.abs(finetuned_retain[sort_indices]), alpha=0.3, label='Finetuned Retain', color='lightblue', s=10)
#     ax1.scatter(x, np.abs(baseline_forget[sort_indices]), alpha=0.3, label='Baseline Forget', color='red', s=10)
#     ax1.scatter(x, np.abs(finetuned_forget[sort_indices]), alpha=0.3, label='Finetuned Forget', color='lightcoral', s=10)
    
#     ax1.set_yscale('log')
#     ax1.set_xlabel('Feature Index (sorted by baseline retain activation)')
#     ax1.set_ylabel('Activation Magnitude (log scale)')
#     ax1.set_title('Feature Activations (Log Scale)')
#     ax1.legend()
#     ax1.grid(True, which="both", ls="-", alpha=0.2)

    
    
#     # 2. Selectivity analysis (log scale for y)
#     ax2 = plt.subplot(222)
#     baseline_selectivity = np.abs(baseline_retain) / (np.abs(baseline_forget) + 1e-10)
#     finetuned_selectivity = np.abs(finetuned_retain) / (np.abs(finetuned_forget) + 1e-10)
    
#     # Sort by baseline selectivity
#     select_sort = np.argsort(-baseline_selectivity)

#     ax2.plot(baseline_selectivity[select_sort], label='Baseline', alpha=0.4)
#     ax2.plot(finetuned_selectivity[select_sort], label='Finetuned', alpha=0.4)

#     # ax2.plot(baseline_selectivity[select_sort], label='Baseline', alpha=0.7)
#     # ax2.plot(finetuned_selectivity[select_sort], label='Finetuned', alpha=0.7)
    
#     ax2.set_yscale('log')
#     ax2.set_xlabel('Feature Index (sorted by baseline selectivity)')
#     ax2.set_ylabel('Selectivity |retain|/|forget| (log scale)')
#     ax2.set_title('Feature Selectivity Comparison')
#     ax2.legend()
#     ax2.grid(True, which="both", ls="-", alpha=0.2)
    
#     # 3. Joint distribution (log scale)
#     ax3 = plt.subplot(223)
#     retain_mask = baseline_retain > np.median(baseline_retain)
    
#     # Create scatter plot with different colors for high/low retain features
#     # ax3.scatter(np.abs(baseline_forget[~retain_mask]), np.abs(finetuned_forget[~retain_mask]), 
#     #             alpha=0.5, label='Low Retain', color='gray', s=10)
#     # ax3.scatter(np.abs(baseline_forget[retain_mask]), np.abs(finetuned_forget[retain_mask]), 
#     #             alpha=0.5, label='High Retain', color='red', s=10)

#     ax3.scatter(np.abs(baseline_forget[~retain_mask]), np.abs(finetuned_forget[~retain_mask]), 
#             alpha=0.3, label='Low Retain', color='gray', s=10)
#     ax3.scatter(np.abs(baseline_forget[retain_mask]), np.abs(finetuned_forget[retain_mask]), 
#             alpha=0.3, label='High Retain', color='red', s=10)



    
#     max_val = max(np.max(np.abs(baseline_forget)), np.max(np.abs(finetuned_forget)))
#     min_val = min(np.min(np.abs(baseline_forget[baseline_forget != 0])), 
#                  np.min(np.abs(finetuned_forget[finetuned_forget != 0])))
    
#     ax3.plot([min_val, max_val], [min_val, max_val], '--', color='gray', alpha=0.5)
#     ax3.set_xscale('log')
#     ax3.set_yscale('log')
#     ax3.set_xlabel('Baseline Forget Activation (log scale)')
#     ax3.set_ylabel('Finetuned Forget Activation (log scale)')
#     ax3.set_title('Forget Activation Changes\n(colored by baseline retain strength)')
#     ax3.legend()
#     ax3.grid(True, which="both", ls="-", alpha=0.2)
    
#     # 4. Change ratio analysis
#     ax4 = plt.subplot(224)
#     change_ratio_retain = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
#     change_ratio_forget = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
#     # Sort by baseline retain activation
#     # ax4.scatter(x, change_ratio_retain[sort_indices], alpha=0.5, label='Retain', color='blue', s=10)
#     # ax4.scatter(x, change_ratio_forget[sort_indices], alpha=0.5, label='Forget', color='red', s=10)

#     ax4.scatter(x, change_ratio_retain[sort_indices], alpha=0.3, label='Retain', color='blue', s=10)
#     ax4.scatter(x, change_ratio_forget[sort_indices], alpha=0.3, label='Forget', color='red', s=10)
    
#     ax4.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
#     ax4.set_yscale('log')
#     ax4.set_xlabel('Feature Index (sorted by baseline retain activation)')
#     ax4.set_ylabel('Change Ratio (finetuned/baseline) (log scale)')
#     ax4.set_title('Activation Change Ratios')
#     ax4.legend()
#     ax4.grid(True, which="both", ls="-", alpha=0.2)
    
#     plt.tight_layout()
#     return fig

# def print_activation_summary(baseline_forget, baseline_retain, 
#                            finetuned_forget, finetuned_retain):
#     # Calculate key metrics using log space
#     baseline_selectivity = np.abs(baseline_retain) / (np.abs(baseline_forget) + 1e-10)
#     finetuned_selectivity = np.abs(finetuned_retain) / (np.abs(finetuned_forget) + 1e-10)
    
#     change_ratio_retain = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
#     change_ratio_forget = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
#     print("\nActivation Analysis Summary:")
#     print(f"\nMedian Selectivity (|retain|/|forget|):")
#     print(f"Baseline: {np.median(baseline_selectivity):.4f}")
#     print(f"Finetuned: {np.median(finetuned_selectivity):.4f}")
    
#     print(f"\nMedian Change Ratios (finetuned/baseline):")
#     print(f"Retain: {np.median(change_ratio_retain):.4f}")
#     print(f"Forget: {np.median(change_ratio_forget):.4f}")
    
#     # Calculate percentage of features with improved metrics
#     selectivity_improved = np.mean(finetuned_selectivity > baseline_selectivity) * 100
#     retain_preserved = np.mean(change_ratio_retain > 0.9) * 100
#     forget_reduced = np.mean(change_ratio_forget < 0.9) * 100
    
#     print(f"\nPercentage of features with:")
#     print(f"Improved selectivity: {selectivity_improved:.1f}%")
#     print(f"Preserved retain activation (>90%): {retain_preserved:.1f}%")
#     print(f"Reduced forget activation (<90%): {forget_reduced:.1f}%")

# def main():
#     # Define paths
#     base_dir = "/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq"
#     model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
#     # Load baseline data
#     baseline_path = os.path.join(base_dir, "base_model", model_subdir)
#     baseline_forget = np.load(os.path.join(baseline_path, "activations/feature_acts_forget.npy"))
#     baseline_retain = np.load(os.path.join(baseline_path, "activations/feature_acts_retain.npy"))
    
#     # Load finetuned data
#     finetuned_path = os.path.join(base_dir, "alpha_0.1", model_subdir)
#     finetuned_forget = np.load(os.path.join(finetuned_path, "activations/feature_acts_forget.npy"))
#     finetuned_retain = np.load(os.path.join(finetuned_path, "activations/feature_acts_retain.npy"))
    
#     # Print data shapes
#     print("Data shapes:")
#     print(f"Baseline forget: {baseline_forget.shape}")
#     print(f"Baseline retain: {baseline_retain.shape}")
#     print(f"Finetuned forget: {finetuned_forget.shape}")
#     print(f"Finetuned retain: {finetuned_retain.shape}")
    
#     # Generate visualizations
#     fig = plot_aggregated_activations(baseline_forget, baseline_retain,
#                                     finetuned_forget, finetuned_retain)
    
#     # Save the figure
#     fig.savefig('activation_analysis_log.png', dpi=300, bbox_inches='tight')
#     plt.close(fig)
    
#     # Print summary statistics
#     print_activation_summary(baseline_forget, baseline_retain,
#                            finetuned_forget, finetuned_retain)

# if __name__ == "__main__":
#     main()

In [24]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# import os

# def plot_aggregated_activations(baseline_forget, baseline_retain, 
#                               finetuned_forget, finetuned_retain):
#     """
#     Create a figure with four subplots analyzing different aspects of the activation patterns.
#     """
#     fig = plt.figure(figsize=(20, 15))
    
#     # 1. Feature activation changes (log scale)
#     ax1 = plt.subplot(221)
#     sort_indices = np.argsort(-baseline_retain)
#     x = np.arange(len(baseline_retain))
    
#     # Plot retains
#     ax1.scatter(x, np.abs(baseline_retain[sort_indices]), alpha=0.15, 
#                 label='Baseline Retain', color='darkblue', s=5)
#     ax1.scatter(x, np.abs(finetuned_retain[sort_indices]), alpha=0.15, 
#                 label='Finetuned Retain', color='blue', s=5)
    
#     # Plot forgets
#     ax1.scatter(x, np.abs(baseline_forget[sort_indices]), alpha=0.15, 
#                 label='Baseline Forget', color='darkred', s=5)
#     ax1.scatter(x, np.abs(finetuned_forget[sort_indices]), alpha=0.15, 
#                 label='Finetuned Forget', color='red', s=5)
    
#     ax1.set_yscale('log')
#     ax1.set_xlabel('Feature Index (sorted by baseline retain activation)')
#     ax1.set_ylabel('Activation Magnitude (log scale)')
#     ax1.set_title('Feature Activations')
#     ax1.legend()
#     ax1.grid(True, which="both", ls="-", alpha=0.1)
    
#     # 2. Selectivity analysis (log scale)
#     ax2 = plt.subplot(222)
#     baseline_selectivity = np.abs(baseline_retain) / (np.abs(baseline_forget) + 1e-10)
#     finetuned_selectivity = np.abs(finetuned_retain) / (np.abs(finetuned_forget) + 1e-10)
    
#     select_sort = np.argsort(-baseline_selectivity)
    
#     # Plot selectivity curves
#     ax2.plot(baseline_selectivity[select_sort], color='darkblue', 
#              label='Baseline', alpha=0.6, linewidth=1)
#     ax2.plot(finetuned_selectivity[select_sort], color='red',
#              label='Finetuned', alpha=0.6, linewidth=1)
    
#     ax2.set_yscale('log')
#     ax2.set_xlabel('Feature Index (sorted by baseline selectivity)')
#     ax2.set_ylabel('Selectivity |retain|/|forget| (log scale)')
#     ax2.set_title('Feature Selectivity')
#     ax2.legend()
#     ax2.grid(True, which="both", ls="-", alpha=0.1)
    
#     # 3. Distribution of forget activations
#     ax3 = plt.subplot(223)
    
#     # Use hexbin for better density visualization
#     hb = ax3.hexbin(np.abs(baseline_forget), np.abs(finetuned_forget), 
#                     gridsize=30, bins='log', cmap='YlOrRd')
#     plt.colorbar(hb, ax=ax3, label='Count (log scale)')
    
#     # Add reference line
#     max_val = max(np.max(np.abs(baseline_forget)), np.max(np.abs(finetuned_forget)))
#     min_val = min(np.min(np.abs(baseline_forget[baseline_forget != 0])), 
#                  np.min(np.abs(finetuned_forget[finetuned_forget != 0])))
#     ax3.plot([min_val, max_val], [min_val, max_val], '--', color='gray', alpha=0.5)
    
#     ax3.set_xscale('log')
#     ax3.set_yscale('log')
#     ax3.set_xlabel('Baseline Forget Activation (log scale)')
#     ax3.set_ylabel('Finetuned Forget Activation (log scale)')
#     ax3.set_title('Forget Activation Changes\n(density plot)')
#     ax3.grid(True, which="both", ls="-", alpha=0.1)
    
#     # 4. Change ratio analysis
#     ax4 = plt.subplot(224)
#     change_ratio_retain = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
#     change_ratio_forget = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
#     # Calculate moving averages
#     window = 500
#     retain_ma = np.convolve(change_ratio_retain[sort_indices], 
#                            np.ones(window)/window, mode='valid')
#     forget_ma = np.convolve(change_ratio_forget[sort_indices], 
#                            np.ones(window)/window, mode='valid')
#     x_ma = np.arange(len(retain_ma))
    
#     # Plot moving averages and raw data
#     ax4.plot(x_ma, retain_ma, color='blue', label='Retain (moving avg)', 
#              alpha=0.8, linewidth=2)
#     ax4.plot(x_ma, forget_ma, color='red', label='Forget (moving avg)', 
#              alpha=0.8, linewidth=2)
    
#     ax4.scatter(x, change_ratio_retain[sort_indices], alpha=0.05, 
#                 color='blue', s=1)
#     ax4.scatter(x, change_ratio_forget[sort_indices], alpha=0.05, 
#                 color='red', s=1)
    
#     ax4.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
#     ax4.set_yscale('log')
#     ax4.set_xlabel('Feature Index (sorted by baseline retain activation)')
#     ax4.set_ylabel('Change Ratio (finetuned/baseline) (log scale)')
#     ax4.set_title('Activation Change Ratios')
#     ax4.legend()
#     ax4.grid(True, which="both", ls="-", alpha=0.1)
    
#     plt.tight_layout()
#     return fig

# def print_activation_summary(baseline_forget, baseline_retain, 
#                            finetuned_forget, finetuned_retain):
#     """
#     Calculate and print summary statistics for the activation analysis.
#     """
#     # Calculate selectivity metrics
#     baseline_selectivity = np.abs(baseline_retain) / (np.abs(baseline_forget) + 1e-10)
#     finetuned_selectivity = np.abs(finetuned_retain) / (np.abs(finetuned_forget) + 1e-10)
    
#     change_ratio_retain = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
#     change_ratio_forget = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
#     print("\nActivation Analysis Summary:")
#     print(f"\nMedian Selectivity (|retain|/|forget|):")
#     print(f"Baseline: {np.median(baseline_selectivity):.4f}")
#     print(f"Finetuned: {np.median(finetuned_selectivity):.4f}")
    
#     print(f"\nMedian Change Ratios (finetuned/baseline):")
#     print(f"Retain: {np.median(change_ratio_retain):.4f}")
#     print(f"Forget: {np.median(change_ratio_forget):.4f}")
    
#     # Calculate improvement metrics
#     selectivity_improved = np.mean(finetuned_selectivity > baseline_selectivity) * 100
#     retain_preserved = np.mean(change_ratio_retain > 0.9) * 100
#     forget_reduced = np.mean(change_ratio_forget < 0.9) * 100
    
#     print(f"\nPercentage of features with:")
#     print(f"Improved selectivity: {selectivity_improved:.1f}%")
#     print(f"Preserved retain activation (>90%): {retain_preserved:.1f}%")
#     print(f"Reduced forget activation (<90%): {forget_reduced:.1f}%")

# def main():
#     """
#     Main function to load data and generate visualizations.
#     """
#     # Define paths
#     base_dir = "/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq"
#     model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
#     # Load baseline data
#     baseline_path = os.path.join(base_dir, "base_model", model_subdir)
#     baseline_forget = np.load(os.path.join(baseline_path, "activations/feature_acts_forget.npy"))
#     baseline_retain = np.load(os.path.join(baseline_path, "activations/feature_acts_retain.npy"))
    
#     # Load finetuned data
#     finetuned_path = os.path.join(base_dir, "alpha_0.1", model_subdir)
#     finetuned_forget = np.load(os.path.join(finetuned_path, "activations/feature_acts_forget.npy"))
#     finetuned_retain = np.load(os.path.join(finetuned_path, "activations/feature_acts_retain.npy"))
    
#     # Print data shapes
#     print("Data shapes:")
#     print(f"Baseline forget: {baseline_forget.shape}")
#     print(f"Baseline retain: {baseline_retain.shape}")
#     print(f"Finetuned forget: {finetuned_forget.shape}")
#     print(f"Finetuned retain: {finetuned_retain.shape}")
    
#     # Generate visualizations
#     fig = plot_aggregated_activations(baseline_forget, baseline_retain,
#                                     finetuned_forget, finetuned_retain)
    
#     # Save the figure
#     fig.savefig('activation_analysis_log.png', dpi=300, bbox_inches='tight')
#     plt.close(fig)
    
#     # Print summary statistics
#     print_activation_summary(baseline_forget, baseline_retain,
#                            finetuned_forget, finetuned_retain)

# if __name__ == "__main__":
#     main()

In [27]:
# import numpy as np
# import matplotlib.pyplot as plt
# from matplotlib.colors import LogNorm  # Added this import
# import seaborn as sns
# import os
# from pathlib import Path

# def plot_improved_activations(baseline_forget, baseline_retain, 
#                             finetuned_forget, finetuned_retain):
#     """
#     Create improved visualizations of activation patterns with better statistical measures
#     and clearer density plotting.
#     """
#     fig = plt.figure(figsize=(20, 15))
    
#     # 1. Feature activation changes (log scale)
#     ax1 = plt.subplot(221)
#     sort_indices = np.argsort(-baseline_retain)
#     x = np.arange(len(baseline_retain))
    
#     # Plot with reduced alpha and increased size for better visibility
#     ax1.scatter(x, np.abs(baseline_retain[sort_indices]), alpha=0.2, 
#                 label='Baseline Retain', color='darkblue', s=10)
#     ax1.scatter(x, np.abs(finetuned_retain[sort_indices]), alpha=0.2, 
#                 label='Finetuned Retain', color='blue', s=10)
#     ax1.scatter(x, np.abs(baseline_forget[sort_indices]), alpha=0.2, 
#                 label='Baseline Forget', color='darkred', s=10)
#     ax1.scatter(x, np.abs(finetuned_forget[sort_indices]), alpha=0.2, 
#                 label='Finetuned Forget', color='red', s=10)
    
#     ax1.set_yscale('log')
#     ax1.set_xlabel('Feature Index (sorted by baseline retain activation)')
#     ax1.set_ylabel('Activation Magnitude (log scale)')
#     ax1.set_title('Feature Activations')
#     ax1.legend()
#     ax1.grid(True, which="both", ls="-", alpha=0.2)
    
#     # 2. Selectivity analysis with percentile bands
#     ax2 = plt.subplot(222)
#     baseline_selectivity = np.abs(baseline_retain) / (np.abs(baseline_forget) + 1e-10)
#     finetuned_selectivity = np.abs(finetuned_retain) / (np.abs(finetuned_forget) + 1e-10)
    
#     select_sort = np.argsort(-baseline_selectivity)
    
#     # Calculate percentile bands
#     def rolling_percentile(data, window=500):
#         result = np.zeros_like(data)
#         for i in range(len(data)):
#             start_idx = max(0, i - window//2)
#             end_idx = min(len(data), i + window//2)
#             result[i] = np.median(data[start_idx:end_idx])
#         return result
    
#     baseline_median = rolling_percentile(baseline_selectivity[select_sort])
#     finetuned_median = rolling_percentile(finetuned_selectivity[select_sort])
    
#     ax2.plot(baseline_median, color='darkblue', label='Baseline (median)', linewidth=2)
#     ax2.plot(finetuned_median, color='red', label='Finetuned (median)', linewidth=2)
    
#     ax2.set_yscale('log')
#     ax2.set_xlabel('Feature Index (sorted by baseline selectivity)')
#     ax2.set_ylabel('Selectivity |retain|/|forget| (log scale)')
#     ax2.set_title('Feature Selectivity (with rolling median)')
#     ax2.legend()
#     ax2.grid(True, which="both", ls="-", alpha=0.2)
    
#     # 3. Improved density plot of forget activations
#     ax3 = plt.subplot(223)
    
#     # Create 2D histogram with more bins and better color scaling
#     valid_mask = (baseline_forget != 0) & (finetuned_forget != 0)
#     h, xedges, yedges = np.histogram2d(
#         np.log10(np.abs(baseline_forget[valid_mask]) + 1e-10),
#         np.log10(np.abs(finetuned_forget[valid_mask]) + 1e-10),
#         bins=50,
#         range=[[-6, 2], [-6, 2]]  # Set fixed range for better visualization
#     )
    
#     # Plot as a heatmap with improved visibility
#     pcm = ax3.pcolormesh(xedges, yedges, h.T, 
#                          norm=LogNorm(vmin=1, vmax=h.max()),
#                          cmap='viridis')
#     plt.colorbar(pcm, ax=ax3, label='Count (log scale)')
    
#     # Add reference line
#     ax3.plot([-6, 2], [-6, 2], '--', color='red', alpha=0.8, label='y=x')
    
#     ax3.set_xlabel('Baseline Forget Activation (log10 scale)')
#     ax3.set_ylabel('Finetuned Forget Activation (log10 scale)')
#     ax3.set_title('Forget Activation Changes\n(improved density visualization)')
#     ax3.legend()
#     ax3.grid(True, which="both", ls="-", alpha=0.2)
    
#     # 4. Change ratio analysis with quantile regression
#     ax4 = plt.subplot(224)
#     change_ratio_retain = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
#     change_ratio_forget = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
#     # Calculate quantiles instead of moving average
#     def rolling_quantiles(data, window=500):
#         result_50 = np.zeros_like(data)
#         result_25 = np.zeros_like(data)
#         result_75 = np.zeros_like(data)
        
#         for i in range(len(data)):
#             start_idx = max(0, i - window//2)
#             end_idx = min(len(data), i + window//2)
#             window_data = data[start_idx:end_idx]
#             result_50[i] = np.percentile(window_data, 50)
#             result_25[i] = np.percentile(window_data, 25)
#             result_75[i] = np.percentile(window_data, 75)
        
#         return result_25, result_50, result_75
    
#     # Calculate and plot quantiles for both retain and forget
#     retain_25, retain_50, retain_75 = rolling_quantiles(change_ratio_retain[sort_indices])
#     forget_25, forget_50, forget_75 = rolling_quantiles(change_ratio_forget[sort_indices])
    
#     x = np.arange(len(retain_50))
    
#     # Plot with confidence bands
#     ax4.fill_between(x, retain_25, retain_75, color='blue', alpha=0.2)
#     ax4.fill_between(x, forget_25, forget_75, color='red', alpha=0.2)
    
#     ax4.plot(x, retain_50, color='blue', label='Retain (median)', linewidth=2)
#     ax4.plot(x, forget_50, color='red', label='Forget (median)', linewidth=2)
    
#     ax4.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
#     ax4.set_yscale('log')
#     ax4.set_xlabel('Feature Index (sorted by baseline retain activation)')
#     ax4.set_ylabel('Change Ratio (finetuned/baseline) (log scale)')
#     ax4.set_title('Activation Change Ratios with Confidence Bands')
#     ax4.legend()
#     ax4.grid(True, which="both", ls="-", alpha=0.2)
    
#     plt.tight_layout()
#     return fig

# def print_activation_summary(baseline_forget, baseline_retain, 
#                            finetuned_forget, finetuned_retain):
#     """
#     Calculate and print improved summary statistics for the activation analysis.
#     """
#     # Calculate selectivity metrics
#     baseline_selectivity = np.abs(baseline_retain) / (np.abs(baseline_forget) + 1e-10)
#     finetuned_selectivity = np.abs(finetuned_retain) / (np.abs(finetuned_forget) + 1e-10)
    
#     change_ratio_retain = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
#     change_ratio_forget = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
#     print("\nImproved Activation Analysis Summary:")
#     print("\nSelectivity Statistics (|retain|/|forget|):")
#     print(f"Baseline - Median: {np.median(baseline_selectivity):.4f}, "
#           f"25th: {np.percentile(baseline_selectivity, 25):.4f}, "
#           f"75th: {np.percentile(baseline_selectivity, 75):.4f}")
#     print(f"Finetuned - Median: {np.median(finetuned_selectivity):.4f}, "
#           f"25th: {np.percentile(finetuned_selectivity, 25):.4f}, "
#           f"75th: {np.percentile(finetuned_selectivity, 75):.4f}")
    
#     print(f"\nChange Ratio Statistics (finetuned/baseline):")
#     print(f"Retain - Median: {np.median(change_ratio_retain):.4f}, "
#           f"25th: {np.percentile(change_ratio_retain, 25):.4f}, "
#           f"75th: {np.percentile(change_ratio_retain, 75):.4f}")
#     print(f"Forget - Median: {np.median(change_ratio_forget):.4f}, "
#           f"25th: {np.percentile(change_ratio_forget, 25):.4f}, "
#           f"75th: {np.percentile(change_ratio_forget, 75):.4f}")
    
#     # Calculate improvement metrics with confidence intervals
#     def bootstrap_mean(data, n_bootstrap=1000):
#         means = np.zeros(n_bootstrap)
#         for i in range(n_bootstrap):
#             sample = np.random.choice(data, size=len(data), replace=True)
#             means[i] = np.mean(sample)
#         return np.mean(means), np.percentile(means, [2.5, 97.5])
    
#     selectivity_improved = finetuned_selectivity > baseline_selectivity
#     retain_preserved = change_ratio_retain > 0.9
#     forget_reduced = change_ratio_forget < 0.9
    
#     metrics = {
#         'Improved selectivity': selectivity_improved,
#         'Preserved retain activation (>90%)': retain_preserved,
#         'Reduced forget activation (<90%)': forget_reduced
#     }
    
#     print("\nFeature Improvement Analysis (with 95% confidence intervals):")
#     for metric_name, metric_data in metrics.items():
#         mean, (ci_low, ci_high) = bootstrap_mean(metric_data)
#         print(f"{metric_name}: {mean*100:.1f}% [{ci_low*100:.1f}%, {ci_high*100:.1f}%]")

# def main():
#     """
#     Main function to load data and generate improved visualizations.
#     """
#     # Define paths
#     base_dir = Path("/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq")
#     model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
#     # Load baseline data
#     baseline_path = base_dir / "base_model" / model_subdir
#     baseline_forget = np.load(baseline_path / "activations/feature_acts_forget.npy")
#     baseline_retain = np.load(baseline_path / "activations/feature_acts_retain.npy")
    
#     # Load finetuned data
#     finetuned_path = base_dir / "alpha_0.1" / model_subdir
#     finetuned_forget = np.load(finetuned_path / "activations/feature_acts_forget.npy")
#     finetuned_retain = np.load(finetuned_path / "activations/feature_acts_retain.npy")
    
#     # Print data shapes and basic statistics
#     print("Data Analysis:")
#     print(f"Shape of activation arrays: {baseline_forget.shape}")
#     print(f"\nBasic Statistics:")
#     print(f"{'Dataset':<15} {'Min':>10} {'Max':>10} {'Mean':>10} {'Median':>10}")
#     print("-" * 60)
#     for name, data in [
#         ("Baseline Forget", baseline_forget),
#         ("Baseline Retain", baseline_retain),
#         ("Finetuned Forget", finetuned_forget),
#         ("Finetuned Retain", finetuned_retain)
#     ]:
#         print(f"{name:<15} {np.min(data):10.2e} {np.max(data):10.2e} "
#               f"{np.mean(data):10.2e} {np.median(data):10.2e}")
    
#     # Generate improved visualizations
#     fig = plot_improved_activations(baseline_forget, baseline_retain,
#                                   finetuned_forget, finetuned_retain)
    
#     # Save the figure with high resolution
#     plt.savefig('improved_activation_analysis.png', dpi=300, bbox_inches='tight')
#     plt.close(fig)
    
#     # Print detailed summary statistics
#     print_activation_summary(baseline_forget, baseline_retain,
#                            finetuned_forget, finetuned_retain)

# if __name__ == "__main__":
#     main()

In [29]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pandas as pd
from matplotlib.colors import LogNorm

def plot_distributions(baseline_forget, baseline_retain, 
                      finetuned_forget, finetuned_retain,
                      save_path='distribution_plot.png'):
    """
    Create violin plots showing the distribution of activations before and after finetuning.
    """
    plt.figure(figsize=(12, 8))
    
    # Prepare data for violin plots
    data_dict = {
        'Baseline Retain': np.abs(baseline_retain),
        'Finetuned Retain': np.abs(finetuned_retain),
        'Baseline Forget': np.abs(baseline_forget),
        'Finetuned Forget': np.abs(finetuned_forget)
    }
    
    df = pd.DataFrame({
        'Activation': np.concatenate([v for v in data_dict.values()]),
        'Type': np.concatenate([[k] * len(v) for k, v in data_dict.items()])
    })
    
    # Create violin plot
    sns.violinplot(data=df, x='Type', y='Activation', cut=0)
    plt.yscale('log')
    plt.xticks(rotation=45)
    plt.ylabel('Activation Magnitude (log scale)')
    plt.title('Distribution of Activation Magnitudes')
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_feature_comparison(baseline_forget, baseline_retain, 
                          finetuned_forget, finetuned_retain,
                          save_path='feature_comparison.png'):
    """
    Create scatter plot comparing retain vs forget suppression ratios.
    """
    plt.figure(figsize=(10, 10))
    
    # Calculate ratios (adding small epsilon to prevent division by zero)
    retain_ratio = np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10)
    forget_ratio = np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)
    
    # Create scatter plot
    plt.scatter(retain_ratio, forget_ratio, alpha=0.1, s=1)
    
    # Add diagonal line
    max_val = max(retain_ratio.max(), forget_ratio.max())
    min_val = min(retain_ratio.min(), forget_ratio.min())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Equal suppression')
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Retain Activation Ratio (Finetuned/Baseline)')
    plt.ylabel('Forget Activation Ratio (Finetuned/Baseline)')
    plt.title('Feature-wise Comparison of Activation Changes')
    
    # Add annotations
    plt.text(0.05, 0.95, 'Stronger forget suppression', 
             transform=plt.gca().transAxes, 
             verticalalignment='top')
    plt.text(0.95, 0.05, 'Stronger retain suppression', 
             transform=plt.gca().transAxes, 
             horizontalalignment='right')
    
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_magnitude_analysis(baseline_forget, baseline_retain, 
                          finetuned_forget, finetuned_retain,
                          save_path='magnitude_analysis.png'):
    """
    Create plot showing activation magnitudes sorted by baseline retain activation.
    """
    plt.figure(figsize=(15, 10))
    
    # Sort all arrays by baseline retain activation
    sort_indices = np.argsort(-(baseline_retain))
    x = np.arange(len(baseline_retain))
    
    # Function to compute rolling statistics
    def rolling_stats(data, window=500):
        result_median = np.zeros_like(data, dtype=float)
        result_25 = np.zeros_like(data, dtype=float)
        result_75 = np.zeros_like(data, dtype=float)
        
        for i in range(len(data)):
            start_idx = max(0, i - window//2)
            end_idx = min(len(data), i + window//2)
            window_data = data[start_idx:end_idx]
            result_median[i] = np.median(window_data)
            result_25[i] = np.percentile(window_data, 25)
            result_75[i] = np.percentile(window_data, 75)
        
        return result_25, result_median, result_75

    # Compute rolling statistics for each series
    br_25, br_med, br_75 = rolling_stats(np.abs(baseline_retain[sort_indices]))
    fr_25, fr_med, fr_75 = rolling_stats(np.abs(finetuned_retain[sort_indices]))
    bf_25, bf_med, bf_75 = rolling_stats(np.abs(baseline_forget[sort_indices]))
    ff_25, ff_med, ff_75 = rolling_stats(np.abs(finetuned_forget[sort_indices]))

    # Plot with shaded regions for IQR
    plt.fill_between(x, br_25, br_75, alpha=0.2, color='darkblue')
    plt.fill_between(x, fr_25, fr_75, alpha=0.2, color='blue')
    plt.fill_between(x, bf_25, bf_75, alpha=0.2, color='darkred')
    plt.fill_between(x, ff_25, ff_75, alpha=0.2, color='red')

    # Plot median lines
    plt.plot(x, br_med, color='darkblue', label='Baseline Retain', linewidth=2)
    plt.plot(x, fr_med, color='blue', label='Finetuned Retain', linewidth=2)
    plt.plot(x, bf_med, color='darkred', label='Baseline Forget', linewidth=2)
    plt.plot(x, ff_med, color='red', label='Finetuned Forget', linewidth=2)

    plt.yscale('log')
    plt.xlabel('Feature Index (sorted by baseline retain activation)')
    plt.ylabel('Activation Magnitude (log scale)')
    plt.title('Activation Magnitudes Across Features')
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def main():
    """
    Main function to load data and generate all visualizations.
    """
    # Define paths
    base_dir = Path("/data/aashiq_muhamed/unlearning/SAEBench/eval_results/activation_mcq")
    model_subdir = "sae_bench_gemma-2-2b_topk_width-2pow14_date-1109_blocks.5.hook_resid_post__trainer_2/gemma-2-2b-it/results"
    
    # Load baseline data
    baseline_path = base_dir / "base_model" / model_subdir
    baseline_forget = np.load(baseline_path / "activations/feature_acts_forget.npy")
    baseline_retain = np.load(baseline_path / "activations/feature_acts_retain.npy")
    
    # Load finetuned data
    finetuned_path = base_dir / "alpha_0.1" / model_subdir
    finetuned_forget = np.load(finetuned_path / "activations/feature_acts_forget.npy")
    finetuned_retain = np.load(finetuned_path / "activations/feature_acts_retain.npy")
    
    # Generate all visualizations
    plot_distributions(baseline_forget, baseline_retain,
                      finetuned_forget, finetuned_retain,
                      'activation_distributions.png')
    
    plot_feature_comparison(baseline_forget, baseline_retain,
                          finetuned_forget, finetuned_retain,
                          'feature_comparison.png')
    
    plot_magnitude_analysis(baseline_forget, baseline_retain,
                          finetuned_forget, finetuned_retain,
                          'magnitude_analysis.png')
    
    # Print summary statistics
    print("\nSummary Statistics:")
    print("\nMedian Activation Magnitudes:")
    print(f"Baseline Retain: {np.median(np.abs(baseline_retain)):.4e}")
    print(f"Finetuned Retain: {np.median(np.abs(finetuned_retain)):.4e}")
    print(f"Baseline Forget: {np.median(np.abs(baseline_forget)):.4e}")
    print(f"Finetuned Forget: {np.median(np.abs(finetuned_forget)):.4e}")
    
    # Calculate suppression ratios
    retain_ratio = np.median(np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10))
    forget_ratio = np.median(np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10))
    
    print("\nMedian Suppression Ratios (Finetuned/Baseline):")
    print(f"Retain: {retain_ratio:.4f}")
    print(f"Forget: {forget_ratio:.4f}")
    
    # Calculate percentage of features with stronger forget suppression
    stronger_forget = np.mean(
        (np.abs(finetuned_forget) / (np.abs(baseline_forget) + 1e-10)) <
        (np.abs(finetuned_retain) / (np.abs(baseline_retain) + 1e-10))
    ) * 100
    
    print(f"\nPercentage of features with stronger forget suppression: {stronger_forget:.1f}%")

if __name__ == "__main__":
    main()


Summary Statistics:

Median Activation Magnitudes:
Baseline Retain: 7.2640e-03
Finetuned Retain: 2.6610e-05
Baseline Forget: 5.9969e-03
Finetuned Forget: 7.1051e-05

Median Suppression Ratios (Finetuned/Baseline):
Retain: 0.0020
Forget: 0.0158

Percentage of features with stronger forget suppression: 26.4%
