# Learnable Scalar Alpha Training for MKA

Train and evaluate learnable scalar α for layer merging.

## 1. Setup and Configuration

In [5]:
import os
import sys
import subprocess
from huggingface_hub import login

# HuggingFace Authentication
HF_TOKEN = ""

if HF_TOKEN:
    login(token=HF_TOKEN)
    print("✓ Logged in to HuggingFace")

# Configuration
MODEL_PATH = "meta-llama/Meta-Llama-3-8B"
DATA_DIR = "./data"
NUM_LAYERS = 13  # Must match your baseline evaluation
OUTPUT_DIR_BASELINE = "./output_baseline"
OUTPUT_DIR_LEARNED = "./output_learned"

# Training hyperparameters
ALPHA_TRAINING_STEPS = 500
ALPHA_LEARNING_RATE = 1e-4
CALIBRATION_BATCH_SIZE = 4
CALIBRATION_SAMPLES = 100

print("=" * 60)
print("SCALAR ALPHA EXPERIMENT - CONFIGURATION")
print("=" * 60)
print(f"  Model: {MODEL_PATH}")
print(f"  Layers to merge: {NUM_LAYERS}")
print(f"  Training steps: {ALPHA_TRAINING_STEPS}")
print(f"  Learning rate: {ALPHA_LEARNING_RATE}")
print("=" * 60)

✓ Logged in to HuggingFace
SCALAR ALPHA EXPERIMENT - CONFIGURATION
  Model: meta-llama/Meta-Llama-3-8B
  Layers to merge: 13
  Training steps: 500
  Learning rate: 0.0001


## 2. Download MMLU Dataset (if needed)

In [6]:
# Download MMLU dataset (only need to run once)
import os
import subprocess

if not os.path.exists("./data"):
    print("📥 Downloading MMLU dataset...")
    # Clone the official MMLU repository
    !git clone https://github.com/hendrycks/test.git mmlu_download
    
    # Move the data folder
    !mv mmlu_download/data ./data
    
    # Clean up
    !rm -rf mmlu_download
    
    # Verify structure
    if os.path.exists("./data/dev") and os.path.exists("./data/test"):
        print("✅ MMLU dataset downloaded successfully!")
        dev_count = len([f for f in os.listdir("./data/dev") if f.endswith("_dev.csv")])
        test_count = len([f for f in os.listdir("./data/test") if f.endswith("_test.csv")])
        print(f"   Dev files: {dev_count}, Test files: {test_count}")
    else:
        print("⚠️ Download completed but structure looks wrong")
else:
    print("✅ Data directory already exists")

✅ Data directory already exists


## 3. Verify Data

In [7]:
# Check data directory
if os.path.exists(DATA_DIR):
    dev_files = os.listdir(os.path.join(DATA_DIR, "dev")) if os.path.exists(os.path.join(DATA_DIR, "dev")) else []
    test_files = os.listdir(os.path.join(DATA_DIR, "test")) if os.path.exists(os.path.join(DATA_DIR, "test")) else []
    print(f"✓ Data directory exists: {len(dev_files)} dev files, {len(test_files)} test files")
else:
    print(f"✗ Data directory not found: {DATA_DIR}")
    print("  Make sure MMLU data is in ./data/dev/ and ./data/test/")

✓ Data directory exists: 57 dev files, 57 test files


## 4. Train Learnable Alpha (saves model, ~30 min)

In [8]:
# Train learnable alpha and evaluate on MMLU
!python pipeline.py --model_path "meta-llama/Meta-Llama-3-8B" --num_layer 13 --data_dir "./data" --use_learnable_alpha --alpha_training_steps 500 --alpha_learning_rate 1e-4

^C


## 5. Load Fused Model

In [None]:
# Load the fused model with learned alphas
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# Path to the saved fused model
MODEL_DIR = f"./output/Meta-Llama-3-8B/fused_{NUM_LAYERS}_layers/iteration/merged_weights"

print("Loading fused model with learned alphas...")
print(f"Model directory: {MODEL_DIR}")

# Check if model exists
if not os.path.exists(MODEL_DIR):
    print(f"❌ Model directory not found: {MODEL_DIR}")
    print("   Make sure training has completed successfully.")
else:
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        use_fast=True,
        trust_remote_code=True,
        padding_side="left"
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    # Load the fused model
    fused_model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
    )
    
    # Disable caching
    fused_model.config.use_cache = False
    
    print(f"✅ Model loaded successfully!")
    print(f"   Number of layers: {fused_model.config.num_hidden_layers}")
    print(f"   Model dtype: {fused_model.dtype}")

In [None]:
# Evaluate on MMLU
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import json

# MMLU evaluation function
choices = ["A", "B", "C", "D"]

def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt

def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt

@torch.no_grad()
def eval_subject(subject, model, tokenizer, dev_df, test_df, ntrain=5):
    cors = []
    total_loss = 0
    
    for i in tqdm(range(test_df.shape[0]), desc=f"Evaluating {subject}"):
        prompt_end = format_example(test_df, i, include_answer=False)
        train_prompt = gen_prompt(dev_df, subject, ntrain)
        prompt = train_prompt + prompt_end
        
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
        labels = input_ids.clone()
        labels[:, :-len(tokenizer(prompt_end).input_ids)] = -100
        
        outputs = model(input_ids=input_ids, labels=labels, use_cache=False)
        logits = outputs.logits[:, -1, :]
        loss = outputs.loss
        total_loss += loss.item()
        
        probs = torch.nn.functional.softmax(logits, dim=-1).detach().float().cpu().numpy()
        pred = choices[np.argmax(probs[:, [tokenizer(c).input_ids[-1] for c in choices]])]
        label = test_df.iloc[i, test_df.shape[1] - 1]
        
        cor = pred == label
        cors.append(cor)
    
    acc = np.mean(cors)
    avg_loss = total_loss / len(test_df)
    ppl = np.exp(avg_loss)
    
    return acc, ppl

# Run evaluation on all subjects
print("\n" + "="*60)
print("EVALUATING FUSED MODEL ON MMLU")
print("="*60)

fused_model.eval()

subjects = sorted([
    f.split("_test.csv")[0]
    for f in os.listdir(os.path.join(DATA_DIR, "test"))
    if "_test.csv" in f
])

all_accs = {}
all_ppls = {}

for subject in subjects:
    dev_df = pd.read_csv(
        os.path.join(DATA_DIR, "dev", subject + "_dev.csv"), header=None
    )[:5]  # Use 5 examples
    test_df = pd.read_csv(
        os.path.join(DATA_DIR, "test", subject + "_test.csv"), header=None
    )
    
    acc, ppl = eval_subject(subject, fused_model, tokenizer, dev_df, test_df, ntrain=5)
    
    all_accs[subject] = acc
    all_ppls[subject] = ppl
    
    print(f"Average accuracy {acc:.3f} - {subject}")
    print(f"Perplexity {ppl:.3f} - {subject}")

avg_acc = np.mean(list(all_accs.values()))
avg_ppl = np.mean(list(all_ppls.values()))

print("\n" + "="*60)
print("MMLU EVALUATION RESULTS (LEARNED ALPHA)")
print("="*60)
print(f"Average Accuracy:   {avg_acc:.4f}")
print(f"Average Perplexity: {avg_ppl:.4f}")
print("="*60)

# Save results
results = {
    "average_accuracy": float(avg_acc),
    "average_perplexity": float(avg_ppl),
    "per_subject_accuracy": {k: float(v) for k, v in all_accs.items()},
    "per_subject_perplexity": {k: float(v) for k, v in all_ppls.items()},
}

results_path = f"./output/Meta-Llama-3-8B/fused_{NUM_LAYERS}_layers/iteration/fusion_info/mmlu_results.json"
os.makedirs(os.path.dirname(results_path), exist_ok=True)
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"\n✅ Results saved to: {results_path}")

## 6. Evaluate on MMLU (~40-90 min)

In [None]:
# Visualize top and bottom performing subjects
import matplotlib.pyplot as plt

# Sort subjects by accuracy
sorted_subjects = sorted(all_accs.items(), key=lambda x: x[1], reverse=True)

# Get top 10 and bottom 10
top_10 = sorted_subjects[:10]
bottom_10 = sorted_subjects[-10:]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Top 10 subjects
subjects_top = [s[0] for s in top_10]
accs_top = [s[1] for s in top_10]

ax1.barh(range(len(subjects_top)), accs_top, color='green', alpha=0.7)
ax1.set_yticks(range(len(subjects_top)))
ax1.set_yticklabels(subjects_top, fontsize=10)
ax1.set_xlabel('Accuracy', fontsize=12, fontweight='bold')
ax1.set_title('Top 10 Subjects by Accuracy', fontsize=14, fontweight='bold')
ax1.invert_yaxis()
ax1.grid(axis='x', alpha=0.3)

# Add accuracy values on bars
for i, acc in enumerate(accs_top):
    ax1.text(acc + 0.01, i, f'{acc:.3f}', va='center', fontsize=10)

# Bottom 10 subjects
subjects_bottom = [s[0] for s in bottom_10]
accs_bottom = [s[1] for s in bottom_10]

ax2.barh(range(len(subjects_bottom)), accs_bottom, color='red', alpha=0.7)
ax2.set_yticks(range(len(subjects_bottom)))
ax2.set_yticklabels(subjects_bottom, fontsize=10)
ax2.set_xlabel('Accuracy', fontsize=12, fontweight='bold')
ax2.set_title('Bottom 10 Subjects by Accuracy', fontsize=14, fontweight='bold')
ax2.invert_yaxis()
ax2.grid(axis='x', alpha=0.3)

# Add accuracy values on bars
for i, acc in enumerate(accs_bottom):
    ax2.text(acc + 0.01, i, f'{acc:.3f}', va='center', fontsize=10)

plt.tight_layout()
plt.show()

print(f"\n📊 Overall Average Accuracy: {avg_acc:.4f}")

## 7. Visualize Top/Bottom Subjects

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt

# Load learned alphas
learned_alphas_path = "./output/Meta-Llama-3-8B/fused_13_layers/iteration/merged_weights/learned_alphas.json"

if os.path.exists(learned_alphas_path):
    with open(learned_alphas_path, 'r') as f:
        data = json.load(f)
    
    learned_alphas = data.get('learned_alphas', [])
    similarity_scores = data.get('similarity_scores', [])
    
    print("=" * 60)
    print("LEARNED ALPHA STATISTICS")
    print("=" * 60)
    print(f"  Number of layers: {len(learned_alphas)}")
    print(f"  Mean α: {np.mean(learned_alphas):.4f}")
    print(f"  Std α:  {np.std(learned_alphas):.4f}")
    print(f"  Min α:  {np.min(learned_alphas):.4f}")
    print(f"  Max α:  {np.max(learned_alphas):.4f}")
    print("=" * 60)
    
    # Visualize alpha distribution
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 3, 1)
    plt.hist(learned_alphas, bins=15, edgecolor='black', alpha=0.7, color='steelblue')
    plt.axvline(np.mean(learned_alphas), color='r', linestyle='--', label=f'Mean: {np.mean(learned_alphas):.3f}')
    plt.xlabel('Alpha Value')
    plt.ylabel('Frequency')
    plt.title('Distribution of Learned α')
    plt.legend()
    plt.grid(alpha=0.3)
    
    # Alpha vs layer index
    plt.subplot(1, 3, 2)
    plt.plot(range(len(learned_alphas)), learned_alphas, marker='o', linestyle='-', color='darkgreen')
    plt.xlabel('Layer Index')
    plt.ylabel('Learned α')
    plt.title('Learned α Across Layers')
    plt.grid(alpha=0.3)
    
    # Alpha vs Similarity
    if similarity_scores and len(similarity_scores) == len(learned_alphas):
        plt.subplot(1, 3, 3)
        plt.scatter(similarity_scores, learned_alphas, alpha=0.6, s=100, color='coral')
        plt.xlabel('Similarity Score (S_lm)')
        plt.ylabel('Learned α')
        plt.title('Learned α vs Similarity')
        corr = np.corrcoef(similarity_scores, learned_alphas)[0, 1]
        plt.text(0.05, 0.95, f'Correlation: {corr:.3f}', transform=plt.gca().transAxes, 
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n✓ Analysis complete! Check the plots above.")
else:
    print(f"✗ Learned alphas not found: {learned_alphas_path}")
    print("  Training may still be running.")

## 8. Analyze Learned Alpha Values