# Brije: Cognitive Action Detection - Full Pipeline (Google Colab)

Complete pipeline for training **binary cognitive action probes across all layers** on Gemma 3 4B.

**This notebook will:**
1. ✅ Clone the Brije repository
2. ✅ Install all dependencies
3. ✅ Capture activations from Gemma 3 4B layers 4-28 (~3-4 hours) using batch saving
4. ✅ Train 45 binary probes per layer (1,125 total probes) (~8-12 hours)
5. ✅ Compare performance across layers
6. ✅ Test with multi-probe inference
7. ✅ Download trained models to Google Drive

**Requirements:**
- Google Colab with A100 GPU (40GB VRAM recommended)
- Runtime: ~12-16 hours total (can run in stages)

**Dataset:** 31,500 cognitive action examples across 45 actions

**Architecture:** One-vs-rest binary classification, 45 probes × 25 layers = 1,125 total probes

## 1️⃣ Check GPU and Setup Runtime

In [1]:
# Check GPU availability
!nvidia-smi

import torch
print("\n" + "="*60)
print("GPU INFORMATION")
print("="*60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️  WARNING: No GPU detected! This will be very slow on CPU.")
print("="*60)

Fri Oct 10 00:41:49 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   37C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## 2️⃣ Clone Repository and Install Dependencies

In [2]:
import os
import sys

# Clone the repository
repo_url = "https://github.com/ChuloIva/brije.git"
repo_name = "brije"

if not os.path.exists(repo_name):
    print("📥 Cloning Brije repository...")
    !git clone {repo_url}
    print("✅ Repository cloned successfully!")
else:
    print("✅ Repository already exists")
    print("🔄 Pulling latest changes...")
    !cd {repo_name} && git pull

# Change to repo directory
os.chdir(repo_name)
print(f"\n📁 Current directory: {os.getcwd()}")

📥 Cloning Brije repository...
Cloning into 'brije'...
remote: Enumerating objects: 168, done.[K
remote: Counting objects: 100% (168/168), done.[K
remote: Compressing objects: 100% (106/106), done.[K
remote: Total 168 (delta 71), reused 141 (delta 44), pack-reused 0 (from 0)[K
Receiving objects: 100% (168/168), 9.98 MiB | 7.21 MiB/s, done.
Resolving deltas: 100% (71/71), done.
✅ Repository cloned successfully!

📁 Current directory: /content/brije


In [3]:
# Install dependencies
print("📦 Installing dependencies...\n")
!pip install -q torch transformers h5py scikit-learn tqdm matplotlib seaborn

# Clone and install nnsight
nnsight_dir = "third_party/nnsight"
nnsight_repo = "https://github.com/ndif-team/nnsight"

print("\n📦 Setting up nnsight...")
if not os.path.exists(nnsight_dir) or not os.listdir(nnsight_dir):
    print("   Cloning nnsight repository...")
    # Create third_party directory if it doesn't exist
    os.makedirs("third_party", exist_ok=True)
    # Clone nnsight
    !git clone {nnsight_repo} {nnsight_dir}
    print("   ✅ nnsight repository cloned")
else:
    print("   ✅ nnsight repository already exists")

# Install nnsight
print("   Installing nnsight...")
!pip install -q -e {nnsight_dir}

print("\n✅ All dependencies installed!")

📦 Installing dependencies...


📦 Setting up nnsight...
   Cloning nnsight repository...
Cloning into 'third_party/nnsight'...
remote: Enumerating objects: 13148, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (78/78), done.[K
remote: Total 13148 (delta 85), reused 74 (delta 61), pack-reused 13003 (from 3)[K
Receiving objects: 100% (13148/13148), 64.31 MiB | 17.15 MiB/s, done.
Resolving deltas: 100% (8276/8276), done.
   ✅ nnsight repository cloned
   Installing nnsight...
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m74.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.6/59.6 kB[0m [31m6.2 MB/s[0m eta 

In [4]:
from google.colab import drive
drive.mount('/content/drive')

# Create directories in Google Drive for outputs
drive_output_dir = '/content/drive/MyDrive/brije_outputs'
os.makedirs(drive_output_dir, exist_ok=True)
os.makedirs(f"{drive_output_dir}/activations", exist_ok=True)
os.makedirs(f"{drive_output_dir}/probes", exist_ok=True)

print(f"✅ Outputs will be saved to: {drive_output_dir}")

Mounted at /content/drive
✅ Outputs will be saved to: /content/drive/MyDrive/brije_outputs


## 4️⃣ Verify Dataset

In [4]:
# Check if dataset exists
import glob
import os

dataset_path = "third_party/datagen/generated_data"
datasets = glob.glob(f"{dataset_path}/*.jsonl")

print("="*60)
print("AVAILABLE DATASETS")
print("="*60)
for ds in datasets:
    size = os.path.getsize(ds) / 1e6
    print(f"  {os.path.basename(ds)} ({size:.2f} MB)")
print("="*60)

# Use the stratified combined dataset (31.5k examples)
dataset_file = None
for ds in datasets:
    if 'stratified_combined' in ds or '31500' in ds:
        dataset_file = ds
        break

if not dataset_file:
    # Use any available dataset
    dataset_file = datasets[0] if datasets else None

if dataset_file:
    print(f"\n✅ Using dataset: {os.path.basename(dataset_file)}")
else:
    print("\n⚠️  No dataset found! You may need to generate data first.")
    print("See: third_party/datagen/README.md")

AVAILABLE DATASETS
  cognitive_actions_7k_final_1759233061.jsonl (14.69 MB)
  stratified_4500_1759788994.jsonl (3.18 MB)
  stratified_9000_1759769375.jsonl (6.33 MB)
  stratified_combined_31500.jsonl (22.21 MB)
  stratified_18000_1759809514.jsonl (12.70 MB)

✅ Using dataset: stratified_combined_31500.jsonl


## 5️⃣ Configure Pipeline Parameters

In [5]:
CONFIG = {
    'model': 'google/gemma-3-4b-it',
    'dataset': dataset_file,
    'layer_start': 1,  # Start capturing from layer 4
    'layer_end': 20,   # End at layer 28 (inclusive)
    'probe_type': 'linear',  # 'linear' or 'multihead'

    # 🚀 Parallel Training Configuration (OPTIMIZED for A100 40GB)
    'use_parallel_training': True,  # Enable parallel training
    'num_workers': 45,  # Train 45 probes simultaneously
    'batch_size': 32,  # Large batch size for better GPU utilization
    'pin_activations_to_gpu': True,  # Pin activations to GPU memory

    'epochs': 50,  # Max epochs with early stopping
    'learning_rate': 0.0005,  # 5e-4
    'weight_decay': 0.001,  # 1e-3
    'early_stopping_patience': 10,
    'use_scheduler': True,
    'device': 'auto',
    'max_examples': None,  # None = use all examples
    'batch_save': True,
    'batch_save_size': 1000,
}

# Generate layer list
CONFIG['layers_to_capture'] = list(range(CONFIG['layer_start'], CONFIG['layer_end'] + 1))
num_layers = len(CONFIG['layers_to_capture'])
total_probes = num_layers * 45

print("="*70)
print("🚀 PARALLEL TRAINING PIPELINE CONFIGURATION")
print("="*70)
for key, value in CONFIG.items():
    if key != 'layers_to_capture':
        print(f"  {key:25s}: {value}")
print(f"  {'layers_to_capture':25s}: {CONFIG['layer_start']}-{CONFIG['layer_end']} ({num_layers} layers)")
print(f"  {'total_probes':25s}: {total_probes} (45 per layer)")
print("="*70)
print("\n🚀 Parallel Training Benefits:")
print(f"  • {CONFIG['num_workers']}x faster training")
print(f"  • Large batch size ({CONFIG['batch_size']}) for GPU efficiency")
print("  • Activations pinned to GPU memory")
print("  • Expected time: ~2-3 hours (vs 8-12 hours sequential!)")
print("="*70)

🚀 PARALLEL TRAINING PIPELINE CONFIGURATION
  model                    : google/gemma-3-4b-it
  dataset                  : third_party/datagen/generated_data/stratified_combined_31500.jsonl
  layer_start              : 1
  layer_end                : 20
  probe_type               : linear
  use_parallel_training    : True
  num_workers              : 45
  batch_size               : 32
  pin_activations_to_gpu   : True
  epochs                   : 50
  learning_rate            : 0.0005
  weight_decay             : 0.001
  early_stopping_patience  : 10
  use_scheduler            : True
  device                   : auto
  max_examples             : None
  batch_save               : True
  batch_save_size          : 1000
  layers_to_capture        : 1-20 (20 layers)
  total_probes             : 900 (45 per layer)

🚀 Parallel Training Benefits:
  • 45x faster training
  • Large batch size (32) for GPU efficiency
  • Activations pinned to GPU memory
  • Expected time: ~2-3 hours (vs 8-12 hours

In [7]:
# Login to Hugging Face Hub
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## 6️⃣ Step 1: Capture Activations from Layers 4-28 (🚀 ~10-15 minutes with single-pass!)

This extracts hidden states from Gemma 3 4B at 25 layers (4-28) using **optimized single-pass capture**.

**⏰ Expected time:** ~10-15 minutes for full dataset (31.5k examples, ALL 25 layers simultaneously!)

**💾 Memory:** ~12-16 GB VRAM peak usage (single-pass with periodic cleanup)

**🚀 Optimization:** Captures all layers in ONE forward pass (25x faster than old method!)

**💡 Note:** Activations are saved per layer. If interrupted, you can resume from where it stopped.

In [20]:
import time

print("\n" + "="*60)
print("STEP 1: CAPTURING ACTIVATIONS (🚀 OPTIMIZED SINGLE-PASS MODE)")
print("="*60)
print(f"Model: {CONFIG['model']}")
print(f"Layers: {CONFIG['layer_start']}-{CONFIG['layer_end']} ({len(CONFIG['layers_to_capture'])} layers)")
print(f"Dataset: {os.path.basename(CONFIG['dataset'])}")
print(f"Mode: Single-pass optimization (25x faster!)")
print(f"Batch size: {CONFIG['batch_save_size']}")
print("\n⏰ This will take ~10-15 minutes (vs 3-4 hours with old method!).")
print("💡 All layers captured simultaneously in ONE forward pass!\n")

start_time = time.time()

# Build command with --single-pass flag for optimized capture
cmd = [
    'python', 'src/probes/capture_activations.py',
    '--dataset', CONFIG['dataset'],
    '--output-dir', 'data/activations',
    '--model', CONFIG['model'],
    '--layers', *[str(l) for l in CONFIG['layers_to_capture']],
    '--device', CONFIG['device'],
    '--format', 'hdf5',
    '--single-pass',  # 🚀 OPTIMIZED: Capture all layers in one pass!
    '--batch-size', str(CONFIG['batch_save_size'])
]

if CONFIG['max_examples']:
    cmd.extend(['--max-examples', str(CONFIG['max_examples'])])

# Run capture
!{' '.join(cmd)}

elapsed = time.time() - start_time
print(f"\n✅ Activation capture completed in {elapsed/60:.1f} minutes")
print(f"   Speedup: ~{(3.5*60)/elapsed:.1f}x faster than old method!")

# Copy to Google Drive for backup
print("\n📥 Backing up activations to Google Drive...")
!cp -r data/activations/* {drive_output_dir}/activations/
print("✅ Backup complete!")

# Show captured layers
import glob
activation_files = glob.glob('data/activations/layer_*_activations.h5')
print(f"\n📊 Captured {len(activation_files)} layer files:")
for f in sorted(activation_files)[:5]:
    print(f"  • {os.path.basename(f)}")
if len(activation_files) > 5:
    print(f"  ... and {len(activation_files) - 5} more")


STEP 1: CAPTURING ACTIVATIONS (🚀 OPTIMIZED SINGLE-PASS MODE)
Model: google/gemma-3-4b-it
Layers: 20-30 (11 layers)
Dataset: stratified_combined_31500.jsonl
Mode: Single-pass optimization (25x faster!)
Batch size: 1000

⏰ This will take ~10-15 minutes (vs 3-4 hours with old method!).
💡 All layers captured simultaneously in ONE forward pass!

2025-10-09 21:38:54.946385: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-09 21:38:54.963918: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760045934.984880   35499 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has a

## 7️⃣ Step 2: Train Binary Probes for All Layers
## This can be done with low-gpu and will take aroun 1-2 hrs

Train 45 binary probes per layer (1,125 total probes) using one-vs-rest strategy.

**⏰ Expected time:** 8-12 hours for all layers
- ~20-30 minutes per layer
- 25 layers total

**🎯 Expected performance:** AUC-ROC 0.85-0.95 per probe (varies by layer)

**💡 Note:** Training happens sequentially per layer. Progress is saved after each layer completes.

In [None]:
import json
import time

print("\n" + "="*70)
print("STEP 2: 🚀 PARALLEL TRAINING OF BINARY PROBES")
print("="*70)
print(f"Layers: {CONFIG['layer_start']}-{CONFIG['layer_end']} ({len(CONFIG['layers_to_capture'])} layers)")
print(f"Probes per layer: 45")
print(f"Total probes: {len(CONFIG['layers_to_capture']) * 45}")
print(f"\n🚀 Parallel Training Settings:")
print(f"  Workers: {CONFIG['num_workers']} (train {CONFIG['num_workers']} probes simultaneously)")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Pin to GPU: {CONFIG['pin_activations_to_gpu']}")
print(f"\n⏰ This will take ~2-3 hours (8x faster than sequential!)")
print("💡 Each layer's probes are trained in parallel, then saved.\n")

overall_start = time.time()
layer_results = []

for layer_idx in CONFIG['layers_to_capture']:
    layer_start = time.time()

    print(f"\n{'='*70}")
    print(f"Training Layer {layer_idx} ({CONFIG['layers_to_capture'].index(layer_idx) + 1}/{len(CONFIG['layers_to_capture'])})")
    print(f"🚀 Using {CONFIG['num_workers']} parallel workers")
    print(f"{'='*70}")

    # Build command
    activation_file = f"data/activations/layer_{layer_idx}_activations.h5"
    output_dir = f"data/probes_binary/layer_{layer_idx}"

    # Check if activations exist
    if not os.path.exists(activation_file):
        print(f"⚠️  Activation file not found: {activation_file}")
        print(f"   Skipping layer {layer_idx}")
        continue

    # Use parallel training script
    cmd = [
        'python', 'src/probes/train_binary_probes_parallel.py',
        '--activations', activation_file,
        '--output-dir', output_dir,
        '--model-type', CONFIG['probe_type'],
        '--batch-size', str(CONFIG['batch_size']),
        '--epochs', str(CONFIG['epochs']),
        '--lr', str(CONFIG['learning_rate']),
        '--weight-decay', str(CONFIG['weight_decay']),
        '--early-stopping-patience', str(CONFIG['early_stopping_patience']),
        '--device', CONFIG['device'],
        '--num-workers', str(CONFIG['num_workers'])
    ]

    # Add scheduler flag
    if not CONFIG.get('use_scheduler', True):
        cmd.append('--no-scheduler')

    # Add GPU pinning flag
    if CONFIG['pin_activations_to_gpu']:
        cmd.append('--pin-activations-to-gpu')
    else:
        cmd.append('--no-pin-activations')

    # Run parallel training
    !{' '.join(cmd)}

    layer_elapsed = time.time() - layer_start

    # Load metrics for this layer
    metrics_file = f"{output_dir}/aggregate_metrics.json"
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r') as f:
            metrics = json.load(f)

        speedup = metrics.get('num_workers', 1)
        layer_results.append({
            'layer': layer_idx,
            'avg_auc': metrics['average_auc_roc'],
            'avg_f1': metrics['average_f1'],
            'avg_accuracy': metrics['average_accuracy'],
            'time_minutes': layer_elapsed / 60,
            'speedup': speedup
        })

        print(f"\n✅ Layer {layer_idx} complete in {layer_elapsed/60:.1f} minutes (🚀 {speedup}x speedup!)")
        print(f"   Avg AUC: {metrics['average_auc_roc']:.4f}, Avg F1: {metrics['average_f1']:.4f}")

    # Backup to Google Drive after each layer
    !cp -r {output_dir} {drive_output_dir}/probes_binary/

overall_elapsed = time.time() - overall_start
print(f"\n{'='*70}")
print(f"✅ ALL LAYERS COMPLETE!")
print(f"{'='*70}")
print(f"Total time: {overall_elapsed/3600:.2f} hours ({overall_elapsed/60:.1f} minutes)")
print(f"Trained {len(layer_results) * 45} probes across {len(layer_results)} layers")
print(f"🚀 Average speedup: {CONFIG['num_workers']}x faster than sequential!")
print(f"\nOutputs backed up to Google Drive: {drive_output_dir}/probes_binary/")

# Save layer summary
summary = {
    'total_layers': len(layer_results),
    'total_probes': len(layer_results) * 45,
    'total_time_hours': overall_elapsed / 3600,
    'parallel_training': True,
    'num_workers': CONFIG['num_workers'],
    'layer_results': layer_results,
    'config': {
        'batch_size': CONFIG['batch_size'],
        'epochs': CONFIG['epochs'],
        'learning_rate': CONFIG['learning_rate'],
        'weight_decay': CONFIG['weight_decay'],
        'early_stopping_patience': CONFIG['early_stopping_patience'],
        'use_scheduler': CONFIG.get('use_scheduler', True),
        'num_workers': CONFIG['num_workers']
    }
}

with open('data/probes_binary/training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nSummary saved to: data/probes_binary/training_summary.json")


STEP 2: 🚀 PARALLEL TRAINING OF BINARY PROBES
Layers: 1-20 (20 layers)
Probes per layer: 45
Total probes: 900

🚀 Parallel Training Settings:
  Workers: 45 (train 45 probes simultaneously)
  Batch size: 32
  Pin to GPU: True

⏰ This will take ~2-3 hours (8x faster than sequential!)
💡 Each layer's probes are trained in parallel, then saved.


Training Layer 1 (1/20)
🚀 Using 45 parallel workers
⚠️  Activation file not found: /content/drive/MyDrive/brije_outputs/activations/layer_1_activations.h5
   Skipping layer 1

Training Layer 2 (2/20)
🚀 Using 45 parallel workers
⚠️  Activation file not found: /content/drive/MyDrive/brije_outputs/activations/layer_2_activations.h5
   Skipping layer 2

Training Layer 3 (3/20)
🚀 Using 45 parallel workers
⚠️  Activation file not found: /content/drive/MyDrive/brije_outputs/activations/layer_3_activations.h5
   Skipping layer 3

Training Layer 4 (4/20)
🚀 Using 45 parallel workers
⚠️  Activation file not found: /content/drive/MyDrive/brije_outputs/activatio

NameError: name 'drive_output_dir' is not defined

In [None]:
import numpy as np

if layer_results:
    print("="*70)
    print("PARALLEL TRAINING PERFORMANCE SUMMARY")
    print("="*70)

    # Time savings
    sequential_time = overall_elapsed * CONFIG['num_workers']
    time_saved = sequential_time - overall_elapsed

    print(f"\n⏱️  Time Performance:")
    print(f"  Parallel time: {overall_elapsed/3600:.2f} hours")
    print(f"  Sequential estimate: {sequential_time/3600:.2f} hours")
    print(f"  Time saved: {time_saved/3600:.2f} hours! 🎉")
    print(f"  Speedup: {CONFIG['num_workers']}x")

    # Accuracy metrics
    avg_auc = np.mean([m['avg_auc'] for m in layer_results])
    best_layer = max(layer_results, key=lambda x: x['avg_auc'])

    print(f"\n📊 Accuracy Metrics:")
    print(f"  Average AUC: {avg_auc:.4f}")
    print(f"  Best layer: {best_layer['layer']} (AUC: {best_layer['avg_auc']:.4f})")
    print(f"  Total probes trained: {len(layer_results) * 45}")
    print("="*70)

## 8️⃣ View Training Results - Performance Across All Layers

Compare binary probe performance across different layers

## 8️⃣ View Training Results - Overall Performance Across Layers

Compare overall binary probe performance across different layers

## 9️⃣ Step 3: Test Multi-Probe Inference

Run all 45 binary probes from the **best performing layer** to detect cognitive actions.

## 8️⃣.5️⃣ Per-Action Layer Analysis - Find Best Layer for Each Cognitive Action

Analyze which layer performs best for EACH of the 45 cognitive actions

In [None]:
import json
import numpy as np
import pandas as pd
from collections import defaultdict

print("="*70)
print("PER-ACTION LAYER ANALYSIS")
print("="*70)
print("\nAnalyzing which layer is best for each cognitive action...\n")

# Collect per-action metrics across all layers
action_layer_performance = defaultdict(dict)  # {action_name: {layer: auc}}

for layer_idx in CONFIG['layers_to_capture']:
    metrics_file = f'data/probes_binary/layer_{layer_idx}/aggregate_metrics.json'
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r') as f:
            metrics = json.load(f)

        # Extract per-action metrics
        for action_metrics in metrics['per_action_metrics']:
            action_name = action_metrics['action']
            auc = action_metrics['auc_roc']
            f1 = action_metrics['f1']

            action_layer_performance[action_name][layer_idx] = {
                'auc': auc,
                'f1': f1
            }

# Find best layer for each action
action_best_layers = []
for action_name, layer_perfs in sorted(action_layer_performance.items()):
    # Find layer with highest AUC
    best_layer = max(layer_perfs.items(), key=lambda x: x[1]['auc'])
    layer_idx, perf = best_layer

    # Get performance range
    auc_scores = [p['auc'] for p in layer_perfs.values()]
    auc_range = max(auc_scores) - min(auc_scores)

    action_best_layers.append({
        'action': action_name,
        'best_layer': layer_idx,
        'best_auc': perf['auc'],
        'best_f1': perf['f1'],
        'auc_range': auc_range,
        'worst_auc': min(auc_scores),
        'layer_sensitivity': auc_range  # How much performance varies by layer
    })

# Convert to DataFrame for easier analysis
df = pd.DataFrame(action_best_layers)

print(f"✅ Analyzed {len(action_best_layers)} cognitive actions across {len(CONFIG['layers_to_capture'])} layers\n")

# Show distribution of best layers
print("="*70)
print("BEST LAYER DISTRIBUTION")
print("="*70)
layer_counts = df['best_layer'].value_counts().sort_index()
print("\nHow many actions are best detected at each layer:\n")
for layer, count in layer_counts.items():
    bar = "█" * count
    print(f"  Layer {layer:2d}: {count:2d} actions {bar}")

# Most common best layers
top_layers = layer_counts.head(5)
print(f"\n🏆 Most effective layers:")
for layer, count in top_layers.items():
    pct = (count / len(action_best_layers)) * 100
    print(f"   Layer {layer}: {count} actions ({pct:.1f}%)")

# Show actions grouped by best layer
print("\n" + "="*70)
print("ACTIONS GROUPED BY BEST LAYER")
print("="*70)
for layer in sorted(df['best_layer'].unique())[:10]:  # Show first 10 layers
    actions = df[df['best_layer'] == layer]['action'].tolist()
    if actions:
        print(f"\nLayer {layer} ({len(actions)} actions):")
        for action in actions[:5]:  # Show first 5 actions per layer
            auc = df[df['action'] == action]['best_auc'].values[0]
            print(f"  • {action:30s} (AUC: {auc:.4f})")
        if len(actions) > 5:
            print(f"  ... and {len(actions) - 5} more")

# Show most layer-sensitive actions (vary a lot by layer)
print("\n" + "="*70)
print("LAYER-SENSITIVE ACTIONS")
print("="*70)
print("\nActions where layer choice matters most:\n")
df_sorted = df.sort_values('layer_sensitivity', ascending=False)
print(f"{'Action':<30} {'Best Layer':<12} {'Best AUC':<10} {'AUC Range':<10}")
print("-" * 70)
for _, row in df_sorted.head(10).iterrows():
    print(f"{row['action']:<30} Layer {row['best_layer']:<6} {row['best_auc']:.4f}     {row['auc_range']:.4f}")

print("\n→ Large AUC range = layer choice is critical for this action")

# Show layer-robust actions (work well across all layers)
print("\n" + "="*70)
print("LAYER-ROBUST ACTIONS")
print("="*70)
print("\nActions that work well across all layers:\n")
print(f"{'Action':<30} {'Best Layer':<12} {'Best AUC':<10} {'AUC Range':<10}")
print("-" * 70)
for _, row in df_sorted.tail(10).iterrows():
    print(f"{row['action']:<30} Layer {row['best_layer']:<6} {row['best_auc']:.4f}     {row['auc_range']:.4f}")

print("\n→ Small AUC range = works well regardless of layer")

# Save detailed results
results = {
    'summary': {
        'total_actions': len(action_best_layers),
        'total_layers_tested': len(CONFIG['layers_to_capture']),
        'most_common_best_layer': int(layer_counts.idxmax()),
        'layer_distribution': {int(k): int(v) for k, v in layer_counts.items()}
    },
    'per_action_best_layers': action_best_layers
}

with open('data/probes_binary/per_action_layer_analysis.json', 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n✅ Detailed results saved to: data/probes_binary/per_action_layer_analysis.json")

## 8️⃣.6️⃣ Visualize Per-Action Layer Performance

Heatmap showing which layers are best for each cognitive action

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Create heatmap data: actions × layers
actions = sorted(action_layer_performance.keys())
layers = sorted(CONFIG['layers_to_capture'])

# Build matrix
heatmap_data = []
for action in actions:
    row = []
    for layer in layers:
        if layer in action_layer_performance[action]:
            row.append(action_layer_performance[action][layer]['auc'])
        else:
            row.append(np.nan)
    heatmap_data.append(row)

heatmap_data = np.array(heatmap_data)

# Create figure with multiple subplots
fig, axes = plt.subplots(2, 2, figsize=(20, 16))

# Plot 1: Full heatmap (all actions, all layers)
ax1 = axes[0, 0]
im1 = ax1.imshow(heatmap_data, aspect='auto', cmap='RdYlGn', vmin=0.5, vmax=1.0)
ax1.set_xticks(range(0, len(layers), 2))
ax1.set_xticklabels([layers[i] for i in range(0, len(layers), 2)], fontsize=8)
ax1.set_yticks(range(len(actions)))
ax1.set_yticklabels(actions, fontsize=6)
ax1.set_xlabel('Layer', fontsize=10)
ax1.set_ylabel('Cognitive Action', fontsize=10)
ax1.set_title('Per-Action AUC-ROC Across All Layers', fontsize=12, fontweight='bold')
plt.colorbar(im1, ax=ax1, label='AUC-ROC')

# Mark best layer for each action with a star
for i, action in enumerate(actions):
    best_layer_idx = df[df['action'] == action]['best_layer'].values[0]
    best_layer_pos = layers.index(best_layer_idx)
    ax1.plot(best_layer_pos, i, 'w*', markersize=4)

# Plot 2: Best layer distribution
ax2 = axes[0, 1]
layer_counts = df['best_layer'].value_counts().sort_index()
ax2.bar(layer_counts.index, layer_counts.values, color='steelblue', alpha=0.7)
ax2.set_xlabel('Layer', fontsize=10)
ax2.set_ylabel('Number of Actions', fontsize=10)
ax2.set_title('Distribution of Best Layers\n(How many actions peak at each layer)', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for layer, count in layer_counts.items():
    ax2.text(layer, count + 0.2, str(count), ha='center', fontsize=8)

# Plot 3: Layer sensitivity (how much does layer matter?)
ax3 = axes[1, 0]
df_sorted_sens = df.sort_values('layer_sensitivity', ascending=False)
y_pos = np.arange(len(df_sorted_sens.head(20)))
ax3.barh(y_pos, df_sorted_sens.head(20)['layer_sensitivity'], color='coral', alpha=0.7)
ax3.set_yticks(y_pos)
ax3.set_yticklabels(df_sorted_sens.head(20)['action'], fontsize=8)
ax3.set_xlabel('AUC Range (Best - Worst)', fontsize=10)
ax3.set_title('Top 20 Layer-Sensitive Actions\n(Layer choice matters most)', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3, axis='x')
ax3.invert_yaxis()

# Plot 4: Average performance by layer group
ax4 = axes[1, 1]
n = len(layers)
early_layers = layers[:n//3]
middle_layers = layers[n//3:2*n//3]
late_layers = layers[2*n//3:]

groups = ['Early\n(4-12)', 'Middle\n(13-20)', 'Late\n(21-28)']
layer_groups = [early_layers, middle_layers, late_layers]

# Calculate average AUC for each group
group_aucs = []
for layer_group in layer_groups:
    aucs = []
    for layer in layer_group:
        if layer in all_layer_metrics_dict:  # Need to create this
            aucs.append(all_layer_metrics_dict[layer]['avg_auc'])
    group_aucs.append(np.mean(aucs) if aucs else 0)

# Create dict for easy lookup
all_layer_metrics_dict = {m['layer']: m for m in all_layer_metrics}
group_aucs = []
for layer_group in layer_groups:
    aucs = [all_layer_metrics_dict[layer]['avg_auc'] for layer in layer_group if layer in all_layer_metrics_dict]
    group_aucs.append(np.mean(aucs) if aucs else 0)

bars = ax4.bar(groups, group_aucs, color=['lightblue', 'lightgreen', 'lightsalmon'], alpha=0.7, edgecolor='black')
ax4.set_ylabel('Average AUC-ROC', fontsize=10)
ax4.set_title('Performance by Layer Group\n(Early vs Middle vs Late)', fontsize=12, fontweight='bold')
ax4.set_ylim([min(group_aucs) - 0.02, max(group_aucs) + 0.02])
ax4.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, auc in zip(bars, group_aucs):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height + 0.005,
             f'{auc:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('data/probes_binary/per_action_layer_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("✅ Per-action layer analysis saved to: data/probes_binary/per_action_layer_analysis.png")

# Additional insight: Show examples of actions best at different layer stages
print("\n" + "="*70)
print("LAYER STAGE EXAMPLES")
print("="*70)

early_actions = df[df['best_layer'].isin(early_layers)].nlargest(5, 'best_auc')
middle_actions = df[df['best_layer'].isin(middle_layers)].nlargest(5, 'best_auc')
late_actions = df[df['best_layer'].isin(late_layers)].nlargest(5, 'best_auc')

print(f"\n📘 Best actions detected in EARLY layers ({min(early_layers)}-{max(early_layers)}):")
for _, row in early_actions.iterrows():
    print(f"   • {row['action']:30s} Layer {row['best_layer']}, AUC: {row['best_auc']:.4f}")

print(f"\n📗 Best actions detected in MIDDLE layers ({min(middle_layers)}-{max(middle_layers)}):")
for _, row in middle_actions.iterrows():
    print(f"   • {row['action']:30s} Layer {row['best_layer']}, AUC: {row['best_auc']:.4f}")

print(f"\n📕 Best actions detected in LATE layers ({min(late_layers)}-{max(late_layers)}):")
for _, row in late_actions.iterrows():
    print(f"   • {row['action']:30s} Layer {row['best_layer']}, AUC: {row['best_auc']:.4f}")

print("\n→ This reveals which cognitive processes are captured at different network depths!")

In [None]:
# Determine best layer for inference
if all_layer_metrics:
    best_layer_idx = max(all_layer_metrics, key=lambda x: x['avg_auc'])['layer']
else:
    best_layer_idx = CONFIG['layer_end']  # Default to last layer

print(f"Using probes from Layer {best_layer_idx} (best performing layer)\n")

# Test on sample texts
test_texts = [
    "After receiving feedback, she began reconsidering her initial approach to the problem.",
    "He was analyzing the data to find patterns and correlations between variables.",
    "They started generating creative ideas for solving the design challenge.",
    "She was evaluating different strategies to determine the most effective one.",
    "He tried to recall the specific details from the previous meeting."
]

print("="*60)
print("MULTI-PROBE INFERENCE EXAMPLES")
print("="*60)
print(f"\nRunning all 45 binary probes from Layer {best_layer_idx}...\n")

for i, text in enumerate(test_texts, 1):
    print(f"\n{'='*60}")
    print(f"Example {i}")
    print(f"{'='*60}")
    print(f"📝 Text: {text}\n")

    cmd = [
        'python', 'src/probes/multi_probe_inference.py',
        '--probes-dir', f'data/probes_binary/layer_{best_layer_idx}',
        '--model', CONFIG['model'],
        '--layer', str(best_layer_idx),
        '--text', f'"{text}"',
        '--top-k', '5',
        '--threshold', '0.1'
    ]

    !{' '.join(cmd)}

## 🔟 Visualize Performance Across Layers

In [None]:
import matplotlib.pyplot as plt
import json
import numpy as np

# Plot performance across layers
if all_layer_metrics:
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))

    layers = [m['layer'] for m in all_layer_metrics]
    auc_scores = [m['avg_auc'] for m in all_layer_metrics]
    f1_scores = [m['avg_f1'] for m in all_layer_metrics]
    acc_scores = [m['avg_accuracy'] for m in all_layer_metrics]

    # Plot 1: AUC-ROC across layers
    axes[0].plot(layers, auc_scores, 'b-o', linewidth=2, markersize=6, label='AUC-ROC')
    axes[0].axhline(y=np.mean(auc_scores), color='r', linestyle='--', alpha=0.5, label='Mean')
    axes[0].set_xlabel('Layer', fontsize=12)
    axes[0].set_ylabel('Average AUC-ROC', fontsize=12)
    axes[0].set_title('Binary Probe Performance Across Layers (AUC-ROC)', fontsize=14, fontweight='bold')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    axes[0].set_ylim([min(auc_scores) - 0.02, max(auc_scores) + 0.02])

    # Mark best layer
    best_idx = np.argmax(auc_scores)
    axes[0].annotate(f'Best: {layers[best_idx]}\n{auc_scores[best_idx]:.4f}',
                     xy=(layers[best_idx], auc_scores[best_idx]),
                     xytext=(10, 10), textcoords='offset points',
                     bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7),
                     arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))

    # Plot 2: All metrics comparison
    axes[1].plot(layers, auc_scores, 'b-o', label='AUC-ROC', linewidth=2, markersize=5)
    axes[1].plot(layers, f1_scores, 'g-s', label='F1 Score', linewidth=2, markersize=5)
    axes[1].plot(layers, acc_scores, 'r-^', label='Accuracy', linewidth=2, markersize=5)
    axes[1].set_xlabel('Layer', fontsize=12)
    axes[1].set_ylabel('Score', fontsize=12)
    axes[1].set_title('All Metrics Across Layers', fontsize=14, fontweight='bold')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    axes[1].set_ylim([0.5, 1.0])

    plt.tight_layout()
    plt.savefig('data/probes_binary/layer_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✅ Layer comparison plot saved to: data/probes_binary/layer_comparison.png")

    # Additional analysis: Early vs Middle vs Late layers
    n = len(layers)
    early = all_layer_metrics[:n//3]
    middle = all_layer_metrics[n//3:2*n//3]
    late = all_layer_metrics[2*n//3:]

    print("\n📊 Layer Group Analysis:")
    print(f"   Early layers ({layers[0]}-{layers[n//3-1]}):  AUC = {np.mean([m['avg_auc'] for m in early]):.4f}")
    print(f"   Middle layers ({layers[n//3]}-{layers[2*n//3-1]}): AUC = {np.mean([m['avg_auc'] for m in middle]):.4f}")
    print(f"   Late layers ({layers[2*n//3]}-{layers[-1]}):   AUC = {np.mean([m['avg_auc'] for m in late]):.4f}")

else:
    print("⚠️  No metrics available for plotting")

## 1️⃣1️⃣ Download Trained Binary Probes

Download all 45 trained binary probes and metrics for local use.

In [None]:
from google.colab import files
import shutil

# Create a zip file with all outputs
output_zip = 'brije_all_layers_binary_probes.zip'

print("📦 Creating download package...")
print("   This may take a few minutes for 1,125 probe files...")
!cd data && zip -r ../{output_zip} probes_binary/ -q

print(f"\n✅ Package created: {output_zip}")
print(f"Size: {os.path.getsize(output_zip) / 1e6:.2f} MB")

# Option to download best layer only
print("\n💡 TIP: Download options:")
print("   1. Full package (all layers) - see below")
print(f"   2. Best layer only - smaller download")

# Create best layer only zip
best_layer_zip = f'brije_layer_{best_layer_idx}_probes.zip'
!cd data/probes_binary && zip -r ../../{best_layer_zip} layer_{best_layer_idx}/ -q

print(f"\nBest layer package: {best_layer_zip} ({os.path.getsize(best_layer_zip) / 1e6:.2f} MB)")

# Download best layer by default (faster)
print("\n📥 Downloading best layer probes...")
files.download(best_layer_zip)

print("\n✅ Download complete!")
print("\nPackage contains:")
print(f"  • Layer {best_layer_idx} probes (45 binary probes)")
print("  • Per-action metrics (metrics_*.json)")
print("  • Aggregate performance summary")
print("\nTo download ALL layers (all 1,125 probes):")
print(f"  Uncomment the line below and run again:")
print(f"  # files.download('{output_zip}')")

## 1️⃣2️⃣ Summary and Next Steps

In [None]:
print("="*60)
print("🎉 PIPELINE COMPLETE!")
print("="*60)
print("\n✅ What was accomplished:")
print("  1. Captured activations from Gemma 3 4B (layers 4-28)")
print("  2. Trained 1,125 binary probes (45 per layer × 25 layers)")
print("  3. Evaluated performance across all layers")
print("  4. Identified best performing layer")
print("  5. Saved all outputs to Google Drive")
print("\n📂 Outputs saved to:")
print(f"  • Local: {os.getcwd()}/data/")
print(f"  • Google Drive: {drive_output_dir}")

if all_layer_metrics:
    best = max(all_layer_metrics, key=lambda x: x['avg_auc'])
    print(f"\n🏆 Best Layer: {best['layer']} (AUC: {best['avg_auc']:.4f})")

print("\n🚀 Next Steps:")
print("  1. Download trained probes (best layer or all layers)")
print("  2. Use multi_probe_inference.py for predictions")
print("  3. Compare performance across layers")
print("  4. Experiment with different thresholds")

print("\n💡 Usage Example (Best Layer):")
print(f"  python src/probes/multi_probe_inference.py \\")
print(f"    --probes-dir data/probes_binary/layer_{best_layer_idx if all_layer_metrics else 27} \\")
print("    --model google/gemma-3-4b-it \\")
print(f"    --layer {best_layer_idx if all_layer_metrics else 27} \\")
print("    --text \"Your text here\" \\")
print("    --top-k 5")

print("\n📊 Key Findings:")
if all_layer_metrics:
    auc_scores = [m['avg_auc'] for m in all_layer_metrics]
    print(f"  • {len(all_layer_metrics)} layers trained successfully")
    print(f"  • Average AUC across all layers: {np.mean(auc_scores):.4f}")
    print(f"  • Performance range: {min(auc_scores):.4f} - {max(auc_scores):.4f}")
    print(f"  • {sum(1 for auc in auc_scores if auc > 0.90)}/{len(auc_scores)} layers achieved AUC > 0.90")

print("\n📚 Files & Documentation:")
print("  • training_summary.json - Overall training summary")
print("  • layer_comparison.png - Performance visualization")
print("  • aggregate_metrics.json (per layer) - Detailed metrics")
print("  • README.md - Full documentation")
print("="*60)