# Learnable MLP Alpha Training for MKA Layer Merging

This notebook compares **baseline MKA** (similarity-based) vs **MLP-based dynamic alpha** approach.

## Workflow:
1. **Run Baseline**: Original MKA with similarity-based merging (num_layer=13)
2. **Train MLP**: Learn MLP network to predict α from activation statistics
3. **Evaluate**: Test both on MMLU and compare accuracy

## Goal:
Test whether an MLP that predicts α dynamically improves over static similarity-based merging.

## Setup and Configuration

In [None]:
import os
import sys
from huggingface_hub import login

# HuggingFace Authentication
HF_TOKEN = "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"

if HF_TOKEN:
    login(token=HF_TOKEN)
    print("✓ Logged in to HuggingFace")

# Configuration
MODEL_PATH = "meta-llama/Meta-Llama-3-8B"
DATA_DIR = "./data"
NUM_LAYERS = 13  # Must match your baseline evaluation

# Training hyperparameters
ALPHA_TRAINING_STEPS = 500
ALPHA_LEARNING_RATE = 1e-4

print("=" * 60)
print("MLP ALPHA EXPERIMENT - CONFIGURATION")
print("=" * 60)
print(f"  Model: {MODEL_PATH}")
print(f"  Layers to merge: {NUM_LAYERS}")
print(f"  Training steps: {ALPHA_TRAINING_STEPS}")
print(f"  Learning rate: {ALPHA_LEARNING_RATE}")
print(f"  MLP Mode: ENABLED (4 features → hidden → α)")
print("=" * 60)

  from .autonotebook import tqdm as notebook_tqdm


✓ Logged in to HuggingFace
MLP ALPHA EXPERIMENT - CONFIGURATION
  Model: meta-llama/Meta-Llama-3-8B
  Layers to merge: 13
  Training steps: 500
  Learning rate: 0.0001
  MLP Mode: ENABLED (4 features → hidden → α)


## Step 1: Verify Data Files

In [None]:
# Download MMLU dataset (only need to run once)
import os
import subprocess

if not os.path.exists("./data"):
    print("📥 Downloading MMLU dataset...")
    # Clone the official MMLU repository
    !git clone https://github.com/hendrycks/test.git mmlu_download
    
    # Move the data folder
    !mv mmlu_download/data ./data
    
    # Clean up
    !rm -rf mmlu_download
    
    # Verify structure
    if os.path.exists("./data/dev") and os.path.exists("./data/test"):
        print("✅ MMLU dataset downloaded successfully!")
        dev_count = len([f for f in os.listdir("./data/dev") if f.endswith("_dev.csv")])
        test_count = len([f for f in os.listdir("./data/test") if f.endswith("_test.csv")])
        print(f"   Dev files: {dev_count}, Test files: {test_count}")
    else:
        print("⚠️ Download completed but structure looks wrong")
else:
    print("✅ Data directory already exists")

## Step 0: Download MMLU Dataset (Lightning AI Setup)

**First time only:** Download MMLU dataset from the official source.

In [None]:
# Check data directory
if os.path.exists(DATA_DIR):
    dev_files = os.listdir(os.path.join(DATA_DIR, "dev")) if os.path.exists(os.path.join(DATA_DIR, "dev")) else []
    test_files = os.listdir(os.path.join(DATA_DIR, "test")) if os.path.exists(os.path.join(DATA_DIR, "test")) else []
    print(f"✓ Data directory exists: {len(dev_files)} dev files, {len(test_files)} test files")
else:
    print(f"✗ Data directory not found: {DATA_DIR}")
    print("  Make sure MMLU data is in ./data/dev/ and ./data/test/")

✓ Data directory exists: 57 dev files, 57 test files


## Step 2A: Run Baseline (Original MKA - No Alpha Training)

**Note:** You mentioned you already have baseline results for num_layer=13, so you can skip this step.

In [3]:
# SKIP THIS if you already have baseline results for num_layer=13
print("⚠️ Skipping baseline - using existing results for num_layer=13")

⚠️ Skipping baseline - using existing results for num_layer=13


## Step 2B: Train MLP-based Alpha

Train MLP network to predict α dynamically, then evaluate on MMLU.

In [4]:
# Train MLP alpha and evaluate on MMLU
!python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 13 --data_dir "./data" --use_learnable_alpha --use_mlp_merge --alpha_training_steps 500 --alpha_learning_rate 1e-4

^C


## Step 3: Analyze MLP Predictions

Analyze how the MLP predicts α values based on activation statistics.

## Summary

This notebook compared:
1. **Baseline MKA** (similarity-based heuristic S_lm)
2. **MLP-based Dynamic α** (predicted from activation statistics)

**Key Findings:**
- MLP learns to predict α based on 4 activation features (mean, std, max, norm)
- Check if dynamic prediction improves over static similarity scores
- Compare MMLU accuracy between baseline and MLP approach

**Next Steps:**
- Compare with scalar alpha (see `train_scalar_alpha.ipynb`)
- Run full comparison with `evaluate_methods.py --include_mlp`

In [None]:
# TODO: Replace with your actual baseline accuracy for num_layer=13
baseline_accuracy = 0.0  # <-- UPDATE THIS WITH YOUR BASELINE RESULT

# Load MLP results
results_path = "./output/Meta-Llama-3-8B/fused_13_layers/iteration/fusion_info/mmlu_results.json"

print("=" * 60)
print("MMLU ACCURACY COMPARISON (num_layer=13)")
print("=" * 60)

if os.path.exists(results_path):
    try:
        with open(results_path, 'r') as f:
            results = json.load(f)
            mlp_accuracy = results.get('average_accuracy', 0.0)
        
        print(f"Baseline (Similarity-based):  {baseline_accuracy:.4f}")
        print(f"MLP-based Dynamic α:          {mlp_accuracy:.4f}")
        print("-" * 60)
        
        improvement = mlp_accuracy - baseline_accuracy
        improvement_pct = (improvement / baseline_accuracy * 100) if baseline_accuracy > 0 else 0
        
        print(f"Improvement: {improvement:+.4f} ({improvement_pct:+.2f}%)")
        print("=" * 60)
        
        # Visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        methods = ['Baseline\n(Similarity)', 'MLP-based\nDynamic α']
        accuracies = [baseline_accuracy, mlp_accuracy]
        colors = ['#3498db', '#9b59b6']
        
        bars = ax1.bar(methods, accuracies, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
        ax1.set_ylabel('MMLU Accuracy', fontsize=12, fontweight='bold')
        ax1.set_title('Baseline vs MLP Alpha', fontsize=14, fontweight='bold')
        ax1.set_ylim([min(accuracies) * 0.95 if min(accuracies) > 0 else 0, max(accuracies) * 1.05])
        
        for bar, acc in zip(bars, accuracies):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{acc:.4f}',
                    ha='center', va='bottom', fontsize=12, fontweight='bold')
        
        ax1.grid(axis='y', alpha=0.3)
        
        ax2.bar(['Improvement'], [improvement], color='green' if improvement > 0 else 'red', 
               alpha=0.8, edgecolor='black', linewidth=2)
        ax2.set_ylabel('Accuracy Difference', fontsize=12, fontweight='bold')
        ax2.set_title('Performance Gain', fontsize=14, fontweight='bold')
        ax2.axhline(0, color='black', linestyle='--', linewidth=1)
        ax2.text(0, improvement, f'{improvement:+.4f}\n({improvement_pct:+.2f}%)', 
                ha='center', va='bottom' if improvement > 0 else 'top', 
                fontsize=12, fontweight='bold')
        ax2.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        if 'per_subject' in results:
            print("\nTop 5 subjects by accuracy:")
            subject_accs = [(k, v) for k, v in results['per_subject'].items()]
            subject_accs.sort(key=lambda x: x[1], reverse=True)
            for subject, acc in subject_accs[:5]:
                print(f"  {subject:40s}: {acc:.4f}")
        
    except Exception as e:
        print(f"⚠️ Error loading results: {e}")
else:
    print(f"✗ Results not found: {results_path}")
    print("  Training must complete first.")
    
print("=" * 60)

## Step 4: Compare MMLU Accuracy - Baseline vs MLP Alpha

**Important:** Update `baseline_accuracy` with your actual baseline result for num_layer=13.

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt

# Load MLP alpha predictions (if saved)
mlp_alphas_path = "./output/Meta-Llama-3-8B/fused_13_layers/iteration/merged_weights/mlp_alpha_predictions.json"
learned_alphas_path = "./output/Meta-Llama-3-8B/fused_13_layers/iteration/merged_weights/learned_alphas.json"

if os.path.exists(learned_alphas_path):
    with open(learned_alphas_path, 'r') as f:
        data = json.load(f)
    
    learned_alphas = data.get('learned_alphas', [])
    similarity_scores = data.get('similarity_scores', [])
    
    print("=" * 60)
    print("MLP ALPHA PREDICTIONS")
    print("=" * 60)
    print(f"  Number of layers: {len(learned_alphas)}")
    print(f"  Mean α: {np.mean(learned_alphas):.4f}")
    print(f"  Std α:  {np.std(learned_alphas):.4f}")
    print(f"  Min α:  {np.min(learned_alphas):.4f}")
    print(f"  Max α:  {np.max(learned_alphas):.4f}")
    print("=" * 60)
    
    # Visualize
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 3, 1)
    plt.hist(learned_alphas, bins=15, edgecolor='black', alpha=0.7, color='purple')
    plt.axvline(np.mean(learned_alphas), color='r', linestyle='--', label=f'Mean: {np.mean(learned_alphas):.3f}')
    plt.xlabel('MLP Predicted α')
    plt.ylabel('Frequency')
    plt.title('Distribution of MLP α')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.plot(range(len(learned_alphas)), learned_alphas, marker='o', linestyle='-', color='darkblue')
    plt.xlabel('Layer Index')
    plt.ylabel('MLP Predicted α')
    plt.title('MLP α Across Layers')
    plt.grid(alpha=0.3)
    
    if similarity_scores and len(similarity_scores) == len(learned_alphas):
        plt.subplot(1, 3, 3)
        plt.scatter(similarity_scores, learned_alphas, alpha=0.6, s=100, color='orange')
        plt.xlabel('Similarity Score (S_lm)')
        plt.ylabel('MLP Predicted α')
        plt.title('MLP α vs Similarity')
        corr = np.corrcoef(similarity_scores, learned_alphas)[0, 1]
        plt.text(0.05, 0.95, f'Correlation: {corr:.3f}', transform=plt.gca().transAxes, 
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n✓ MLP analysis complete!")
else:
    print(f"✗ MLP predictions not found: {learned_alphas_path}")
    print("  Training must complete first.")