# Load Data

In [None]:
import shutil
import os

src_dir = '/content/drive/MyDrive/counting_project/counting_dataset/CMA_analysis'
dst_dir = '/content/CMA_analysis'

# Make sure destination directory exists
os.makedirs(dst_dir, exist_ok=True)

# Loop through all files in source directory
for filename in os.listdir(src_dir):
    src_file = os.path.join(src_dir, filename)
    dst_file = os.path.join(dst_dir, filename)

    # Only copy files (skip subdirectories)
    if os.path.isfile(src_file):
        shutil.copy2(src_file, dst_file)  # copy2 preserves metadata

# counterfactual activation patching

### calculating TE,IE

In [None]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any
from tqdm import tqdm
import json
import gc
import re
import warnings
warnings.filterwarnings("ignore")

class CMAEffectsCalculator:
    def __init__(self, model_name: str = "microsoft/phi-4", device: Optional[str] = None):
        """Initialize model and tokenizer for effects calculation"""
        print(f"Loading model {model_name}...")

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        self.model_name = model_name

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
            trust_remote_code=True
        )
        self.model.eval()

        # Get model info
        if hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
            self.num_layers = len(self.model.model.layers)
            self.layer_attr = 'model.layers'
        elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
            self.num_layers = len(self.model.transformer.h)
            self.layer_attr = 'transformer.h'
        else:
            raise ValueError("Cannot determine model architecture")

        print(f"Model loaded with {self.num_layers} layers")
        print(f"Device: {self.device}")

    def format_prompt(self, category: str, word_list: List[str]) -> str:
        """Format the counting prompt to match your evaluation script exactly"""
        # Convert list to string format matching your examples
        word_list_str = ' '.join(word_list)

        # Match the exact format from your evaluation script
        problem = f"""Count how many words in this list match the type "{category}".

List: {word_list}

Respond with only the number in parentheses, like (0), (1), (2), etc."""

        # Use the same message format
        messages = [
            {"role": "system", "content": "You are a precise counting assistant. When given a list and a type, count how many items match that type. Always respond with ONLY the count in parentheses format: (0), (1), (2), etc. Never include explanations or other text."},
            {"role": "user", "content": problem}
        ]

        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

    def extract_number_from_output(self, text: str) -> int:
        """Extract number from model output - matching your evaluation script exactly"""
        # Clean the text
        text = text.strip()

        # Look for exact pattern (number)
        match = re.search(r'\((\d+)\)', text)
        if match:
            return int(match.group(1))

        # If no parentheses, look for just a number
        match = re.search(r'^\s*(\d+)\s*$', text)
        if match:
            return int(match.group(1))

        # If still nothing, look for any number in the text
        match = re.search(r'(\d+)', text)
        if match:
            return int(match.group(1))

        return -1  # Return -1 if no answer found

    def get_model_output(self, prompt: str) -> Tuple[int, str, torch.Tensor]:
        """Get model's prediction and logits"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=5,  # Match your evaluation script
                # temperature=0.0,   # Deterministic
                do_sample=False,
                pad_token_id=self.tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        # Extract generated text
        generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

        # Extract number
        predicted_number = self.extract_number_from_output(generated_text)

        # Get last logits for probability analysis
        if outputs.scores:
            last_logits = outputs.scores[-1][0]
        else:
            last_logits = None

        return predicted_number, generated_text, last_logits

    def get_hidden_states(self, prompt: str) -> Tuple[Dict[int, torch.Tensor], torch.Tensor]:
        """Extract hidden states from all layers"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model(
                **inputs,
                output_hidden_states=True,
                return_dict=True
            )

        # Extract hidden states
        hidden_states = {}
        all_hidden = outputs.hidden_states

        # Skip embedding layer (index 0)
        for layer_idx in range(self.num_layers):
            hidden_states[layer_idx] = all_hidden[layer_idx + 1].cpu()

        return hidden_states, inputs.input_ids

    def find_intervention_token_positions(self, prompt: str, word_list: List[str],
                                          intervention_position: int,
                                          target_word: str) -> List[int]:
        """
        Hybrid version:
        1. Try to locate the token positions of `target_word` using intervention_position within `word_list`.
        2. If that fails, fall back to robust token matching within the whole prompt.
        """
        inputs = self.tokenizer(prompt, return_tensors="pt")
        full_tokens = self.tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

        # --- Primary strategy: Use intervention_position ---
        if 0 <= intervention_position < len(word_list):
            # Insert a space before the word to encourage separate tokenization
            tokenized_word = self.tokenizer.tokenize(f" {word_list[intervention_position]}")

            # Tokenize the full word list and match normalized version in full prompt
            word_list_str = ' '.join(word_list)
            word_list_tokens = self.tokenizer.tokenize(word_list_str)

            # Locate word_list in full_tokens
            for i in range(len(full_tokens) - len(word_list_tokens) + 1):
                full_slice = full_tokens[i:i+len(word_list_tokens)]
                if full_slice == word_list_tokens:
                    start = i
                    return list(range(start + sum(len(self.tokenizer.tokenize(f" {word_list[j]}"))
                                                  for j in range(intervention_position)),
                                      start + sum(len(self.tokenizer.tokenize(f" {word_list[j]}"))
                                                  for j in range(intervention_position + 1))))

        # --- Fallback strategy: Match target_word robustly ---
        # Method 1: Try to find exact target_word tokens in the full prompt
        target_word_tokens = self.tokenizer.tokenize(target_word)
        target_word_tokens_normalized = [token.lstrip('Ġ▁') for token in target_word_tokens]

        for i in range(len(full_tokens) - len(target_word_tokens) + 1):
            slice_tokens = full_tokens[i:i+len(target_word_tokens)]
            normalized = [t.lstrip('Ġ▁') for t in slice_tokens]
            if normalized == target_word_tokens_normalized:
                return list(range(i, i + len(target_word_tokens)))

        # Method 2: Try fuzzy token string matching
        variants = [
            target_word.lower(),
            f"▁{target_word.lower()}",
            f"Ġ{target_word.lower()}",
            target_word.lower().capitalize(),
            f"▁{target_word.lower().capitalize()}",
            f"Ġ{target_word.lower().capitalize()}"
        ]

        for i, token in enumerate(full_tokens):
            cleaned = token.lower().lstrip("Ġ▁")
            if cleaned in [v.lstrip("Ġ▁") for v in variants]:
                return [i]

        # If all fails, return empty
        return []

    def patch_forward_pass(self, prompt: str, layer_idx: int,
                          patch_activation: torch.Tensor,
                          patch_positions: List[int]) -> Tuple[int, str]:
        """Run forward pass with patched activations"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # Get the layer to patch
        if 'model.layers' in self.layer_attr:
            target_layer = self.model.model.layers[layer_idx]
        else:
            target_layer = self.model.transformer.h[layer_idx]

        # Store original forward method
        original_forward = target_layer.forward

        # Track if patch was applied
        patch_applied = [False]

        def patched_forward(hidden_states, *args, **kwargs):
            # Call original forward
            outputs = original_forward(hidden_states, *args, **kwargs)

            # Extract hidden states from output
            if isinstance(outputs, tuple):
                hidden_states_out = outputs[0]
                other_outputs = outputs[1:]
            else:
                hidden_states_out = outputs
                other_outputs = ()

            # Apply patch only at specified positions
            for pos in patch_positions:
                if pos < hidden_states_out.shape[1] and pos < patch_activation.shape[1]:
                    # Apply patch
                    patch_value = patch_activation[:, pos, :].to(hidden_states_out.device)
                    hidden_states_out[:, pos, :] = patch_value
                    patch_applied[0] = True

            # Return in original format
            if isinstance(outputs, tuple):
                return (hidden_states_out,) + other_outputs
            else:
                return hidden_states_out

        # Temporarily replace forward method
        target_layer.forward = patched_forward

        try:
            # Generate with patched activation
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=5,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id
                )

            # Extract result
            generated_ids = outputs[0][inputs.input_ids.shape[1]:]
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            predicted_number = self.extract_number_from_output(generated_text)

        finally:
            # Restore original forward method
            target_layer.forward = original_forward

        return predicted_number, generated_text

    def calculate_effects(self, pair: Dict, layer_idx: int) -> Dict[str, Any]:
        """Calculate only TE and IE for a single pair and layer"""

        # Format prompts
        original_prompt = self.format_prompt(pair['category'], pair['original_list'])
        intervention_prompt = self.format_prompt(pair['category'], pair['intervention_list'])

        # Step 1: Get original and intervention outputs (for TE)
        original_num, original_text, _ = self.get_model_output(original_prompt)
        intervention_num, intervention_text, _ = self.get_model_output(intervention_prompt)

        # Total Effect
        TE = intervention_num - original_num

        # Step 2: Get hidden states
        original_hidden, _ = self.get_hidden_states(original_prompt)
        intervention_hidden, _ = self.get_hidden_states(intervention_prompt)

        # Step 3: Find positions to patch
        patch_positions = self.find_intervention_token_positions(
            intervention_prompt,
            pair['intervention_list'],
            pair['intervention_position'],
            pair['intervention_word']
        )

        # Step 4: Calculate Indirect Effect (IE)
        # IE: Run original prompt but patch intervention activations at original word position
        ie_original_positions = self.find_intervention_token_positions(
            original_prompt,
            pair['original_list'],
            pair['intervention_position'],
            pair['original_word']
        )

        ie_num, ie_text = self.patch_forward_pass(
            original_prompt,
            layer_idx,
            intervention_hidden[layer_idx],
            ie_original_positions
        )
        IE = ie_num - original_num

        # Collect results
        return {
            'pair_id': pair['pair_id'],
            'layer_idx': layer_idx,
            'TE': TE,
            'IE': IE,
            'original_output': original_num,
            'intervention_output': intervention_num,
            'ie_output': ie_num,
            'original_text': original_text.strip(),
            'intervention_text': intervention_text.strip(),
            'ie_text': ie_text.strip(),
            'expected_original': pair['original_count'],
            'expected_intervention': pair['intervention_count'],
            'original_correct': original_num == pair['original_count'],
            'intervention_correct': intervention_num == pair['intervention_count'],
            'patch_positions': patch_positions,
            'ie_positions': ie_original_positions,
            'num_patch_positions': len(patch_positions),
            'num_ie_positions': len(ie_original_positions)
        }

    def calculate_effects_batch(self, pairs: List[Dict],
                               layers_to_test: Optional[List[int]] = None,
                               save_frequency: int = 10,
                               output_dir: str = "/content/counting_dataset") -> pd.DataFrame:
        """Calculate effects for multiple pairs and layers"""

        if layers_to_test is None:
            # Test middle and later layers by default
            layers_to_test = list(range(self.num_layers // 2, self.num_layers))
            print(f"Testing layers {layers_to_test[0]} to {layers_to_test[-1]}")

        results = []
        total_calculations = len(pairs) * len(layers_to_test)

        print(f"Calculating effects for {len(pairs)} pairs and {len(layers_to_test)} layers")
        print(f"Total calculations: {total_calculations}")

        with tqdm(total=total_calculations, desc="Calculating effects") as pbar:
            for pair_idx, pair in enumerate(pairs):
                for layer_idx in layers_to_test:
                    try:
                        # Calculate effects
                        effects = self.calculate_effects(pair, layer_idx)

                        # Add pair metadata
                        effects.update({
                            'category': pair['category'],
                            'intervention_position': pair['intervention_position'],
                            'list_length': len(pair['original_list']),
                            'original_word': pair['original_word'],
                            'intervention_word': pair['intervention_word']
                        })

                        results.append(effects)

                    except Exception as e:
                        print(f"\nError with pair {pair['pair_id']}, layer {layer_idx}: {e}")
                        # Add failed result
                        results.append({
                            'pair_id': pair['pair_id'],
                            'layer_idx': layer_idx,
                            'error': str(e),
                            'TE': np.nan,
                            'IE': np.nan
                        })

                    pbar.update(1)

                # Save intermediate results
                if (pair_idx + 1) % save_frequency == 0:
                    df_temp = pd.DataFrame(results)
                    temp_path = f"{output_dir}/cma_effects_intermediate_{pair_idx+1}.csv"
                    df_temp.to_csv(temp_path, index=False)
                    print(f"\nSaved intermediate results ({pair_idx+1} pairs processed)")

                # Clear cache periodically
                if (pair_idx + 1) % 5 == 0:
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

        # Create final dataframe
        results_df = pd.DataFrame(results)

        # Save final results
        final_path = f"{output_dir}/cma_effects_results.csv"
        results_df.to_csv(final_path, index=False)
        print(f"\nSaved final results to {final_path}")

        return results_df

    def run_analysis(self, pairs_path: str = "/content/counting_dataset/cma_intervention_pairs.json",
                    layers_to_test: Optional[List[int]] = None,
                    max_pairs: Optional[int] = None,
                    output_dir: str = "/content/counting_dataset") -> pd.DataFrame:
        """Complete pipeline to run effects calculation"""

        # Load pairs
        print(f"Loading pairs from {pairs_path}")
        with open(pairs_path, 'r') as f:
            pairs = json.load(f)

        # Limit pairs if specified
        if max_pairs:
            pairs = pairs[:max_pairs]
            print(f"Using first {max_pairs} pairs")

        # Calculate effects
        results_df = self.calculate_effects_batch(pairs, layers_to_test, output_dir=output_dir)

        # Print summary statistics
        self.print_summary_stats(results_df)

        return results_df

    def print_summary_stats(self, results_df: pd.DataFrame):
        """Print summary statistics of the results"""
        print("\n=== Effects Calculation Summary ===")
        print(f"Total calculations: {len(results_df)}")

        # Check for errors
        if 'error' in results_df.columns:
            error_count = results_df['error'].notna().sum()
            print(f"Errors: {error_count}")

        # Basic statistics
        valid_results = results_df[results_df['TE'].notna()]
        print(f"Valid results: {len(valid_results)}")

        if len(valid_results) > 0:
            print("\nModel Accuracy:")
            orig_acc = valid_results['original_correct'].mean()
            int_acc = valid_results['intervention_correct'].mean()
            print(f"  Original prompts: {orig_acc:.2%} ({valid_results['original_correct'].sum()}/{len(valid_results)})")
            print(f"  Intervention prompts: {int_acc:.2%} ({valid_results['intervention_correct'].sum()}/{len(valid_results)})")

            print("\nEffect Statistics:")
            print(f"  Mean TE: {valid_results['TE'].mean():.3f} (std: {valid_results['TE'].std():.3f})")
            print(f"  Mean |TE|: {np.abs(valid_results['TE']).mean():.3f}")
            print(f"  Mean IE: {valid_results['IE'].mean():.3f} (std: {valid_results['IE'].std():.3f})")

            # Effect by layer
            print("\nEffects by Layer:")
            layer_effects = valid_results.groupby('layer_idx')[['TE', 'IE']].mean()
            print(layer_effects.round(3))


# Example usage
if __name__ == "__main__":
    # Initialize calculator
    calculator = CMAEffectsCalculator(model_name="microsoft/phi-4")

    # Run analysis with specific output directory
    results_df = calculator.run_analysis(
        pairs_path="/content/CMA_analysis/cma_intervention_pairs.json",
        laye
        rs_to_test=list(range(20)),
        max_pairs=100,
        output_dir="/content/CMA_analysis"
    )

    print("\nAnalysis complete!")
    print(f"Results shape: {results_df.shape}")

    # Show some example results
    if len(results_df) > 0:
        print("\nExample results (first 3 rows):")
        cols_to_show = ['pair_id', 'layer_idx', 'TE', 'IE',
                       'original_correct', 'intervention_correct', 'original_text', 'intervention_text']
        print(results_df[cols_to_show].head(3))


Loading model microsoft/phi-4...


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Model loaded with 40 layers
Device: cuda
Loading pairs from /content/CMA_analysis/cma_intervention_pairs.json
Using first 100 pairs
Calculating effects for 100 pairs and 20 layers
Total calculations: 2000


Calculating effects:  10%|█         | 200/2000 [02:27<21:52,  1.37it/s]


Saved intermediate results (10 pairs processed)


Calculating effects:  20%|██        | 400/2000 [04:54<19:22,  1.38it/s]


Saved intermediate results (20 pairs processed)


Calculating effects:  30%|███       | 600/2000 [07:21<16:56,  1.38it/s]


Saved intermediate results (30 pairs processed)


Calculating effects:  40%|████      | 800/2000 [09:48<14:33,  1.37it/s]


Saved intermediate results (40 pairs processed)


Calculating effects:  50%|█████     | 1000/2000 [12:16<12:21,  1.35it/s]


Saved intermediate results (50 pairs processed)


Calculating effects:  60%|██████    | 1200/2000 [14:42<09:52,  1.35it/s]


Saved intermediate results (60 pairs processed)


Calculating effects:  70%|███████   | 1400/2000 [17:09<07:19,  1.37it/s]


Saved intermediate results (70 pairs processed)


Calculating effects:  80%|████████  | 1600/2000 [19:36<04:58,  1.34it/s]


Saved intermediate results (80 pairs processed)


Calculating effects:  90%|█████████ | 1800/2000 [22:03<02:28,  1.35it/s]


Saved intermediate results (90 pairs processed)


Calculating effects: 100%|██████████| 2000/2000 [24:31<00:00,  1.36it/s]


Saved intermediate results (100 pairs processed)

Saved final results to /content/CMA_analysis/cma_effects_results.csv

=== Effects Calculation Summary ===
Total calculations: 2000
Valid results: 2000

Model Accuracy:
  Original prompts: 83.00% (1660/2000)
  Intervention prompts: 87.00% (1740/2000)

Effect Statistics:
  Mean TE: -0.880 (std: 0.407)
  Mean |TE|: 0.900
  Mean IE: -0.595 (std: 0.532)

Effects by Layer:
             TE    IE
layer_idx            
0         -0.88 -0.87
1         -0.88 -0.87
2         -0.88 -0.87
3         -0.88 -0.87
4         -0.88 -0.87
5         -0.88 -0.87
6         -0.88 -0.86
7         -0.88 -0.86
8         -0.88 -0.85
9         -0.88 -0.84
10        -0.88 -0.78
11        -0.88 -0.70
12        -0.88 -0.63
13        -0.88 -0.38
14        -0.88 -0.28
15        -0.88 -0.22
16        -0.88 -0.16
17        -0.88 -0.09
18        -0.88 -0.03
19        -0.88  0.00

Analysis complete!
Results shape: (2000, 23)

Example results (first 3 rows):
   pair_id  laye




### analysis

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from typing import Dict, List, Tuple, Optional
import json
import os

class SimplifiedCMAAnalyzer:
    def __init__(self, results_path: str = "/content/counting_dataset/cma_effects_results.csv",
                 output_dir: str = "/content/counting_dataset"):
        """Initialize analyzer with results data"""
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

        print(f"Loading results from {results_path}")
        self.results_df = pd.read_csv(results_path)
        print(f"Loaded {len(self.results_df)} results")

        # Clean data
        self.results_df = self.results_df[self.results_df['TE'].notna()]
        print(f"Valid results: {len(self.results_df)}")

    def compute_layer_statistics(self) -> pd.DataFrame:
        """Compute statistics for each layer"""
        layer_stats = []

        for layer_idx in sorted(self.results_df['layer_idx'].unique()):
            layer_data = self.results_df[self.results_df['layer_idx'] == layer_idx]

            stats_dict = {
                'layer_idx': layer_idx,
                'n_samples': len(layer_data),
                # Effect statistics
                'TE_mean': layer_data['TE'].mean(),
                'TE_std': layer_data['TE'].std(),
                'TE_abs_mean': layer_data['TE'].abs().mean(),
                'IE_mean': layer_data['IE'].mean(),
                'IE_std': layer_data['IE'].std(),
                'IE_abs_mean': layer_data['IE'].abs().mean(),
                # Accuracy metrics
                'correct_original': layer_data['original_correct'].mean() if 'original_correct' in layer_data.columns else 0,
                'correct_intervention': layer_data['intervention_correct'].mean() if 'intervention_correct' in layer_data.columns else 0,
            }

            layer_stats.append(stats_dict)

        return pd.DataFrame(layer_stats).sort_values('layer_idx')

    def plot_effect_magnitudes(self, layer_stats: pd.DataFrame, save_path: Optional[str] = None):
        """Plot effect magnitudes by layer"""
        if save_path is None:
            save_path = os.path.join(self.output_dir, "effect_magnitudes_by_layer.png")

        plt.figure(figsize=(12, 6))

        x = layer_stats['layer_idx']
        plt.plot(x, layer_stats['TE_abs_mean'], 'o-', label='|TE| (Total Effect)',
                markersize=8, linewidth=2, color='blue')
        plt.plot(x, layer_stats['IE_abs_mean'], '^-', label='|IE| (Indirect Effect)',
                markersize=8, linewidth=2, color='red')

        plt.xlabel('Layer Index', fontsize=14)
        plt.ylabel('Mean Absolute Effect', fontsize=14)
        plt.title('Effect Magnitudes by Layer', fontsize=16, fontweight='bold')
        plt.legend(fontsize=12)
        plt.grid(True, alpha=0.3)

        # Fix x-axis to show integer ticks
        plt.xticks(x.astype(int))

        # Add some styling
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved effect magnitudes plot to {save_path}")

    def calculate_overall_accuracy(self) -> Dict[str, float]:
        """Calculate overall correct rates"""
        overall_stats = {}

        if 'original_correct' in self.results_df.columns:
            overall_stats['original_correct_rate'] = self.results_df['original_correct'].mean()
        else:
            overall_stats['original_correct_rate'] = 0

        if 'intervention_correct' in self.results_df.columns:
            overall_stats['intervention_correct_rate'] = self.results_df['intervention_correct'].mean()
        else:
            overall_stats['intervention_correct_rate'] = 0

        return overall_stats

    def plot_overall_accuracy(self, save_path: Optional[str] = None):
        """Plot overall correct rates"""
        if save_path is None:
            save_path = os.path.join(self.output_dir, "overall_correct_rates.png")

        accuracy_stats = self.calculate_overall_accuracy()

        plt.figure(figsize=(10, 6))

        categories = ['Original', 'After Intervention']
        rates = [accuracy_stats['original_correct_rate'], accuracy_stats['intervention_correct_rate']]
        colors = ['skyblue', 'lightcoral']

        bars = plt.bar(categories, rates, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

        # Add value labels on bars
        for bar, rate in zip(bars, rates):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{rate:.2%}', ha='center', va='bottom', fontsize=14, fontweight='bold')

        plt.ylabel('Correct Rate', fontsize=14)
        plt.title('Overall Model Accuracy: Original vs After Intervention', fontsize=16, fontweight='bold')
        plt.ylim(0, 1.1)
        plt.grid(True, alpha=0.3, axis='y')

        # Add percentage formatting to y-axis
        plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.0%}'.format(y)))

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved overall accuracy plot to {save_path}")

        return accuracy_stats

    def run_analysis(self):
        """Run the simplified analysis with only the two required figures"""
        print("Running simplified CMA analysis...")

        # Compute layer statistics
        layer_stats = self.compute_layer_statistics()

        # Save layer statistics
        stats_path = os.path.join(self.output_dir, "layer_statistics.csv")
        layer_stats.to_csv(stats_path, index=False)
        print(f"Saved layer statistics to {stats_path}")

        # Generate the two required plots
        self.plot_effect_magnitudes(layer_stats)
        accuracy_stats = self.plot_overall_accuracy()

        # Print summary
        print("\n=== ANALYSIS SUMMARY ===")
        print(f"Total samples analyzed: {len(self.results_df)}")
        print(f"Number of layers: {len(layer_stats)}")
        print(f"Layer range: {layer_stats['layer_idx'].min()} to {layer_stats['layer_idx'].max()}")
        print(f"\nOverall Accuracy:")
        print(f"  Original correct rate: {accuracy_stats['original_correct_rate']:.2%}")
        print(f"  After intervention correct rate: {accuracy_stats['intervention_correct_rate']:.2%}")
        print(f"  Accuracy change: {accuracy_stats['intervention_correct_rate'] - accuracy_stats['original_correct_rate']:.2%}")

        print(f"\nEffect Statistics:")
        print(f"  Mean |TE| across all layers: {self.results_df['TE'].abs().mean():.3f}")
        print(f"  Mean |IE| across all layers: {self.results_df['IE'].abs().mean():.3f}")

        # Find layer with highest effects
        max_te_layer = layer_stats.loc[layer_stats['TE_abs_mean'].idxmax()]
        max_ie_layer = layer_stats.loc[layer_stats['IE_abs_mean'].idxmax()]

        print(f"\nHighest Effects:")
        print(f"  Highest |TE|: Layer {int(max_te_layer['layer_idx'])} ({max_te_layer['TE_abs_mean']:.3f})")
        print(f"  Highest |IE|: Layer {int(max_ie_layer['layer_idx'])} ({max_ie_layer['IE_abs_mean']:.3f})")

        return layer_stats, accuracy_stats


# Example usage
if __name__ == "__main__":
    # Initialize analyzer
    analyzer = SimplifiedCMAAnalyzer(
        results_path="/content/CMA_analysis/cma_effects_results.csv",
        output_dir="/content/CMA_analysis"
    )

    # Run analysis with only the two required figures
    layer_stats, accuracy_stats = analyzer.run_analysis()

    # Display top 5 layers by IE
    print("\n=== Top 5 Layers by Indirect Effect ===")
    top_ie_layers = layer_stats.nlargest(5, 'IE_abs_mean')
    print(top_ie_layers[['layer_idx', 'IE_abs_mean', 'TE_abs_mean', 'correct_original', 'correct_intervention']])

Loading results from /content/CMA_analysis/cma_effects_results.csv
Loaded 2000 results
Valid results: 2000
Running simplified CMA analysis...
Saved layer statistics to /content/CMA_analysis/layer_statistics.csv
Saved effect magnitudes plot to /content/CMA_analysis/effect_magnitudes_by_layer.png
Saved overall accuracy plot to /content/CMA_analysis/overall_correct_rates.png

=== ANALYSIS SUMMARY ===
Total samples analyzed: 2000
Number of layers: 20
Layer range: 0 to 19

Overall Accuracy:
  Original correct rate: 83.00%
  After intervention correct rate: 87.00%
  Accuracy change: 4.00%

Effect Statistics:
  Mean |TE| across all layers: 0.900
  Mean |IE| across all layers: 0.623

Highest Effects:
  Highest |TE|: Layer 0 (0.900)
  Highest |IE|: Layer 7 (0.900)

=== Top 5 Layers by Indirect Effect ===
   layer_idx  IE_abs_mean  TE_abs_mean  correct_original  correct_intervention
7          7         0.90          0.9              0.83                  0.87
0          0         0.89          

### save result

In [None]:
import shutil
import os

src_dir = '/content/CMA_analysis'
dst_dir = '/content/drive/MyDrive/counting_project/counting_dataset/CMA_analysis'

# Make sure destination directory exists
os.makedirs(dst_dir, exist_ok=True)

# Loop through all files in source directory
for filename in os.listdir(src_dir):
    src_file = os.path.join(src_dir, filename)
    dst_file = os.path.join(dst_dir, filename)

    # Only copy files (skip subdirectories)
    if os.path.isfile(src_file):
        shutil.copy2(src_file, dst_file)  # copy2 preserves metadata

# End