In [None]:
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime

In [None]:
# Setup paths
BASE_DIR = Path("data/eval_parallelized/")
BASE_DIR = Path("/Users/roy/data/ripple_bench/parallelized")

BASE_RESULTS = {
    "llama": BASE_DIR / "Llama-3-8b-Instruct_ripple_results.csv",
    "zephyr": BASE_DIR / "zephyr-7b-beta_ripple_results.csv",
}

UNLEARNING_RESULTS = {
    "llama": {
        "elm": BASE_DIR / "llama-3-8b-instruct-elm-ckpt7_ripple_results.csv",
        "graddiff": BASE_DIR / "llama-3-8b-instruct-graddiff-ckpt8_ripple_results.csv",
        "pbj": BASE_DIR / "llama-3-8b-instruct-pbj-ckpt6_ripple_results.csv",
        "repnoise": BASE_DIR / "llama-3-8b-instruct-repnoise-ckpt6_ripple_results.csv",
        "rmu": BASE_DIR / "llama-3-8b-instruct-rmu-ckpt6_ripple_results.csv",
        "rmu_lat": BASE_DIR / "llama-3-8b-instruct-rmu-lat-ckpt7_ripple_results.csv",
        "rr": BASE_DIR / "llama-3-8b-instruct-rr-ckpt8_ripple_results.csv",
        "tar": BASE_DIR / "llama-3-8b-instruct-tar-ckpt8_ripple_results.csv",
    },
    "zephyr": {
        "elm": BASE_DIR / "zephyr-7b-elm_ripple_results.csv",
    }
}

PLOT_DIR = Path("plots/")
PLOT_DIR.mkdir(parents=True, exist_ok=True)

## Identify WMDP Topics

First, let's identify what WMDP topics are in the dataset. These are typically at distance 0.

In [None]:
# Load base Llama results to explore WMDP topics
df_base = pd.read_csv(BASE_RESULTS["llama"])

print(f"Available columns: {df_base.columns.tolist()}")
print(f"\nTotal data points: {len(df_base)}")
print(f"Total unique questions: {df_base['question'].nunique()}")

# Check for distance 0 entries (WMDP topics)
df_wmdp = df_base[df_base['distance'] == 0]
print(f"\nQuestions at distance 0 (WMDP): {len(df_wmdp)}")
print(f"Unique questions at distance 0: {df_wmdp['question'].nunique()}")

# If there's a source_topic column, analyze it
if 'source_topic' in df_base.columns:
    wmdp_topics = df_wmdp['source_topic'].value_counts()
    print(f"\nWMDP topics (distance 0):")
    print(wmdp_topics.head(20))
    
    # Store WMDP topics
    wmdp_topic_list = wmdp_topics.index.tolist()
else:
    print("\nNo 'source_topic' column found. Looking for alternative way to identify topics...")
    # Try to identify topics from questions at distance 0
    print("\nSample WMDP questions (distance 0):")
    print(df_wmdp['question'].head(10).tolist())
    wmdp_topic_list = []

In [None]:
# Analyze distribution of questions by distance for each WMDP topic
if 'source_topic' in df_base.columns and wmdp_topic_list:
    print("Distribution of questions by distance for top WMDP topics:\n")
    
    for topic in wmdp_topic_list[:5]:  # Top 5 WMDP topics
        topic_df = df_base[df_base['source_topic'] == topic]
        dist_counts = topic_df.groupby('distance').size()
        
        print(f"Topic: {topic}")
        print(f"  Total questions: {len(topic_df)}")
        print(f"  Unique questions: {topic_df['question'].nunique()}")
        print(f"  Distance range: {topic_df['distance'].min():.0f} - {topic_df['distance'].max():.0f}")
        print(f"  Questions at distance 0-10: {len(topic_df[topic_df['distance'] <= 10])}")
        print()

## Ripple Effects from WMDP Topics

In [None]:
def analyze_wmdp_topic_ripple(wmdp_topic, bucket_size=10, max_distance=100):
    """Analyze ripple effects emanating from a specific WMDP topic."""
    
    colors = ['C0', 'C1']
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Process both base models
    for i, (model_name, file_path) in enumerate(BASE_RESULTS.items()):
        df = pd.read_csv(file_path)
        
        # Filter for questions related to this WMDP topic
        if 'source_topic' in df.columns:
            df = df[df['source_topic'] == wmdp_topic]
        else:
            print(f"Warning: Cannot filter by topic, no 'source_topic' column")
            return
        
        if len(df) == 0:
            print(f"No data found for WMDP topic: {wmdp_topic}")
            continue
            
        df = df[df["distance"] < max_distance]
        df["distance_bucket"] = (df["distance"] // bucket_size) * bucket_size
        
        # Raw results
        raw_results = df.groupby("distance_bucket")["is_correct"].agg(["mean", "std"])
        raw_results["sem"] = raw_results["std"] / np.sqrt(df.groupby("distance_bucket").size())
        
        # Dedup results
        df_dedup = df.groupby("question")[["is_correct", "distance_bucket"]].agg(
            ["max", "min", "sum", "count", "mean"]
        )
        dedup_results = df_dedup.groupby(df_dedup["distance_bucket"]["min"]).agg(["mean", "std"])["is_correct"]["mean"]
        dedup_results["sem"] = dedup_results["std"] / np.sqrt(df_dedup["distance_bucket"].groupby("min").size())
        
        # Plot raw results with dashed line
        ax1.errorbar(raw_results.index, raw_results["mean"] * 100, yerr=raw_results["sem"] * 100,
                     marker='o', linewidth=2, markersize=8, capsize=5, linestyle=':',
                     color=colors[i], label=f'{model_name.title()} Raw')
        
        # Plot dedup results with solid line  
        ax1.errorbar(dedup_results.index, dedup_results["mean"] * 100, yerr=dedup_results["sem"] * 100,
                     marker='s', linewidth=2, markersize=8, capsize=5, linestyle='-',
                     color=colors[i], label=f'{model_name.title()} Dedup')
    
    ax1.set_xlabel('Distance Bucket', fontsize=12)
    ax1.set_ylabel('Accuracy (%)', fontsize=12)
    ax1.set_title(f'WMDP Topic: {wmdp_topic}\nRaw vs Dedup Results', fontsize=14)
    ax1.legend(loc="best", fontsize=11)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, 100)
    
    # Second plot: Compare unlearning methods for this topic (Llama only)
    df_llama = pd.read_csv(BASE_RESULTS["llama"])
    if 'source_topic' in df_llama.columns:
        df_llama = df_llama[df_llama['source_topic'] == wmdp_topic]
        df_llama = df_llama[df_llama["distance"] < max_distance]
        df_llama["distance_bucket"] = (df_llama["distance"] // bucket_size) * bucket_size
        
        # Base model results
        base_results = df_llama.groupby("distance_bucket")["is_correct"].agg(["mean", "std"])
        base_results["sem"] = base_results["std"] / np.sqrt(df_llama.groupby("distance_bucket").size())
        
        ax2.errorbar(base_results.index, base_results["mean"] * 100, yerr=base_results["sem"] * 100,
                     marker='o', linewidth=3, markersize=8, capsize=5, 
                     color='black', label='Llama Base', alpha=0.9, zorder=10)
        
        # Method colors
        METHOD_COLORS = {
            'elm': '#FF6B6B', 'rmu': '#4ECDC4', 'graddiff': '#95E77E',
            'pbj': '#FFD93D', 'tar': '#A8E6CF', 'rmu_lat': '#FF8B94',
            'repnoise': '#B4A7D6', 'rr': '#FFB347'
        }
        
        # Plot unlearning methods
        for method, path in UNLEARNING_RESULTS["llama"].items():
            df_unlearn = pd.read_csv(path)
            if 'source_topic' in df_unlearn.columns:
                df_unlearn = df_unlearn[df_unlearn['source_topic'] == wmdp_topic]
                df_unlearn = df_unlearn[df_unlearn["distance"] < max_distance]
                
                if len(df_unlearn) > 0:
                    df_unlearn["distance_bucket"] = (df_unlearn["distance"] // bucket_size) * bucket_size
                    unlearn_results = df_unlearn.groupby("distance_bucket")["is_correct"].agg(["mean", "std"])
                    unlearn_results["sem"] = unlearn_results["std"] / np.sqrt(df_unlearn.groupby("distance_bucket").size())
                    
                    color = METHOD_COLORS.get(method, '#888888')
                    ax2.errorbar(unlearn_results.index, unlearn_results["mean"] * 100, 
                                yerr=unlearn_results["sem"] * 100,
                                marker='s', linewidth=2, markersize=6, capsize=3,
                                color=color, label=method.upper().replace('_', '-'), alpha=0.8)
    
    ax2.set_xlabel('Distance Bucket', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title(f'WMDP Topic: {wmdp_topic}\nUnlearning Methods Comparison', fontsize=14)
    ax2.legend(loc="best", fontsize=10, ncol=2)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 100)
    
    plt.tight_layout()
    
    # Save plot
    date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
    safe_topic = wmdp_topic.replace("/", "_").replace(" ", "_")
    filename = f"wmdp_topic_ripple_{safe_topic}_{date_str}"
    plt.savefig(PLOT_DIR / f"{filename}.png", dpi=150, bbox_inches='tight')
    plt.savefig(PLOT_DIR / f"{filename}.pdf", bbox_inches='tight')
    print(f"Saved plot to {PLOT_DIR / filename}")
    
    plt.show()
    
    # Print statistics
    print(f"\nStatistics for WMDP topic: {wmdp_topic}")
    print(f"  Total evaluations: {len(df)}")
    print(f"  Unique questions: {df['question'].nunique()}")
    print(f"  Distance 0 accuracy: {df[df['distance']==0]['is_correct'].mean()*100:.1f}%" if len(df[df['distance']==0]) > 0 else "  No distance 0 data")
    print(f"  Overall accuracy: {df['is_correct'].mean()*100:.1f}%")

In [None]:
# Analyze first WMDP topic
if wmdp_topic_list:
    topic_to_analyze = wmdp_topic_list[0]
    print(f"Analyzing WMDP topic: {topic_to_analyze}")
    analyze_wmdp_topic_ripple(topic_to_analyze)
else:
    print("No WMDP topics found. Please check the data structure.")

## Compare Multiple WMDP Topics

In [None]:
def compare_wmdp_topics(topics_list, model="llama", dedup=False, bucket_size=10, max_distance=100):
    """Compare ripple effects across multiple WMDP topics."""
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    n_topics = min(len(topics_list), 6)
    
    for idx, wmdp_topic in enumerate(topics_list[:n_topics]):
        ax = axes[idx]
        
        df = pd.read_csv(BASE_RESULTS[model])
        
        # Filter for this WMDP topic
        if 'source_topic' in df.columns:
            df = df[df['source_topic'] == wmdp_topic]
        else:
            continue
            
        if len(df) == 0:
            ax.text(0.5, 0.5, f"No data for\n{wmdp_topic}", 
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_title(f"{wmdp_topic[:30]}..." if len(wmdp_topic) > 30 else wmdp_topic)
            continue
        
        df = df[df["distance"] < max_distance]
        df["distance_bucket"] = (df["distance"] // bucket_size) * bucket_size
        
        if dedup:
            # Dedup results
            df_dedup = df.groupby("question")[["is_correct", "distance_bucket"]].agg(
                ["max", "min", "sum", "count", "mean"]
            )
            results = df_dedup.groupby(df_dedup["distance_bucket"]["min"]).agg(["mean", "std"])["is_correct"]["mean"]
            results["sem"] = results["std"] / np.sqrt(df_dedup["distance_bucket"].groupby("min").size())
        else:
            # Raw results
            results = df.groupby("distance_bucket")["is_correct"].agg(["mean", "std"])
            results["sem"] = results["std"] / np.sqrt(df.groupby("distance_bucket").size())
        
        # Plot
        ax.errorbar(results.index, results["mean"] * 100, yerr=results["sem"] * 100,
                   marker='o', linewidth=2, markersize=6, capsize=3,
                   color='C0', label=model.title())
        
        # Add unlearning comparison (pick one method for clarity)
        if "elm" in UNLEARNING_RESULTS[model]:
            df_elm = pd.read_csv(UNLEARNING_RESULTS[model]["elm"])
            if 'source_topic' in df_elm.columns:
                df_elm = df_elm[df_elm['source_topic'] == wmdp_topic]
                df_elm = df_elm[df_elm["distance"] < max_distance]
                
                if len(df_elm) > 0:
                    df_elm["distance_bucket"] = (df_elm["distance"] // bucket_size) * bucket_size
                    
                    if dedup:
                        df_elm_dedup = df_elm.groupby("question")[["is_correct", "distance_bucket"]].agg(
                            ["max", "min", "sum", "count", "mean"]
                        )
                        elm_results = df_elm_dedup.groupby(df_elm_dedup["distance_bucket"]["min"]).agg(["mean", "std"])["is_correct"]["mean"]
                        elm_results["sem"] = elm_results["std"] / np.sqrt(df_elm_dedup["distance_bucket"].groupby("min").size())
                    else:
                        elm_results = df_elm.groupby("distance_bucket")["is_correct"].agg(["mean", "std"])
                        elm_results["sem"] = elm_results["std"] / np.sqrt(df_elm.groupby("distance_bucket").size())
                    
                    ax.errorbar(elm_results.index, elm_results["mean"] * 100, 
                               yerr=elm_results["sem"] * 100,
                               marker='s', linewidth=2, markersize=6, capsize=3,
                               color='#FF6B6B', label='ELM', alpha=0.8)
        
        ax.set_xlabel('Distance', fontsize=10)
        ax.set_ylabel('Accuracy (%)', fontsize=10)
        title = f"{wmdp_topic[:25]}..." if len(wmdp_topic) > 25 else wmdp_topic
        ax.set_title(title, fontsize=11)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 100)
        
        if idx == 0:
            ax.legend(loc="best", fontsize=9)
    
    # Hide unused subplots
    for idx in range(n_topics, len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle(f"WMDP Topics Ripple Effects - {model.title()} {'(Dedup)' if dedup else '(Raw)'}", 
                fontsize=14, y=1.02)
    plt.tight_layout()
    
    # Save
    date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
    dedup_str = "_dedup" if dedup else "_raw"
    filename = f"wmdp_topics_comparison_{model}{dedup_str}_{date_str}"
    plt.savefig(PLOT_DIR / f"{filename}.png", dpi=150, bbox_inches='tight')
    plt.savefig(PLOT_DIR / f"{filename}.pdf", bbox_inches='tight')
    
    plt.show()

In [None]:
# Compare top WMDP topics
if wmdp_topic_list:
    print("Comparing top 6 WMDP topics (Raw):")
    compare_wmdp_topics(wmdp_topic_list[:6], model="llama", dedup=False)
    
    print("\nComparing top 6 WMDP topics (Dedup):")
    compare_wmdp_topics(wmdp_topic_list[:6], model="llama", dedup=True)

## Aggregate Analysis Across WMDP Topics

In [None]:
def analyze_ripple_strength_by_topic(model="llama", bucket_size=10):
    """Analyze and rank WMDP topics by their ripple effect strength."""
    
    df_base = pd.read_csv(BASE_RESULTS[model])
    
    if 'source_topic' not in df_base.columns:
        print("No source_topic column found")
        return
    
    # Get WMDP topics (those with distance 0 entries)
    wmdp_topics = df_base[df_base['distance'] == 0]['source_topic'].unique()
    
    ripple_analysis = []
    
    for topic in wmdp_topics:
        topic_df = df_base[df_base['source_topic'] == topic]
        
        # Calculate accuracies at different distance ranges
        acc_0_10 = topic_df[topic_df['distance'] < 10]['is_correct'].mean()
        acc_10_20 = topic_df[(topic_df['distance'] >= 10) & (topic_df['distance'] < 20)]['is_correct'].mean()
        acc_20_30 = topic_df[(topic_df['distance'] >= 20) & (topic_df['distance'] < 30)]['is_correct'].mean()
        acc_30_50 = topic_df[(topic_df['distance'] >= 30) & (topic_df['distance'] < 50)]['is_correct'].mean()
        
        n_questions = topic_df['question'].nunique()
        n_total = len(topic_df)
        
        # Calculate ripple decay (how much accuracy drops with distance)
        if not pd.isna(acc_0_10) and not pd.isna(acc_20_30):
            ripple_decay = acc_0_10 - acc_20_30
        else:
            ripple_decay = np.nan
        
        ripple_analysis.append({
            'topic': topic,
            'n_questions': n_questions,
            'n_total': n_total,
            'acc_0_10': acc_0_10,
            'acc_10_20': acc_10_20,
            'acc_20_30': acc_20_30,
            'acc_30_50': acc_30_50,
            'ripple_decay': ripple_decay
        })
    
    df_analysis = pd.DataFrame(ripple_analysis)
    df_analysis = df_analysis.sort_values('ripple_decay', ascending=False)
    
    print(f"WMDP Topics Ranked by Ripple Decay (Accuracy drop from 0-10 to 20-30):\n")
    print(f"{'Topic':<40} {'N Questions':>12} {'Acc 0-10':>10} {'Acc 20-30':>10} {'Decay':>10}")
    print("-" * 85)
    
    for _, row in df_analysis.head(15).iterrows():
        if not pd.isna(row['ripple_decay']):
            print(f"{row['topic'][:40]:<40} {row['n_questions']:>12} "
                  f"{row['acc_0_10']*100:>9.1f}% {row['acc_20_30']*100:>9.1f}% "
                  f"{row['ripple_decay']*100:>9.1f}%")
    
    return df_analysis

In [None]:
# Analyze ripple strength across all WMDP topics
df_ripple_analysis = analyze_ripple_strength_by_topic(model="llama")

In [None]:
# Visualize ripple decay across topics
if df_ripple_analysis is not None and len(df_ripple_analysis) > 0:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: Ripple decay for top topics
    top_topics_decay = df_ripple_analysis.dropna(subset=['ripple_decay']).head(15)
    
    ax1.barh(range(len(top_topics_decay)), top_topics_decay['ripple_decay'] * 100)
    ax1.set_yticks(range(len(top_topics_decay)))
    ax1.set_yticklabels([t[:30] + '...' if len(t) > 30 else t for t in top_topics_decay['topic']])
    ax1.set_xlabel('Accuracy Decay (0-10 vs 20-30) %')
    ax1.set_title('WMDP Topics with Strongest Ripple Decay')
    ax1.grid(True, alpha=0.3, axis='x')
    
    # Plot 2: Accuracy profile for top 5 topics
    distances = ['0-10', '10-20', '20-30', '30-50']
    x = np.arange(len(distances))
    width = 0.15
    
    for i, (_, row) in enumerate(top_topics_decay.head(5).iterrows()):
        accuracies = [row['acc_0_10'], row['acc_10_20'], row['acc_20_30'], row['acc_30_50']]
        accuracies = [a * 100 if not pd.isna(a) else 0 for a in accuracies]
        
        label = row['topic'][:20] + '...' if len(row['topic']) > 20 else row['topic']
        ax2.bar(x + i * width, accuracies, width, label=label)
    
    ax2.set_xlabel('Distance Range')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Accuracy Profile for Top 5 WMDP Topics')
    ax2.set_xticks(x + width * 2)
    ax2.set_xticklabels(distances)
    ax2.legend(loc='best', fontsize=9)
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    
    # Save
    date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
    filename = f"wmdp_topics_ripple_analysis_{date_str}"
    plt.savefig(PLOT_DIR / f"{filename}.png", dpi=150, bbox_inches='tight')
    plt.savefig(PLOT_DIR / f"{filename}.pdf", bbox_inches='tight')
    
    plt.show()

## Focus on Specific WMDP Topic with All Methods

In [None]:
def detailed_wmdp_topic_analysis(wmdp_topic, bucket_size=5, max_distance=50):
    """Detailed analysis of a single WMDP topic with fine-grained distance buckets."""
    
    METHOD_COLORS = {
        'elm': '#FF6B6B', 'rmu': '#4ECDC4', 'graddiff': '#95E77E',
        'pbj': '#FFD93D', 'tar': '#A8E6CF', 'rmu_lat': '#FF8B94',
        'repnoise': '#B4A7D6', 'rr': '#FFB347'
    }
    
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    # Main plot: All methods comparison
    ax_main = fig.add_subplot(gs[:, :2])
    
    # Load base Llama data
    df_base = pd.read_csv(BASE_RESULTS["llama"])
    if 'source_topic' in df_base.columns:
        df_base = df_base[df_base['source_topic'] == wmdp_topic]
    else:
        print("No source_topic column found")
        return
    
    df_base = df_base[df_base["distance"] < max_distance]
    df_base["distance_bucket"] = (df_base["distance"] // bucket_size) * bucket_size
    
    # Base results
    base_results = df_base.groupby("distance_bucket")["is_correct"].agg(["mean", "std", "count"])
    base_results["sem"] = base_results["std"] / np.sqrt(base_results["count"])
    
    # Plot base
    ax_main.errorbar(base_results.index, base_results["mean"] * 100, 
                     yerr=base_results["sem"] * 100,
                     marker='o', linewidth=3, markersize=8, capsize=5,
                     color='black', label='Llama Base', alpha=0.9, zorder=10)
    
    # Plot each unlearning method
    delta_data = {}
    for method, path in UNLEARNING_RESULTS["llama"].items():
        df_unlearn = pd.read_csv(path)
        if 'source_topic' in df_unlearn.columns:
            df_unlearn = df_unlearn[df_unlearn['source_topic'] == wmdp_topic]
            df_unlearn = df_unlearn[df_unlearn["distance"] < max_distance]
            
            if len(df_unlearn) > 0:
                df_unlearn["distance_bucket"] = (df_unlearn["distance"] // bucket_size) * bucket_size
                unlearn_results = df_unlearn.groupby("distance_bucket")["is_correct"].agg(["mean", "std", "count"])
                unlearn_results["sem"] = unlearn_results["std"] / np.sqrt(unlearn_results["count"])
                
                color = METHOD_COLORS.get(method, '#888888')
                ax_main.errorbar(unlearn_results.index, unlearn_results["mean"] * 100,
                               yerr=unlearn_results["sem"] * 100,
                               marker='s', linewidth=2, markersize=6, capsize=3,
                               color=color, label=method.upper().replace('_', '-'), alpha=0.8)
                
                # Store delta for subplot
                common_idx = base_results.index.intersection(unlearn_results.index)
                if len(common_idx) > 0:
                    delta = base_results.loc[common_idx, "mean"] - unlearn_results.loc[common_idx, "mean"]
                    delta_data[method] = delta * 100
    
    ax_main.set_xlabel('Distance Bucket', fontsize=12)
    ax_main.set_ylabel('Accuracy (%)', fontsize=12)
    ax_main.set_title(f'WMDP Topic: {wmdp_topic}\nDetailed Ripple Effect Analysis', fontsize=14)
    ax_main.legend(loc='best', ncol=2, fontsize=10)
    ax_main.grid(True, alpha=0.3)
    ax_main.set_ylim(0, 100)
    
    # Subplot 1: Delta plot
    ax_delta = fig.add_subplot(gs[0, 2])
    for method, delta in delta_data.items():
        color = METHOD_COLORS.get(method, '#888888')
        ax_delta.plot(delta.index, delta.values, marker='s', linewidth=2,
                     color=color, label=method.upper()[:3], alpha=0.8)
    
    ax_delta.axhline(y=0, color='gray', linestyle=':', alpha=0.5)
    ax_delta.set_xlabel('Distance', fontsize=10)
    ax_delta.set_ylabel('Î” Accuracy (%)', fontsize=10)
    ax_delta.set_title('Ripple Effect\n(Base - Unlearned)', fontsize=11)
    ax_delta.legend(loc='best', fontsize=8, ncol=2)
    ax_delta.grid(True, alpha=0.3)
    
    # Subplot 2: Statistics table
    ax_stats = fig.add_subplot(gs[1, 2])
    ax_stats.axis('tight')
    ax_stats.axis('off')
    
    # Create statistics table
    stats_data = []
    stats_data.append(['Metric', 'Value'])
    stats_data.append(['Total Questions', f"{df_base['question'].nunique()}"])
    stats_data.append(['Total Evaluations', f"{len(df_base)}"])
    stats_data.append(['Distance Range', f"{df_base['distance'].min():.0f}-{df_base['distance'].max():.0f}"])
    stats_data.append(['Base Acc (0-10)', f"{df_base[df_base['distance']<10]['is_correct'].mean()*100:.1f}%"])
    stats_data.append(['Base Acc (20-30)', f"{df_base[(df_base['distance']>=20) & (df_base['distance']<30)]['is_correct'].mean()*100:.1f}%"])
    
    table = ax_stats.table(cellText=stats_data, cellLoc='left', loc='center',
                          colWidths=[0.6, 0.4])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.5)
    
    # Style header row
    for i in range(2):
        table[(0, i)].set_facecolor('#E0E0E0')
        table[(0, i)].set_text_props(weight='bold')
    
    plt.suptitle(f'Comprehensive Analysis: {wmdp_topic}', fontsize=16, y=0.98)
    
    # Save
    date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
    safe_topic = wmdp_topic.replace("/", "_").replace(" ", "_")
    filename = f"wmdp_topic_detailed_{safe_topic}_{date_str}"
    plt.savefig(PLOT_DIR / f"{filename}.png", dpi=150, bbox_inches='tight')
    plt.savefig(PLOT_DIR / f"{filename}.pdf", bbox_inches='tight')
    
    plt.show()

In [None]:
# Perform detailed analysis on a specific WMDP topic
if wmdp_topic_list:
    # You can change this to any WMDP topic of interest
    topic_for_detailed_analysis = wmdp_topic_list[0]  # or specify manually
    print(f"Performing detailed analysis for: {topic_for_detailed_analysis}")
    detailed_wmdp_topic_analysis(topic_for_detailed_analysis, bucket_size=5, max_distance=50)

In [None]:
# Interactive selection - uncomment and modify topic name
# custom_wmdp_topic = ""  # Enter a specific WMDP topic name here
# if custom_wmdp_topic:
#     analyze_wmdp_topic_ripple(custom_wmdp_topic)
#     detailed_wmdp_topic_analysis(custom_wmdp_topic)