# Learnable Scalar Alpha Training for MKA Layer Merging

This notebook demonstrates training a **scalar alpha parameter** to find optimal layer merging coefficients.

**Goal**: Test whether the paper's S_lm heuristic is truly optimal by learning α via gradient descent.

## Method
- Replace layer pairs with `MergeableLayer` wrappers
- Each wrapper has a trainable scalar α (logit-parameterized)
- Train α parameters on calibration data while keeping original layers frozen
- Compare learned α values with similarity-based heuristic

---

## 💡 **Important Notes**
- **For actual training**: Run commands directly in terminal (see commands below)
- **This notebook**: Use for interactive analysis and visualization of results
- **Safety feature**: `EXECUTE_COMMANDS = False` by default to prevent accidental runs

## Setup and Configuration

In [16]:
import os
import sys
import subprocess

# Configuration
MODEL_PATH = "meta-llama/Meta-Llama-3-8B"
DATA_DIR = "./data"
NUM_LAYERS = 14  # Number of layer pairs to merge
OUTPUT_DIR = "./merged_weights"

# Training hyperparameters
ALPHA_TRAINING_STEPS = 500
ALPHA_LEARNING_RATE = 1e-4
CALIBRATION_BATCH_SIZE = 4
CALIBRATION_SAMPLES = 100

# Safety toggle - set to True to execute commands in this notebook
# RECOMMENDED: Run commands directly in terminal instead
EXECUTE_COMMANDS = False

print("=" * 60)
print("SCALAR ALPHA TRAINING - CONFIGURATION")
print("=" * 60)
print(f"  Model: {MODEL_PATH}")
print(f"  Data directory: {DATA_DIR}")
print(f"  Layers to merge: {NUM_LAYERS}")
print(f"  Training steps: {ALPHA_TRAINING_STEPS}")
print(f"  Learning rate: {ALPHA_LEARNING_RATE}")
print(f"  Batch size: {CALIBRATION_BATCH_SIZE}")
print(f"  Calibration samples: {CALIBRATION_SAMPLES}")
print("=" * 60)
print(f"⚠️  EXECUTE_COMMANDS = {EXECUTE_COMMANDS}")
if not EXECUTE_COMMANDS:
    print("    (Commands will show as dry-run only)")
print("=" * 60)

SCALAR ALPHA TRAINING - CONFIGURATION
  Model: meta-llama/Meta-Llama-3-8B
  Data directory: ./data
  Layers to merge: 14
  Training steps: 500
  Learning rate: 0.0001
  Batch size: 4
  Calibration samples: 100
⚠️  EXECUTE_COMMANDS = False
    (Commands will show as dry-run only)


## Step 1: Verify Required Files

Before running training, ensure all required files exist:
- Model checkpoint or HuggingFace model access
- MMLU data files in `data/dev/` and `data/test/`
- Similarity matrix (optional, for comparison)

In [17]:
# Check data directory
if os.path.exists(DATA_DIR):
    dev_files = os.listdir(os.path.join(DATA_DIR, "dev")) if os.path.exists(os.path.join(DATA_DIR, "dev")) else []
    test_files = os.listdir(os.path.join(DATA_DIR, "test")) if os.path.exists(os.path.join(DATA_DIR, "test")) else []
    print(f"✓ Data directory exists")
    print(f"  Dev files: {len(dev_files)}")
    print(f"  Test files: {len(test_files)}")
else:
    print(f"✗ Data directory not found: {DATA_DIR}")

# Check similarity matrix (optional)
similarity_matrix_path = "similarity_matrix.pkl"
if os.path.exists(similarity_matrix_path):
    print(f"✓ Similarity matrix found: {similarity_matrix_path}")
else:
    print(f"⚠️ Similarity matrix not found (optional): {similarity_matrix_path}")

# Check output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"✓ Output directory ready: {OUTPUT_DIR}")

✓ Data directory exists
  Dev files: 57
  Test files: 57
✓ Similarity matrix found: similarity_matrix.pkl
✓ Output directory ready: ./merged_weights


## Step 2: Train Scalar Alpha Parameters

This command will:
1. Load the Llama-3-8B model
2. Replace selected layer pairs with `MergeableLayer` wrappers
3. Train scalar α parameters using calibration data
4. Save learned α values to `learned_alphas.json`
5. Fuse layers and save the merged model

### 🚀 **Recommended: Run in Terminal**
Copy and run this command directly in PowerShell:
```powershell
python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --alpha_training_steps 500 --alpha_learning_rate 1e-4
```

In [None]:
# Uncomment and run this cell to execute training directly
!python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --alpha_training_steps 500 --alpha_learning_rate 1e-4

print("👆 Uncomment the line above and run this cell to train scalar alpha")
print("   OR use the subprocess approach in the next cell")

## Alternative: Run Command Directly in Notebook

You can also run the command directly in a notebook cell using `!`:
- The `!` prefix executes shell commands from within the notebook
- This is simpler than the `subprocess` approach below

In [None]:
# # Build the command
# cmd = [
#     "python", "pipeline.py",
#     "--model_path", MODEL_PATH,
#     "--num_layer", str(NUM_LAYERS),
#     "--data_dir", DATA_DIR,
#     "--use_learnable_alpha",
#     "--alpha_training_steps", str(ALPHA_TRAINING_STEPS),
#     "--alpha_learning_rate", str(ALPHA_LEARNING_RATE),
#     "--calibration_batch_size", str(CALIBRATION_BATCH_SIZE),
#     "--calibration_samples", str(CALIBRATION_SAMPLES),
# ]

# print("=" * 60)
# print("TRAINING COMMAND")
# print("=" * 60)
# print(" ".join(cmd))
# print("=" * 60)

# if EXECUTE_COMMANDS:
#     print("\n🚀 Executing training...\n")
#     result = subprocess.run(cmd, capture_output=False, text=True)
#     if result.returncode == 0:
#         print("\n✅ Training completed successfully!")
#     else:
#         print(f"\n❌ Training failed with exit code {result.returncode}")
# else:
#     print("\n⚠️  EXECUTE_COMMANDS is False")
#     print("   To run: Set EXECUTE_COMMANDS = True in configuration cell")
#     print("   OR better: Copy the command above and run in terminal")

TRAINING COMMAND
python pipeline.py --model_path meta-llama/Meta-Llama-3-8B --num_layer 14 --data_dir ./data --use_learnable_alpha --alpha_training_steps 500 --alpha_learning_rate 0.0001 --calibration_batch_size 4 --calibration_samples 100

⚠️  EXECUTE_COMMANDS is False
   To run: Set EXECUTE_COMMANDS = True in configuration cell
   OR better: Copy the command above and run in terminal


## Step 3: Analyze Learned Alpha Values

After training, examine the learned α values and compare with similarity scores.

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt

learned_alphas_path = os.path.join(OUTPUT_DIR, "learned_alphas.json")

if os.path.exists(learned_alphas_path):
    with open(learned_alphas_path, 'r') as f:
        data = json.load(f)
    
    learned_alphas = data.get('learned_alphas', [])
    similarity_scores = data.get('similarity_scores', [])
    
    print(f"Number of learned alphas: {len(learned_alphas)}")
    print(f"Alpha statistics:")
    print(f"  Mean: {np.mean(learned_alphas):.4f}")
    print(f"  Std:  {np.std(learned_alphas):.4f}")
    print(f"  Min:  {np.min(learned_alphas):.4f}")
    print(f"  Max:  {np.max(learned_alphas):.4f}")
    
    if similarity_scores:
        print(f"\nSimilarity scores:")
        print(f"  Mean: {np.mean(similarity_scores):.4f}")
        print(f"  Std:  {np.std(similarity_scores):.4f}")
    
    # Plot histogram
    plt.figure(figsize=(10, 6))
    plt.subplot(1, 2, 1)
    plt.hist(learned_alphas, bins=20, edgecolor='black')
    plt.xlabel('Alpha Value')
    plt.ylabel('Frequency')
    plt.title('Distribution of Learned Alpha Values')
    plt.axvline(x=np.mean(learned_alphas), color='r', linestyle='--', label=f'Mean: {np.mean(learned_alphas):.3f}')
    plt.legend()
    
    # Plot alpha vs similarity
    if similarity_scores and len(similarity_scores) == len(learned_alphas):
        plt.subplot(1, 2, 2)
        plt.scatter(similarity_scores, learned_alphas, alpha=0.6)
        plt.xlabel('Similarity Score')
        plt.ylabel('Learned Alpha')
        plt.title('Learned Alpha vs Similarity Score')
        
        # Add correlation
        corr = np.corrcoef(similarity_scores, learned_alphas)[0, 1]
        plt.text(0.05, 0.95, f'Correlation: {corr:.3f}', 
                transform=plt.gca().transAxes, verticalalignment='top')
    
    plt.tight_layout()
    plt.show()
else:
    print(f"✗ Learned alphas file not found: {learned_alphas_path}")
    print("Run the training step first.")

✗ Learned alphas file not found: ./merged_weights\learned_alphas.json
Run the training step first.


## Step 4: Evaluate Merged Model

Test the merged model on MMLU benchmark to measure accuracy.

### 🎯 **For Comprehensive Comparison**
Run this command in terminal to compare all 5 methods:
```powershell
python evaluate_methods.py --model_path "meta-llama/Meta-Llama-3-8B" --data_dir "./data" --similarity_matrix "similarity_matrix.pkl" --output_dir "./experiments"
```

In [None]:
# Evaluation command
print("=" * 60)
print("EVALUATION COMMAND")
print("=" * 60)
eval_cmd = f"python evaluate_methods.py --model_path {MODEL_PATH} --data_dir {DATA_DIR} --output_dir ./experiments"
print(eval_cmd)
print("=" * 60)
print(f"\n📁 Merged model saved in: {OUTPUT_DIR}")
print("\n💡 This evaluates the scalar alpha method.")
print("   To also include MLP comparison, add: --include_mlp")

EVALUATION COMMAND
python evaluate_methods.py --model_path meta-llama/Meta-Llama-3-8B --data_dir ./data --output_dir ./experiments

📁 Merged model saved in: ./merged_weights

💡 This evaluates the scalar alpha method.
   To also include MLP comparison, add: --include_mlp


## Summary

This notebook demonstrated:
1. ✅ Training scalar α parameters for layer merging
2. ✅ Analyzing learned α values and their relationship to similarity scores
3. ✅ Saving the merged model for evaluation

**Key Questions to Answer:**
- Do learned α values differ significantly from similarity-based heuristic?
- Does learning α improve model accuracy compared to fixed α?
- What patterns emerge in the learned α distribution?

**Next Steps:**
- Compare with MLP-based merging (see `train_mlp_alpha.ipynb`)
- Run full evaluation suite to compare all methods
- Analyze per-layer α patterns and their correlation with layer properties

---

## 🔄 **Quick Reference: All Commands**

### Scalar Alpha Training (This Notebook):
```powershell
python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --alpha_training_steps 500 --alpha_learning_rate 1e-4
```

### MLP Alpha Training (See Other Notebook):
```powershell
python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 14 --data_dir "./data" --use_learnable_alpha --use_mlp_merge --alpha_training_steps 500 --alpha_learning_rate 1e-4
```

### Full Evaluation (Compare All 5 Methods):
```powershell
python evaluate_methods.py --model_path "meta-llama/Meta-Llama-3-8B" --data_dir "./data" --similarity_matrix "similarity_matrix.pkl" --output_dir "./experiments" --include_mlp
```