## Abstention Direction Analysis

In this section, we'll identify a one-dimensional activation direction in the residual stream that linearly separates abstention from non-abstention at the answer token, focusing on forms V1 and V2 which elicit abstention more reliably.


In [1]:
from pathlib import Path

import numpy as np

# Set random seed for reproducibility
np.random.seed(42)


### 1. Load and Prepare Data

First, we'll load the residual stream data for forms V1 and V2, and filter based on CA scores.


In [2]:
def load_residual_stream_data(forms=['V1', 'V2'], split='train'):
    """
    Load residual stream data for specified forms and split from the by_layer directory.
    
    Args:
        forms: List of forms to include (default: V1 and V2)
        split: Dataset split to use (default: train)
        
    Returns:
        Dictionary mapping layer indices to dictionaries of experiment data
    """
    import json
    # Path to the by_layer directory
    by_layer_dir = Path("/home/sergio/projects/MATS-Project/results/by_layer_20250911_112921/by_layer")

    if not by_layer_dir.exists():
        print(f"Directory not found: {by_layer_dir}")
        return {}

    # Initialize data structure
    # {layer_idx: {exp_name: {'vectors': ndarray, 'question_ids': ndarray, 'metadata': dict}}}
    residual_data = {}

    # List all layer directories
    layer_dirs = sorted([d for d in by_layer_dir.glob("layer_*") if d.is_dir()])
    print(f"Found {len(layer_dirs)} layer directories")

    # Process each layer directory
    for layer_dir in layer_dirs:
        layer_idx = int(layer_dir.name.split('_')[1])
        print(f"Processing layer {layer_idx}...")

        # Initialize layer data
        residual_data[layer_idx] = {}

        # Find all experiment files for the specified forms
        for form in forms:
            # Look for .npz files matching the form
            npz_files = list(layer_dir.glob(f"{form}_*_{form}.npz"))

            for npz_file in npz_files:
                # Get the experiment name (e.g., "V1_alpha_p1")
                exp_name = npz_file.stem.replace(f"_{form}", "")

                # Load the npz data
                npz_data = np.load(npz_file)
                question_ids = npz_data['question_ids']
                vectors = npz_data['vectors']

                # Load the corresponding JSON metadata
                json_file = npz_file.with_suffix('.json')
                if json_file.exists():
                    with open(json_file) as f:
                        metadata = json.load(f)
                else:
                    metadata = {}

                # Store the data
                residual_data[layer_idx][exp_name] = {
                    'vectors': vectors,
                    'question_ids': question_ids,
                    'metadata': metadata
                }

                print(f"  Loaded {exp_name} with {len(question_ids)} questions")

    return residual_data

# Load residual stream data for forms V1 and V2
try:
    print("Loading residual stream data from by_layer directory...")
    residual_data = load_residual_stream_data(forms=['V1', 'V2'])

    # Count the total number of experiments across all layers
    total_experiments = sum(len(layer_data) for layer_data in residual_data.values())

    if residual_data:
        print(f"Loaded data for {len(residual_data)} layers with a total of {total_experiments} experiments")

        # Print details for the first layer
        first_layer = min(residual_data.keys())
        first_layer_data = residual_data[first_layer]
        print(f"\nLayer {first_layer} contains {len(first_layer_data)} experiments")

        # Print details for the first experiment in the first layer
        if first_layer_data:
            first_exp = next(iter(first_layer_data))
            exp_data = first_layer_data[first_exp]
            print(f"  Experiment: {first_exp}")
            print(f"  Number of questions: {len(exp_data['question_ids'])}")
            print(f"  Vector dimension: {exp_data['vectors'][0].shape}")

            # Print metadata keys
            if 'metadata' in exp_data and exp_data['metadata']:
                print(f"  Metadata keys: {list(exp_data['metadata'].keys())}")

                # Print question metadata for the first question if available
                if 'question_metadata' in exp_data['metadata']:
                    first_q_id = exp_data['question_ids'][0]
                    if first_q_id in exp_data['metadata']['question_metadata']:
                        q_meta = exp_data['metadata']['question_metadata'][first_q_id]
                        print(f"  First question: {q_meta.get('question', 'N/A')}")
                        print(f"  Subject: {q_meta.get('subject', 'N/A')}")
                        print(f"  Answer: {q_meta.get('answer', 'N/A')}")
    else:
        print("No residual stream data found for forms V1 and V2")
except Exception as e:
    print(f"Error loading residual stream data: {e}")


Loading residual stream data from by_layer directory...
Found 35 layer directories
Processing layer 0...
  Loaded V1_num_p4 with 1050 questions
  Loaded V1_alpha_p4 with 1050 questions
  Loaded V1_alpha_p3 with 1050 questions
  Loaded V1_num_p2 with 1050 questions
  Loaded V1_num_p1 with 1050 questions
  Loaded V1_num_p3 with 1050 questions
  Loaded V1_alpha_p1 with 1050 questions
  Loaded V1_alpha_p5 with 1050 questions
  Loaded V1_alpha_p2 with 1050 questions
  Loaded V1_num_p5 with 1050 questions
  Loaded V2_alpha_p4 with 1050 questions
  Loaded V2_num_p5 with 1050 questions
  Loaded V2_alpha_p3 with 1050 questions
  Loaded V2_num_p3 with 1050 questions
  Loaded V2_alpha_p1 with 1050 questions
  Loaded V2_alpha_p2 with 1050 questions
  Loaded V2_alpha_p5 with 1050 questions
  Loaded V2_num_p1 with 1050 questions
  Loaded V2_num_p2 with 1050 questions
  Loaded V2_num_p4 with 1050 questions
Processing layer 1...
  Loaded V1_num_p4 with 1050 questions
  Loaded V1_alpha_p4 with 1050 que

  Loaded V2_alpha_p1 with 1050 questions
  Loaded V2_alpha_p2 with 1050 questions
  Loaded V2_alpha_p5 with 1050 questions
  Loaded V2_num_p1 with 1050 questions
  Loaded V2_num_p2 with 1050 questions
  Loaded V2_num_p4 with 1050 questions
Processing layer 2...
  Loaded V1_num_p4 with 1050 questions
  Loaded V1_alpha_p4 with 1050 questions
  Loaded V1_alpha_p3 with 1050 questions
  Loaded V1_num_p2 with 1050 questions
  Loaded V1_num_p1 with 1050 questions
  Loaded V1_num_p3 with 1050 questions
  Loaded V1_alpha_p1 with 1050 questions
  Loaded V1_alpha_p5 with 1050 questions
  Loaded V1_alpha_p2 with 1050 questions
  Loaded V1_num_p5 with 1050 questions
  Loaded V2_alpha_p4 with 1050 questions
  Loaded V2_num_p5 with 1050 questions
  Loaded V2_alpha_p3 with 1050 questions
  Loaded V2_num_p3 with 1050 questions
  Loaded V2_alpha_p1 with 1050 questions
  Loaded V2_alpha_p2 with 1050 questions
  Loaded V2_alpha_p5 with 1050 questions
  Loaded V2_num_p1 with 1050 questions
  Loaded V2_num_

In [3]:
# Check vector counts in each layer file
try:
    from pathlib import Path

    import numpy as np

    # Path to the by_layer directory
    by_layer_dir = Path("/home/sergio/projects/MATS-Project/results/by_layer_20250911_112921/by_layer")

    if not by_layer_dir.exists():
        print(f"Directory not found: {by_layer_dir}")
    else:
        # Get all layer directories
        layer_dirs = sorted([d for d in by_layer_dir.glob("layer_*") if d.is_dir()])
        print(f"Found {len(layer_dirs)} layer directories")

        # Track total vectors per layer and by experiment type
        layer_counts = {}
        experiment_counts = {}

        # Process each layer directory
        for layer_dir in layer_dirs:
            layer_idx = int(layer_dir.name.split('_')[1])
            layer_counts[layer_idx] = 0

            # Get all npz files in this layer
            npz_files = list(layer_dir.glob("*.npz"))

            # Process each npz file
            for npz_file in npz_files:
                exp_name = npz_file.stem

                # Load the npz data
                try:
                    data = np.load(npz_file)
                    if 'question_ids' in data and 'vectors' in data:
                        num_vectors = len(data['question_ids'])
                        layer_counts[layer_idx] += num_vectors

                        # Track by experiment type
                        if exp_name not in experiment_counts:
                            experiment_counts[exp_name] = {}
                        experiment_counts[exp_name][layer_idx] = num_vectors
                    else:
                        print(f"  Warning: Missing expected keys in {npz_file}")
                except Exception as e:
                    print(f"  Error loading {npz_file}: {e}")

        # Print summary of vector counts per layer
        print("\nVector counts per layer:")
        for layer_idx in sorted(layer_counts.keys()):
            print(f"Layer {layer_idx}: {layer_counts[layer_idx]} vectors")

        # Check if all layers have the same count
        counts = list(layer_counts.values())
        if len(set(counts)) == 1:
            print(f"\nAll layers have the same number of vectors: {counts[0]}")
        else:
            print(f"\nLayers have different vector counts. Min: {min(counts)}, Max: {max(counts)}")

        # Print counts for V1 and V2 experiments
        print("\nVector counts for V1 and V2 experiments (first layer only):")
        first_layer = min(layer_counts.keys())
        v1v2_experiments = {name: counts[first_layer] for name, counts in experiment_counts.items()
                           if name.startswith('V1') or name.startswith('V2')}

        for exp_name, count in sorted(v1v2_experiments.items()):
            print(f"{exp_name}: {count} vectors")

        # Calculate total for V1 and V2
        v1_total = sum(count for name, count in v1v2_experiments.items() if name.startswith('V1'))
        v2_total = sum(count for name, count in v1v2_experiments.items() if name.startswith('V2'))
        print(f"\nTotal V1: {v1_total} vectors")
        print(f"Total V2: {v2_total} vectors")
        print(f"Combined V1+V2: {v1_total + v2_total} vectors")

except Exception as e:
    print(f"Error checking vector counts: {e}")


Found 35 layer directories

Vector counts per layer:
Layer 0: 63000 vectors
Layer 1: 63000 vectors
Layer 2: 63000 vectors
Layer 3: 63000 vectors
Layer 4: 63000 vectors
Layer 5: 63000 vectors
Layer 6: 63000 vectors
Layer 7: 63000 vectors
Layer 8: 63000 vectors
Layer 9: 63000 vectors
Layer 10: 63000 vectors
Layer 11: 63000 vectors
Layer 12: 63000 vectors
Layer 13: 63000 vectors
Layer 14: 63000 vectors
Layer 15: 63000 vectors
Layer 16: 63000 vectors
Layer 17: 63000 vectors
Layer 18: 63000 vectors
Layer 19: 63000 vectors
Layer 20: 63000 vectors
Layer 21: 63000 vectors
Layer 22: 63000 vectors
Layer 23: 63000 vectors
Layer 24: 63000 vectors
Layer 25: 63000 vectors
Layer 26: 63000 vectors
Layer 27: 63000 vectors
Layer 28: 63000 vectors
Layer 29: 63000 vectors
Layer 30: 63000 vectors
Layer 31: 63000 vectors
Layer 32: 63000 vectors
Layer 33: 63000 vectors
Layer 34: 63000 vectors

All layers have the same number of vectors: 63000

Vector counts for V1 and V2 experiments (first layer only):
V1_al

### 3. Balance the Dataset

We'll balance the positive and negative classes across subject areas and prompt labels to avoid confounds.


In [4]:
import pandas as pd

# Path to the metrics directory
metrics_dir = Path('./results/metrics_analysis')

# Find all run directories
run_dirs = [d for d in metrics_dir.iterdir() if d.is_dir() and d.name.startswith('run_')]

if not run_dirs:
    print("No metrics data found. Please run the run_metrics_analysis.py script first.")
else:
    print(f"Found {len(run_dirs)} runs with metrics data:")
    for d in run_dirs:
        print(f"  - {d.name}")

    # Use the most recent run by default
    latest_run = sorted(run_dirs)[-1]
    print(f"\nUsing latest run: {latest_run.name}")

# Load the all_metrics.csv file from the latest run
try:
    metrics_file = latest_run / 'all_metrics.csv'
    metrics_df = pd.read_csv(metrics_file)
    print(f"Loaded {len(metrics_df)} metric records")

    # Display the first few rows and column names
    print("\nColumns in metrics_df:")
    print(metrics_df.columns.tolist())
    metrics_df.head()
except NameError:
    print("No metrics data available. Please run the run_metrics_analysis.py script first.")


Found 1 runs with metrics data:
  - run_20250911_062422

Using latest run: run_20250911_062422
Loaded 63000 metric records

Columns in metrics_df:
['run', 'experiment', 'form', 'label_type', 'permutation', 'id', 'question', 'answer', 'subject', 'difficulty', 'split', 'pred_label', 'canonical_label', 'score', 'ca_score', 'hedge_score', 'canonical_probs', 'canonical_probs_norm']


In [5]:
try:
    # Filter metrics_df for forms V1 and V2 and split 'train'
    v1v2_metrics = metrics_df[(metrics_df['form'].isin(['V1', 'V2'])) &
                             (metrics_df['split'] == 'train')]

    print(f"Found {len(v1v2_metrics)} examples for forms V1 and V2 in the train split")

    # Calculate the 25% and 75% quantiles of CA scores
    q25 = v1v2_metrics['ca_score'].quantile(0.25)
    q75 = v1v2_metrics['ca_score'].quantile(0.75)

    print(f"CA score quantiles - 25%: {q25:.4f}, 75%: {q75:.4f}")

    # Create positive and negative classes based on CA scores
    positive_class = v1v2_metrics[v1v2_metrics['ca_score'] >= q75]
    negative_class = v1v2_metrics[v1v2_metrics['ca_score'] <= q25]

    print(f"Positive class (high CA): {len(positive_class)} examples")
    print(f"Negative class (low CA): {len(negative_class)} examples")

    # Check distribution across subject areas and prompt forms
    print("\nDistribution across forms:")
    print(pd.concat([
        positive_class['form'].value_counts().rename('Positive'),
        negative_class['form'].value_counts().rename('Negative')
    ], axis=1))

    print("\nDistribution across subjects (top 5):")
    print(pd.concat([
        positive_class['subject'].value_counts().head().rename('Positive'),
        negative_class['subject'].value_counts().head().rename('Negative')
    ], axis=1))

    print("\nDistribution across label types:")
    print(pd.concat([
        positive_class['label_type'].value_counts().rename('Positive'),
        negative_class['label_type'].value_counts().rename('Negative')
    ], axis=1))

except NameError:
    print("metrics_df not available. Please run the previous cells to load the metrics data.")


Found 16800 examples for forms V1 and V2 in the train split
CA score quantiles - 25%: 0.0005, 75%: 0.4686
Positive class (high CA): 4200 examples
Negative class (low CA): 4201 examples

Distribution across forms:
      Positive  Negative
form                    
V1        2298      1965
V2        1902      2236

Distribution across subjects (top 5):
             Positive  Negative
subject                        
Biology         413.0       NaN
Psychology      390.0       NaN
Physics         386.0     319.0
Medicine        383.0       NaN
History         335.0       NaN
Engineering       NaN     424.0
Chemistry         NaN     384.0
Computing         NaN     351.0
Earth             NaN     333.0

Distribution across label types:
            Positive  Negative
label_type                    
num             2168      1689
alpha           2032      2512


In [6]:
def balance_dataset(positive_df, negative_df, balance_cols=['subject', 'form', 'label_type']):
    """
    Balance positive and negative classes across specified columns.
    
    Args:
        positive_df: DataFrame containing positive class examples
        negative_df: DataFrame containing negative class examples
        balance_cols: List of columns to balance across
        
    Returns:
        Tuple of (balanced_positive_df, balanced_negative_df)
    """
    # Create a combined groupby across all balance columns
    pos_counts = positive_df.groupby(balance_cols).size()
    neg_counts = negative_df.groupby(balance_cols).size()

    # Find common groups
    common_groups = set(pos_counts.index).intersection(set(neg_counts.index))

    # Initialize empty DataFrames for balanced data
    balanced_pos = pd.DataFrame()
    balanced_neg = pd.DataFrame()

    # For each common group, take the minimum count from both classes
    for group in common_groups:
        # Convert group to tuple if it's not already
        group_tuple = group if isinstance(group, tuple) else (group,)

        # Create filter conditions
        filter_conditions = {col: val for col, val in zip(balance_cols, group_tuple)}

        # Filter the DataFrames
        pos_group = positive_df
        neg_group = negative_df

        for col, val in filter_conditions.items():
            pos_group = pos_group[pos_group[col] == val]
            neg_group = neg_group[neg_group[col] == val]

        # Get counts
        pos_count = len(pos_group)
        neg_count = len(neg_group)
        min_count = min(pos_count, neg_count)

        if min_count > 0:
            # Sample min_count examples from each group
            balanced_pos = pd.concat([balanced_pos, pos_group.sample(min_count, random_state=42)])
            balanced_neg = pd.concat([balanced_neg, neg_group.sample(min_count, random_state=42)])

    return balanced_pos, balanced_neg

try:
    # Balance the dataset
    balanced_positive, balanced_negative = balance_dataset(positive_class, negative_class)

    print(f"Balanced positive class: {len(balanced_positive)} examples")
    print(f"Balanced negative class: {len(balanced_negative)} examples")

    # Check the balanced distribution
    print("\nBalanced distribution across forms:")
    print(pd.concat([
        balanced_positive['form'].value_counts().rename('Positive'),
        balanced_negative['form'].value_counts().rename('Negative')
    ], axis=1))

    print("\nBalanced distribution across label types:")
    print(pd.concat([
        balanced_positive['label_type'].value_counts().rename('Positive'),
        balanced_negative['label_type'].value_counts().rename('Negative')
    ], axis=1))

    # Create a combined dataset with labels
    balanced_positive['class'] = 1  # High CA (abstention)
    balanced_negative['class'] = 0  # Low CA (non-abstention)

    balanced_dataset = pd.concat([balanced_positive, balanced_negative])
    print(f"\nFinal balanced dataset: {len(balanced_dataset)} examples")

    # Shuffle the dataset
    balanced_dataset = balanced_dataset.sample(frac=1, random_state=42).reset_index(drop=True)

except NameError:
    print("positive_class or negative_class not available. Please run the previous cells.")


Balanced positive class: 3388 examples
Balanced negative class: 3388 examples

Balanced distribution across forms:
      Positive  Negative
form                    
V1        1759      1759
V2        1629      1629

Balanced distribution across label types:
            Positive  Negative
label_type                    
alpha           1889      1889
num             1499      1499

Final balanced dataset: 6776 examples


### 4. Extract Residual Stream Vectors

Now we'll extract the residual stream vectors for each example in our balanced dataset.


In [7]:
def extract_residual_vectors(dataset, residual_data):
    """
    Extract residual stream vectors for examples in the dataset.
    Skip layers that already have saved results.
    
    Args:
        dataset: DataFrame containing examples with question IDs, form, and label_type
        residual_data: Dictionary mapping layer indices to dictionaries of experiment data
        
    Returns:
        Dictionary mapping layer indices to matrices H_pos and H_neg
    """
    import glob
    import pickle
    from datetime import datetime
    from pathlib import Path

    # Create save directory if it doesn't exist
    save_dir = Path("./results/abstention_direction")
    save_dir.mkdir(parents=True, exist_ok=True)

    # Create timestamp for filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Check for existing layer files
    existing_layer_files = glob.glob(str(save_dir / "residual_vectors_*_layer_*.pkl"))
    existing_layers = set()
    for file_path in existing_layer_files:
        try:
            layer_part = file_path.split("_layer_")[1]
            layer_num = int(layer_part.split(".pkl")[0])
            existing_layers.add(layer_num)
        except (IndexError, ValueError):
            continue

    print(f"Found {len(existing_layers)} layers with existing results: {sorted(existing_layers)}")

    # Initialize dictionaries to store residual vectors for each layer
    result = {}

    # Sort layer indices in reverse order (start from last layer)
    layer_indices = sorted(residual_data.keys(), reverse=True)
    layers_to_process = [layer for layer in layer_indices if layer not in existing_layers]
    total_layers = len(layer_indices)
    print(f"Processing {len(layers_to_process)} out of {total_layers} layers in reverse order...")

    # Process each layer
    for i, layer_idx in enumerate(layers_to_process):
        layer_data = residual_data[layer_idx]
        print(f"Extracting vectors for layer {layer_idx} ({i+1}/{len(layers_to_process)})...")

        # Initialize positive and negative vectors for this layer
        H_pos = []  # For high CA examples (class 1)
        H_neg = []  # For low CA examples (class 0)
        pos_question_ids = []
        neg_question_ids = []

        # Process each experiment in this layer
        for exp_name, exp_data in layer_data.items():
            # Extract form and label_type from experiment name
            parts = exp_name.split("_")
            if len(parts) < 2:
                continue

            form = parts[0]
            label_type = parts[1]

            # Get question IDs and vectors
            question_ids = exp_data['question_ids']
            vectors = exp_data['vectors']

            # Match with our dataset
            for i, q_id in enumerate(question_ids):
                # Find matching rows in our dataset
                matching_rows = dataset[(dataset['id'] == q_id) &
                                       (dataset['form'] == form) &
                                       (dataset['label_type'] == label_type)]

                if len(matching_rows) > 0:
                    # Get the class label (1 for positive, 0 for negative)
                    class_label = matching_rows['class'].values[0]

                    # Get the vector for this question
                    vector = vectors[i]

                    # Add to the appropriate list
                    if class_label == 1:
                        H_pos.append(vector)
                        pos_question_ids.append(q_id)
                    else:
                        H_neg.append(vector)
                        neg_question_ids.append(q_id)

        # Convert lists to numpy arrays if they're not empty
        if H_pos and H_neg:
            result[layer_idx] = {
                'positive': np.array(H_pos),
                'negative': np.array(H_neg),
                'pos_question_ids': pos_question_ids,
                'neg_question_ids': neg_question_ids
            }
            print(f"  Layer {layer_idx}: {len(H_pos)} positive examples, {len(H_neg)} negative examples")

            # Save data for this layer immediately
            save_path = save_dir / f"residual_vectors_{timestamp}_layer_{layer_idx}.pkl"

            # Prepare partial results dictionary with current data
            partial_vectors = {
                'positive': {layer: data['positive'] for layer, data in result.items()},
                'negative': {layer: data['negative'] for layer, data in result.items()}
            }

            save_data = {
                'residual_vectors': partial_vectors,
                'residual_vectors_by_layer': result,
                'timestamp': timestamp,
                'completed_layers': list(result.keys()),
                'forms': dataset['form'].unique().tolist() if 'form' in dataset.columns else ['V1', 'V2']
            }

            try:
                with open(save_path, 'wb') as f:
                    pickle.dump(save_data, f)
                print(f"  Saved data for layer {layer_idx} to {save_path}")
            except Exception as e:
                print(f"  Error saving layer data: {e}")

    return result

try:
    # Check if we have residual data and balanced dataset
    if 'residual_data' in locals() and 'balanced_dataset' in locals():
        print("Processing residual vectors layer by layer to save memory...")

        import glob
        import pickle
        from datetime import datetime
        from pathlib import Path

        # Create save directory if it doesn't exist
        save_dir = Path("./results/abstention_direction")
        save_dir.mkdir(parents=True, exist_ok=True)

        # Create timestamp for filenames
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # Check for existing layer files
        existing_layer_files = glob.glob(str(save_dir / "residual_vectors_*_layer_*.pkl"))
        existing_layers = set()
        for file_path in existing_layer_files:
            try:
                layer_part = file_path.split("_layer_")[1]
                layer_num = int(layer_part.split(".pkl")[0])
                existing_layers.add(layer_num)
            except (IndexError, ValueError):
                continue

        print(f"Found {len(existing_layers)} layers with existing results: {sorted(existing_layers)}")

        # Sort layer indices in reverse order (start from last layer)
        layer_indices = sorted(residual_data.keys(), reverse=True)
        layers_to_process = [layer for layer in layer_indices if layer not in existing_layers]
        total_layers = len(layer_indices)
        print(f"Processing {len(layers_to_process)} out of {total_layers} layers in reverse order...")

        # Track processed layers for summary
        processed_layers = []

        # Process each layer independently to save memory
        for i, layer_idx in enumerate(layers_to_process):
            print(f"Processing layer {layer_idx} ({i+1}/{len(layers_to_process)})...")

            # Get layer data
            layer_data = residual_data[layer_idx]

            # Initialize positive and negative vectors for this layer
            H_pos = []  # For high CA examples (class 1)
            H_neg = []  # For low CA examples (class 0)
            pos_question_ids = []
            neg_question_ids = []

            # Process each experiment in this layer
            for exp_name, exp_data in layer_data.items():
                # Extract form and label_type from experiment name
                parts = exp_name.split("_")
                if len(parts) < 2:
                    continue

                form = parts[0]
                label_type = parts[1]

                # Get question IDs and vectors
                question_ids = exp_data['question_ids']
                vectors = exp_data['vectors']

                # Match with our dataset
                for j, q_id in enumerate(question_ids):
                    # Find matching rows in our dataset
                    matching_rows = balanced_dataset[(balanced_dataset['id'] == q_id) &
                                           (balanced_dataset['form'] == form) &
                                           (balanced_dataset['label_type'] == label_type)]

                    if len(matching_rows) > 0:
                        # Get the class label (1 for positive, 0 for negative)
                        class_label = matching_rows['class'].values[0]

                        # Get the vector for this question
                        vector = vectors[j]

                        # Add to the appropriate list
                        if class_label == 1:
                            H_pos.append(vector)
                            pos_question_ids.append(q_id)
                        else:
                            H_neg.append(vector)
                            neg_question_ids.append(q_id)

            # Convert lists to numpy arrays if they're not empty
            if H_pos and H_neg:
                # Create layer result dictionary
                layer_result = {
                    'positive': np.array(H_pos),
                    'negative': np.array(H_neg),
                    'pos_question_ids': pos_question_ids,
                    'neg_question_ids': neg_question_ids
                }

                print(f"  Layer {layer_idx}: {len(H_pos)} positive examples, {len(H_neg)} negative examples")
                processed_layers.append(layer_idx)

                # Save data for this layer immediately
                save_path = save_dir / f"residual_vectors_{timestamp}_layer_{layer_idx}.pkl"

                # Prepare save data for just this layer
                save_data = {
                    'residual_vectors_by_layer': {layer_idx: layer_result},
                    'timestamp': timestamp,
                    'layer': layer_idx,
                    'forms': balanced_dataset['form'].unique().tolist() if 'form' in balanced_dataset.columns else ['V1', 'V2']
                }

                try:
                    with open(save_path, 'wb') as f:
                        pickle.dump(save_data, f)
                    print(f"  Saved data for layer {layer_idx} to {save_path}")
                except Exception as e:
                    print(f"  Error saving layer data: {e}")

                # Clear memory
                del layer_result
                del H_pos
                del H_neg
                del pos_question_ids
                del neg_question_ids

        # After processing all layers, load a sample layer to demonstrate it worked
        if processed_layers or existing_layers:
            all_available_layers = processed_layers + list(existing_layers)
            if all_available_layers:
                sample_layer = min(all_available_layers)
                sample_files = glob.glob(str(save_dir / f"residual_vectors_*_layer_{sample_layer}.pkl"))

                if sample_files:
                    try:
                        with open(sample_files[0], 'rb') as f:
                            sample_data = pickle.load(f)
                            layer_data = sample_data['residual_vectors_by_layer'][sample_layer]

                            print(f"\nSample layer {sample_layer} details:")
                            print(f"  Positive examples: {layer_data['positive'].shape}")
                            print(f"  Negative examples: {layer_data['negative'].shape}")
                            print(f"  Vector dimension: {layer_data['positive'][0].shape}")
                            print("\nProcessing complete. Data saved layer by layer to avoid memory issues.")
                            print(f"Total layers processed: {len(processed_layers)}")
                            print(f"Total layers available: {len(all_available_layers)}")
                            print("\nTo load all layers for analysis, use:")
                            print("residual_vectors = load_all_residual_vectors('./results/abstention_direction')")
                    except Exception as e:
                        print(f"Error loading sample layer: {e}")
            else:
                print("No layers were processed or found.")
        else:
            print("No matching vectors found in the residual data")
    else:
        print("Residual data or balanced dataset not available")
except Exception as e:
    print(f"Error processing residual vectors: {e}")


Processing residual vectors layer by layer to save memory...
Found 30 layers with existing results: [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
Processing 5 out of 35 layers in reverse order...
Processing layer 4 (1/5)...
  Layer 4: 6855 positive examples, 6785 negative examples
  Saved data for layer 4 to results/abstention_direction/residual_vectors_20250913_012645_layer_4.pkl
Processing layer 3 (2/5)...
  Layer 3: 6855 positive examples, 6785 negative examples
  Saved data for layer 3 to results/abstention_direction/residual_vectors_20250913_012645_layer_3.pkl
Processing layer 2 (3/5)...
  Layer 2: 6855 positive examples, 6785 negative examples
  Saved data for layer 2 to results/abstention_direction/residual_vectors_20250913_012645_layer_2.pkl
Processing layer 1 (4/5)...
  Layer 1: 6855 positive examples, 6785 negative examples
  Saved data for layer 1 to results/abstention_direction/residual_vectors_20250913_0

In [8]:
def load_all_residual_vectors(directory_path):
    """
    Load all saved residual vectors from the specified directory.
    
    Args:
        directory_path: Path to the directory containing saved residual vector files
        
    Returns:
        Dictionary with 'positive' and 'negative' keys, each containing layer-indexed vectors
    """
    import glob
    import pickle
    from pathlib import Path

    # Convert to Path object if it's a string
    directory = Path(directory_path)

    # Find all layer files
    layer_files = glob.glob(str(directory / "residual_vectors_*_layer_*.pkl"))

    # Initialize result dictionaries
    result = {'positive': {}, 'negative': {}}
    loaded_layers = []

    print(f"Found {len(layer_files)} layer files")

    # Load each file
    for file_path in sorted(layer_files):
        try:
            with open(file_path, 'rb') as f:
                data = pickle.load(f)

                # Extract layer data
                if 'residual_vectors_by_layer' in data:
                    for layer, layer_data in data['residual_vectors_by_layer'].items():
                        if 'positive' in layer_data and 'negative' in layer_data:
                            result['positive'][layer] = layer_data['positive']
                            result['negative'][layer] = layer_data['negative']
                            loaded_layers.append(layer)
                            print(f"Loaded layer {layer} from {Path(file_path).name}")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    print(f"Successfully loaded data for {len(loaded_layers)} layers: {sorted(loaded_layers)}")
    return result


In [None]:
# Example usage:
# After processing all layers, you can load them for analysis with:
# residual_vectors = load_all_residual_vectors('./results/abstention_direction')
#
# This will give you the same structure as before, but without memory issues:
# residual_vectors = {
#     'positive': {layer_idx: positive_vectors_array, ...},
#     'negative': {layer_idx: negative_vectors_array, ...}
# }


: 