# MLP-Based Dynamic Alpha Training for MKA Layer Merging

This notebook demonstrates training an **MLP network** to predict optimal layer merging coefficients dynamically based on input activation statistics.

**Goal**: Test whether input-dependent dynamic merging outperforms static coefficients.

## Method
- Replace layer pairs with `MLPMergeableLayer` wrappers
- Each wrapper contains a small MLP that predicts α from activation statistics:
  - Mean activation of layer 1
  - Mean activation of layer 2
  - Std activation of layer 1
  - Std activation of layer 2
- Train MLP parameters on calibration data while keeping original layers frozen
- Alpha adapts dynamically to each input sample

---

## 💡 **Important Notes**
- **For actual training**: Run commands directly in terminal (see commands below)
- **This notebook**: Use for interactive analysis and visualization of results
- **Safety feature**: `EXECUTE_COMMANDS = False` by default to prevent accidental runs
- **Key difference**: Uses `--use_mlp_merge` flag for dynamic alpha prediction

## Setup and Configuration

In [None]:
import os
import sys
import subprocess

# Configuration
MODEL_PATH = "meta-llama/Meta-Llama-3-8B"
DATA_DIR = "./data"
NUM_LAYERS = 14  # Number of layer pairs to merge
OUTPUT_DIR = "./merged_weights_mlp"

# Training hyperparameters
ALPHA_TRAINING_STEPS = 500
ALPHA_LEARNING_RATE = 1e-4
CALIBRATION_BATCH_SIZE = 4
CALIBRATION_SAMPLES = 100

# MLP-specific parameters (defined in mergeable_layer.py)
# MLP architecture: 4 inputs -> hidden_dim (64) -> 1 output (alpha)

# Safety toggle - set to True to execute commands in this notebook
# RECOMMENDED: Run commands directly in terminal instead
EXECUTE_COMMANDS = False

print("=" * 60)
print("MLP ALPHA TRAINING - CONFIGURATION")
print("=" * 60)
print(f"  Model: {MODEL_PATH}")
print(f"  Data directory: {DATA_DIR}")
print(f"  Layers to merge: {NUM_LAYERS}")
print(f"  Training steps: {ALPHA_TRAINING_STEPS}")
print(f"  Learning rate: {ALPHA_LEARNING_RATE}")
print(f"  Batch size: {CALIBRATION_BATCH_SIZE}")
print(f"  Calibration samples: {CALIBRATION_SAMPLES}")
print(f"  Method: MLP-based dynamic merging")
print("=" * 60)
print(f"⚠️  EXECUTE_COMMANDS = {EXECUTE_COMMANDS}")
if not EXECUTE_COMMANDS:
    print("    (Commands will show as dry-run only)")
print("=" * 60)

## Step 1: Verify Required Files

Before running training, ensure all required files exist:
- Model checkpoint or HuggingFace model access
- MMLU data files in `data/dev/` and `data/test/`
- Similarity matrix (optional, for comparison)

In [None]:
# Check data directory
if os.path.exists(DATA_DIR):
    dev_files = os.listdir(os.path.join(DATA_DIR, "dev")) if os.path.exists(os.path.join(DATA_DIR, "dev")) else []
    test_files = os.listdir(os.path.join(DATA_DIR, "test")) if os.path.exists(os.path.join(DATA_DIR, "test")) else []
    print(f"✓ Data directory exists")
    print(f"  Dev files: {len(dev_files)}")
    print(f"  Test files: {len(test_files)}")
else:
    print(f"✗ Data directory not found: {DATA_DIR}")

# Check similarity matrix (optional)
similarity_matrix_path = "similarity_matrix.pkl"
if os.path.exists(similarity_matrix_path):
    print(f"✓ Similarity matrix found: {similarity_matrix_path}")
else:
    print(f"⚠️ Similarity matrix not found (optional): {similarity_matrix_path}")

# Check output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"✓ Output directory ready: {OUTPUT_DIR}")

## Step 2: Train MLP Alpha Predictor

This command will:
1. Load the Llama-3-8B model
2. Replace selected layer pairs with `MLPMergeableLayer` wrappers
3. Train MLP networks to predict α from activation statistics
4. Note: Unlike scalar α, MLP predictions are dynamic (different per input)
5. Fuse layers using average α=0.5 and save the merged model

**Key Difference**: The `--use_mlp_merge` flag enables MLP-based merging.

### 🚀 **Recommended: Run in Terminal**
Copy and run this command directly in PowerShell:
```powershell
python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --use_mlp_merge --alpha_training_steps 500 --alpha_learning_rate 1e-4
```

In [None]:
# Uncomment and run this cell to execute MLP training directly
# !python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --use_mlp_merge --alpha_training_steps 500 --alpha_learning_rate 1e-4

print("👆 Uncomment the line above and run this cell to train MLP alpha")
print("   OR use the subprocess approach in the next cell")

## Alternative: Run Command Directly in Notebook

You can also run the command directly in a notebook cell using `!`:
- The `!` prefix executes shell commands from within the notebook
- This is simpler than the `subprocess` approach below

In [None]:
# Build the command
cmd = [
    "python", "pipeline.py",
    "--model_path", MODEL_PATH,
    "--num_layer", str(NUM_LAYERS),
    "--data_dir", DATA_DIR,
    "--use_learnable_alpha",
    "--use_mlp_merge",  # This flag switches to MLP mode
    "--alpha_training_steps", str(ALPHA_TRAINING_STEPS),
    "--alpha_learning_rate", str(ALPHA_LEARNING_RATE),
    "--calibration_batch_size", str(CALIBRATION_BATCH_SIZE),
    "--calibration_samples", str(CALIBRATION_SAMPLES),
]

print("=" * 60)
print("MLP TRAINING COMMAND")
print("=" * 60)
print(" ".join(cmd))
print("=" * 60)

if EXECUTE_COMMANDS:
    print("\n🚀 Executing MLP training...\n")
    result = subprocess.run(cmd, capture_output=False, text=True)
    if result.returncode == 0:
        print("\n✅ MLP training completed successfully!")
    else:
        print(f"\n❌ Training failed with exit code {result.returncode}")
else:
    print("\n⚠️  EXECUTE_COMMANDS is False")
    print("   To run: Set EXECUTE_COMMANDS = True in configuration cell")
    print("   OR better: Copy the command above and run in terminal")

## Step 3: Understanding MLP-Based Merging

Unlike scalar α which is fixed per layer pair, MLP-based α is:
- **Dynamic**: Varies for each input sample
- **Context-aware**: Depends on activation statistics
- **Learned**: MLP weights are optimized via gradient descent

The learned alphas file will contain -1.0 values to indicate MLP layers (since α varies per input).

In [None]:
import json
import numpy as np

learned_alphas_path = os.path.join(OUTPUT_DIR, "learned_alphas.json")

if os.path.exists(learned_alphas_path):
    with open(learned_alphas_path, 'r') as f:
        data = json.load(f)
    
    learned_alphas = data.get('learned_alphas', [])
    similarity_scores = data.get('similarity_scores', [])
    
    print(f"Number of layer pairs: {len(learned_alphas)}")
    print(f"\nAlpha values (should be -1.0 for MLP layers):")
    print(learned_alphas[:10])  # Show first 10
    
    mlp_count = sum(1 for a in learned_alphas if a == -1.0)
    print(f"\nMLP layers: {mlp_count} / {len(learned_alphas)}")
    
    if mlp_count == len(learned_alphas):
        print("✓ All layers are MLP-based (as expected)")
    else:
        print("⚠️ Some layers are not MLP-based (unexpected)")
    
    print("\nNote: -1.0 indicates MLP-based layer where α is predicted dynamically.")
    print("To analyze actual α distributions, you need to:")
    print("1. Run inference on test data")
    print("2. Collect α predictions for each sample")
    print("3. Analyze the distribution of predictions")
else:
    print(f"✗ Learned alphas file not found: {learned_alphas_path}")
    print("Run the training step first.")

## Step 4: Analyze MLP Predictions (Advanced)

To understand how the MLP behaves, we can load the model and examine α predictions on sample inputs.

**Note**: This requires loading the trained model, which is more advanced.

In [None]:
print("To analyze MLP predictions in detail:")
print("1. Load the model with trained MLPMergeableLayer wrappers")
print("2. Run inference on a batch of inputs")
print("3. Extract α predictions from each layer's forward pass")
print("4. Analyze α distribution across samples and layers")
print("\nThis requires custom inference code and is beyond the scope of this notebook.")
print("\nFor comparison with other methods, use evaluate_methods.py:")
eval_cmd = f"python evaluate_methods.py --model_path {MODEL_PATH} --data_dir {DATA_DIR} --output_dir ./experiments --include_mlp"
print(eval_cmd)

## Step 5: Compare with Scalar Alpha Method

Run comprehensive evaluation to compare:
1. **MKA (similarity-based)**: Original paper's heuristic
2. **Fixed α=0.5**: Uniform merging
3. **Fixed α=0.7**: Higher weight on first layer
4. **Learned scalar α**: Trainable static coefficients
5. **Learned MLP α**: Dynamic input-dependent coefficients

### 🎯 **Comprehensive Evaluation**
Run this command in terminal:
```powershell
python evaluate_methods.py --model_path "meta-llama/Meta-Llama-3-8B" --data_dir "./data" --similarity_matrix "similarity_matrix.pkl" --output_dir "./experiments" --include_mlp
```

In [None]:
# Build comprehensive evaluation command
eval_cmd = [
    "python", "evaluate_methods.py",
    "--model_path", MODEL_PATH,
    "--data_dir", DATA_DIR,
    "--similarity_matrix", "similarity_matrix.pkl",
    "--output_dir", "./experiments",
    "--include_mlp",  # Include MLP-based method in comparison
]

print("=" * 60)
print("FULL EVALUATION COMMAND (ALL 5 METHODS)")
print("=" * 60)
print(" ".join(eval_cmd))
print("=" * 60)

if EXECUTE_COMMANDS:
    print("\n🚀 Executing comprehensive evaluation...\n")
    print("⚠️ This will take significant time (trains and evaluates 5 methods)\n")
    result = subprocess.run(eval_cmd, capture_output=False, text=True)
    if result.returncode == 0:
        print("\n✅ Evaluation completed successfully!")
        print("Results saved in ./experiments/")
    else:
        print(f"\n❌ Evaluation failed with exit code {result.returncode}")
else:
    print("\n⚠️  EXECUTE_COMMANDS is False")
    print("   Copy the command above and run in terminal for full evaluation")

## Step 6: Visualize Results

After running evaluation, visualize the comparison between all methods.

In [None]:
import matplotlib.pyplot as plt
import json
import os

results_dir = "./experiments"
methods = ["mka_similarity", "fixed_05", "fixed_07", "learned", "learned_mlp"]
method_names = [
    "MKA (Similarity)",
    "Fixed α=0.5",
    "Fixed α=0.7",
    "Learned Scalar α",
    "Learned MLP α"
]

accuracies = []
found_methods = []

for method, name in zip(methods, method_names):
    result_file = os.path.join(results_dir, method, "results.json")
    if os.path.exists(result_file):
        with open(result_file, 'r') as f:
            data = json.load(f)
            avg_acc = data.get('average_accuracy', 0)
            accuracies.append(avg_acc * 100)  # Convert to percentage
            found_methods.append(name)
    else:
        print(f"⚠️ Results not found for {name}")

if accuracies:
    plt.figure(figsize=(12, 6))
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    bars = plt.bar(found_methods, accuracies, color=colors[:len(found_methods)])
    
    plt.xlabel('Method', fontsize=12)
    plt.ylabel('Average Accuracy (%)', fontsize=12)
    plt.title('MMLU Accuracy Comparison: All Merging Methods', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{acc:.2f}%',
                ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    print("\n📊 Results Summary:")
    for name, acc in zip(found_methods, accuracies):
        print(f"  {name}: {acc:.2f}%")
else:
    print("No results found. Run the evaluation step first.")

## Summary

This notebook demonstrated:
1. ✅ Training MLP networks to predict α dynamically
2. ✅ Understanding input-dependent merging behavior
3. ✅ Comparing with scalar α and other baseline methods

**Key Questions to Answer:**
- Does dynamic α (MLP) outperform static α (scalar)?
- How much does α vary across different inputs?
- What activation patterns lead to high/low α values?

**Research Hypotheses:**
- **H1**: MLP-based merging adapts to input complexity → better accuracy
- **H2**: Different tasks/subjects require different merging strategies
- **H3**: Dynamic α provides more flexibility than similarity heuristic

**Next Steps:**
- Compare with scalar alpha results (see `train_scalar_alpha.ipynb`)
- Analyze per-subject accuracy differences
- Investigate α prediction patterns for different input types
- Experiment with different MLP architectures (hidden size, depth)

---

## 🔄 **Quick Reference: All Commands**

### Scalar Alpha Training (See Other Notebook):
```powershell
python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --alpha_training_steps 500 --alpha_learning_rate 1e-4
```

### MLP Alpha Training (This Notebook):
```powershell
python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --use_mlp_merge --alpha_training_steps 500 --alpha_learning_rate 1e-4
```

### Full Evaluation (Compare All 5 Methods):
```powershell
python evaluate_methods.py --model_path "meta-llama/Meta-Llama-3-8B" --data_dir "./data" --similarity_matrix "similarity_matrix.pkl" --output_dir "./experiments" --include_mlp
```

---

## 📊 **Key Differences from Scalar Alpha**

| Aspect | Scalar Alpha | MLP Alpha (This) |
|--------|--------------|------------------|
| **Flag** | `--use_learnable_alpha` | `--use_learnable_alpha --use_mlp_merge` |
| **Layer Type** | `MergeableLayer` | `MLPMergeableLayer` |
| **Alpha Type** | Static (1 value) | Dynamic (varies per input) |
| **Trainable Params** | 1 scalar | ~4K MLP parameters |
| **Learned Output** | α ∈ [0, 1] | -1.0 (indicator) |
| **Fusion** | Use learned α | Use fixed 0.5 |