# 模型合併 (Model Merging)

本 notebook 對應李宏毅老師 2025 Spring ML HW9，探討如何合併多個微調模型來獲得綜合能力。

## 學習目標

1. 理解模型合併的動機與挑戰
2. 學習基本的權重平均方法
3. 掌握 Task Arithmetic 技術
4. 了解 TIES Merging 方法
5. 使用 mergekit 進行實際操作

## 參考資源

- [Model Soups](https://arxiv.org/abs/2203.05482) - Averaging weights of multiple fine-tuned models
- [Task Arithmetic](https://arxiv.org/abs/2212.04089) - Editing Models with Task Vectors
- [TIES-Merging](https://arxiv.org/abs/2306.01708) - Resolving Interference When Merging Models
- [mergekit](https://github.com/cg123/mergekit) - Model merging toolkit
- [2025 Spring HW9](https://speech.ee.ntu.edu.tw/~hylee/ml/2025-spring.php)

## 1. 為什麼需要模型合併？

### 1.1 動機與場景

```
┌─────────────────────────────────────────────────────────────────────────┐
│                      模型合併的動機                                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  場景 1: 結合多種能力                                                    │
│  ─────────────────────                                                  │
│  ┌──────────────┐   ┌──────────────┐   ┌──────────────┐                │
│  │ Model A      │   │ Model B      │   │ Model C      │                │
│  │ (程式碼能力) │ + │ (數學能力)   │ + │ (對話能力)   │                │
│  └──────────────┘   └──────────────┘   └──────────────┘                │
│           │                 │                 │                         │
│           └─────────────────┼─────────────────┘                         │
│                             ▼                                           │
│                    ┌──────────────┐                                     │
│                    │ Merged Model │                                     │
│                    │ (全能模型)   │                                     │
│                    └──────────────┘                                     │
│                                                                         │
│  場景 2: 避免訓練成本                                                    │
│  ───────────────────                                                    │
│  • 多個社群已經微調好的模型                                               │
│  • 不需要重新訓練，直接合併權重                                           │
│  • 節省 GPU 時間和資料收集成本                                            │
│                                                                         │
│  場景 3: 提升泛化能力 (Model Soups)                                      │
│  ─────────────────────────────────                                      │
│  • 同一任務的多個微調模型（不同超參數）                                    │
│  • 合併後比單一模型泛化更好                                               │
│  • 類似 ensemble 但只需要一個模型                                        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
```

### 1.2 核心挑戰：權重干擾 (Weight Interference)

直接平均不同任務的模型權重可能導致性能下降，因為不同任務的優化方向可能衝突。

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional
from copy import deepcopy
from collections import OrderedDict

# 設定
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. 基本方法：Weight Averaging

### 2.1 簡單平均 (Uniform Averaging)

In [None]:
class SimpleModel(nn.Module):
    """簡單的 MLP 模型用於展示合併概念"""
    def __init__(self, input_dim=10, hidden_dim=32, output_dim=5):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)


def uniform_averaging(models: List[nn.Module]) -> OrderedDict:
    """
    簡單的權重平均
    
    θ_merged = (1/N) * Σ θ_i
    """
    if len(models) == 0:
        raise ValueError("At least one model is required")
    
    # 取得第一個模型的 state_dict 作為基準
    merged_state = OrderedDict()
    
    # 初始化為零
    for key, param in models[0].state_dict().items():
        merged_state[key] = torch.zeros_like(param)
    
    # 累加所有模型的權重
    for model in models:
        state = model.state_dict()
        for key in merged_state:
            merged_state[key] += state[key]
    
    # 除以模型數量得到平均
    n_models = len(models)
    for key in merged_state:
        merged_state[key] /= n_models
    
    return merged_state


def weighted_averaging(models: List[nn.Module], 
                      weights: List[float]) -> OrderedDict:
    """
    加權平均
    
    θ_merged = Σ w_i * θ_i, where Σ w_i = 1
    """
    assert len(models) == len(weights)
    assert abs(sum(weights) - 1.0) < 1e-6, "Weights must sum to 1"
    
    merged_state = OrderedDict()
    
    for key, param in models[0].state_dict().items():
        merged_state[key] = torch.zeros_like(param)
    
    for model, w in zip(models, weights):
        state = model.state_dict()
        for key in merged_state:
            merged_state[key] += w * state[key]
    
    return merged_state


# 測試
print("Weight Averaging 示範")
print("="*50)

# 創建三個模型（模擬不同任務微調後的模型）
model_a = SimpleModel()
model_b = SimpleModel()
model_c = SimpleModel()

# 簡單平均
merged_uniform = uniform_averaging([model_a, model_b, model_c])
print(f"\n簡單平均結果:")
print(f"  fc1.weight 形狀: {merged_uniform['fc1.weight'].shape}")
print(f"  fc1.weight 平均值: {merged_uniform['fc1.weight'].mean().item():.4f}")

# 加權平均
weights = [0.5, 0.3, 0.2]  # Model A 權重較高
merged_weighted = weighted_averaging([model_a, model_b, model_c], weights)
print(f"\n加權平均結果 (weights={weights}):")
print(f"  fc1.weight 平均值: {merged_weighted['fc1.weight'].mean().item():.4f}")

In [None]:
# Model Soups: 找到最佳權重組合
def greedy_soup(
    models: List[nn.Module],
    val_data: Tuple[torch.Tensor, torch.Tensor],
    base_model: nn.Module
) -> Tuple[List[int], nn.Module]:
    """
    Greedy Soup: 貪婪選擇要加入的模型
    
    演算法：
    1. 從驗證分數最高的模型開始
    2. 嘗試加入每個未選中的模型
    3. 如果加入後分數提升，則保留
    4. 重複直到沒有改進
    """
    X_val, y_val = val_data
    criterion = nn.CrossEntropyLoss()
    
    def evaluate(model):
        model.eval()
        with torch.no_grad():
            output = model(X_val)
            loss = criterion(output, y_val)
            acc = (output.argmax(dim=1) == y_val).float().mean()
        return acc.item()
    
    # 評估所有模型
    scores = [evaluate(m) for m in models]
    print(f"Individual model scores: {[f'{s:.2%}' for s in scores]}")
    
    # 從最好的模型開始
    best_idx = np.argmax(scores)
    selected = [best_idx]
    
    # 創建合併模型
    merged = deepcopy(base_model)
    merged.load_state_dict(models[best_idx].state_dict())
    best_score = evaluate(merged)
    
    print(f"\nStarting with model {best_idx}, score: {best_score:.2%}")
    
    # 貪婪添加
    available = set(range(len(models))) - set(selected)
    
    while available:
        best_candidate = None
        best_new_score = best_score
        
        for idx in available:
            # 嘗試加入這個模型
            candidate_models = [models[i] for i in selected] + [models[idx]]
            merged_state = uniform_averaging(candidate_models)
            
            temp_model = deepcopy(base_model)
            temp_model.load_state_dict(merged_state)
            score = evaluate(temp_model)
            
            if score > best_new_score:
                best_new_score = score
                best_candidate = idx
        
        if best_candidate is not None:
            selected.append(best_candidate)
            available.remove(best_candidate)
            best_score = best_new_score
            print(f"Added model {best_candidate}, new score: {best_score:.2%}")
            
            # 更新合併模型
            selected_models = [models[i] for i in selected]
            merged.load_state_dict(uniform_averaging(selected_models))
        else:
            break
    
    print(f"\nFinal soup: models {selected}, score: {best_score:.2%}")
    return selected, merged


print("\n" + "="*60)
print("Greedy Soup 示範")
print("="*60)

# 創建一些模擬數據
X_val = torch.randn(100, 10)
y_val = torch.randint(0, 5, (100,))

# 簡單訓練幾個不同的模型（模擬不同超參數或種子）
def quick_train(model, X, y, epochs=50):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    for _ in range(epochs):
        optimizer.zero_grad()
        loss = criterion(model(X), y)
        loss.backward()
        optimizer.step()
    return model

models = []
for i in range(5):
    torch.manual_seed(i)
    m = SimpleModel()
    # 用不同子集訓練
    idx = torch.randperm(len(X_val))[:80]
    m = quick_train(m, X_val[idx], y_val[idx])
    models.append(m)

base_model = SimpleModel()
selected, merged_model = greedy_soup(models, (X_val, y_val), base_model)

## 3. Task Arithmetic

### 3.1 Task Vectors 概念

```
┌─────────────────────────────────────────────────────────────────────────┐
│                        Task Arithmetic 概念                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  核心想法：微調可以表示為「任務向量」的加法                                │
│                                                                         │
│  Task Vector 定義：                                                      │
│  ─────────────────                                                      │
│  τ_task = θ_finetuned - θ_pretrained                                   │
│                                                                         │
│  τ 代表了從 pretrained 到 task-specific 模型的「方向」                    │
│                                                                         │
│  操作：                                                                  │
│  ──────                                                                 │
│                                                                         │
│  1. 任務加法 (Adding tasks)                                             │
│     θ_multi = θ_pre + τ_A + τ_B                                        │
│     → 獲得 Task A 和 Task B 的能力                                       │
│                                                                         │
│  2. 任務否定 (Negating tasks)                                           │
│     θ_new = θ_pre - τ_toxic                                            │
│     → 減少模型的有毒輸出傾向                                             │
│                                                                         │
│  3. 縮放 (Scaling)                                                      │
│     θ_scaled = θ_pre + λ * τ                                           │
│     → λ 控制任務能力的強度                                               │
│                                                                         │
│  視覺化：                                                                │
│  ┌─────────────────────────────────────────────────────────────┐       │
│  │                                                             │       │
│  │                           * θ_A (Math model)                │       │
│  │                          /                                  │       │
│  │                         / τ_A                               │       │
│  │                        /                                    │       │
│  │    θ_pre * ─────────────────────────* θ_A+B (Multi-task)   │       │
│  │                        \             /                      │       │
│  │                         \ τ_B       / τ_A + τ_B             │       │
│  │                          \         /                        │       │
│  │                           * θ_B (Code model)                │       │
│  │                                                             │       │
│  └─────────────────────────────────────────────────────────────┘       │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
```

In [None]:
def compute_task_vector(
    pretrained_state: OrderedDict,
    finetuned_state: OrderedDict
) -> OrderedDict:
    """
    計算 Task Vector: τ = θ_finetuned - θ_pretrained
    """
    task_vector = OrderedDict()
    
    for key in pretrained_state:
        task_vector[key] = finetuned_state[key] - pretrained_state[key]
    
    return task_vector


def apply_task_vector(
    base_state: OrderedDict,
    task_vector: OrderedDict,
    scaling_factor: float = 1.0
) -> OrderedDict:
    """
    應用 Task Vector: θ_new = θ_base + λ * τ
    """
    new_state = OrderedDict()
    
    for key in base_state:
        new_state[key] = base_state[key] + scaling_factor * task_vector[key]
    
    return new_state


def add_task_vectors(
    task_vectors: List[OrderedDict],
    weights: Optional[List[float]] = None
) -> OrderedDict:
    """
    合併多個 Task Vectors
    
    τ_combined = Σ w_i * τ_i
    """
    if weights is None:
        weights = [1.0] * len(task_vectors)
    
    combined = OrderedDict()
    
    for key in task_vectors[0]:
        combined[key] = torch.zeros_like(task_vectors[0][key])
        for tv, w in zip(task_vectors, weights):
            combined[key] += w * tv[key]
    
    return combined


# 示範 Task Arithmetic
print("Task Arithmetic 示範")
print("="*60)

# 創建 pretrained 模型
torch.manual_seed(0)
pretrained = SimpleModel()
pretrained_state = deepcopy(pretrained.state_dict())

# 創建 Task A 微調模型（例如：數學任務）
torch.manual_seed(1)
finetuned_a = SimpleModel()
# 模擬微調（實際上會用真實數據）
for param in finetuned_a.parameters():
    param.data += torch.randn_like(param) * 0.1

# 創建 Task B 微調模型（例如：程式碼任務）
torch.manual_seed(2)
finetuned_b = SimpleModel()
for param in finetuned_b.parameters():
    param.data += torch.randn_like(param) * 0.1

# 計算 Task Vectors
task_vector_a = compute_task_vector(pretrained_state, finetuned_a.state_dict())
task_vector_b = compute_task_vector(pretrained_state, finetuned_b.state_dict())

print(f"Task Vector A (fc1.weight) norm: {task_vector_a['fc1.weight'].norm().item():.4f}")
print(f"Task Vector B (fc1.weight) norm: {task_vector_b['fc1.weight'].norm().item():.4f}")

# 合併 Task Vectors
combined_tv = add_task_vectors([task_vector_a, task_vector_b], weights=[0.5, 0.5])
print(f"Combined Task Vector norm: {combined_tv['fc1.weight'].norm().item():.4f}")

# 應用到 pretrained 模型
merged_state = apply_task_vector(pretrained_state, combined_tv, scaling_factor=1.0)
print(f"Merged model fc1.weight mean: {merged_state['fc1.weight'].mean().item():.4f}")

In [None]:
# 視覺化不同 scaling factor 的效果
def visualize_task_arithmetic():
    """視覺化 Task Arithmetic 的概念"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 1. Task Vector 在 2D 空間的表示
    ax = axes[0]
    
    # 原點 = pretrained
    origin = np.array([0, 0])
    
    # Task vectors
    tv_a = np.array([2, 1])   # Math task
    tv_b = np.array([0.5, 2]) # Code task
    tv_combined = tv_a + tv_b
    
    # 繪製向量
    ax.arrow(0, 0, tv_a[0], tv_a[1], head_width=0.15, head_length=0.1, fc='blue', ec='blue', label='τ_math')
    ax.arrow(0, 0, tv_b[0], tv_b[1], head_width=0.15, head_length=0.1, fc='red', ec='red', label='τ_code')
    ax.arrow(0, 0, tv_combined[0], tv_combined[1], head_width=0.15, head_length=0.1, fc='green', ec='green', label='τ_combined')
    
    # 標記點
    ax.scatter([0], [0], s=100, c='black', zorder=5, label='θ_pretrained')
    ax.scatter([tv_a[0]], [tv_a[1]], s=100, c='blue', zorder=5)
    ax.scatter([tv_b[0]], [tv_b[1]], s=100, c='red', zorder=5)
    ax.scatter([tv_combined[0]], [tv_combined[1]], s=100, c='green', zorder=5)
    
    # 虛線表示組合路徑
    ax.plot([tv_a[0], tv_combined[0]], [tv_a[1], tv_combined[1]], 'g--', alpha=0.5)
    ax.plot([tv_b[0], tv_combined[0]], [tv_b[1], tv_combined[1]], 'g--', alpha=0.5)
    
    ax.set_xlim([-0.5, 4])
    ax.set_ylim([-0.5, 4])
    ax.set_xlabel('Weight Space Dimension 1')
    ax.set_ylabel('Weight Space Dimension 2')
    ax.set_title('Task Arithmetic: Adding Task Vectors')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')
    
    # 2. Scaling Factor 的效果
    ax2 = axes[1]
    
    scaling_factors = np.linspace(0, 2, 50)
    task_performance = 1 - np.exp(-scaling_factors) + 0.1 * np.random.randn(50) * 0.1  # 模擬
    general_performance = 1 - 0.3 * scaling_factors + 0.1 * np.random.randn(50) * 0.1  # 模擬
    
    ax2.plot(scaling_factors, task_performance, 'b-', linewidth=2, label='Task Performance')
    ax2.plot(scaling_factors, general_performance, 'r-', linewidth=2, label='General Performance')
    
    # 標記最佳點
    combined = task_performance + general_performance
    best_idx = np.argmax(combined)
    ax2.axvline(x=scaling_factors[best_idx], color='green', linestyle='--', label=f'Optimal λ={scaling_factors[best_idx]:.2f}')
    
    ax2.set_xlabel('Scaling Factor (λ)')
    ax2.set_ylabel('Performance')
    ax2.set_title('Effect of Scaling Factor')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\n關鍵觀察：")
    print("1. Task vectors 可以在權重空間中相加")
    print("2. Scaling factor λ 控制任務能力的強度")
    print("3. 需要平衡任務性能和通用能力")

visualize_task_arithmetic()

## 4. TIES-Merging

### 4.1 TIES 方法概覽

TIES = **T**rim, **I**ncrease magnitude, **E**lect **S**ign

```
┌─────────────────────────────────────────────────────────────────────────┐
│                        TIES-Merging 流程                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  問題：直接合併 task vectors 會有干擾（interference）                     │
│                                                                         │
│  解決方案：TIES 三步驟                                                   │
│                                                                         │
│  Step 1: TRIM - 修剪小值                                                │
│  ────────────────────────                                               │
│  • 將 magnitude 較小的參數設為 0                                         │
│  • 保留 top-k% 最大的參數                                                │
│  • 減少噪音影響                                                          │
│                                                                         │
│  Step 2: ELECT SIGN - 選擇符號                                          │
│  ─────────────────────────                                              │
│  • 對於每個參數位置，統計各 task vector 的符號                            │
│  • 選擇「多數票」的符號作為最終符號                                       │
│  • 只保留符號一致的值                                                    │
│                                                                         │
│  Step 3: DISJOINT MERGE - 合併                                          │
│  ─────────────────────────                                              │
│  • 對於每個參數位置，平均所有符號一致的值                                  │
│  • 忽略符號不一致的值（它們會互相抵消）                                    │
│                                                                         │
│  視覺化範例：                                                            │
│  ┌─────────────────────────────────────────────────────────────┐       │
│  │  Parameter position i:                                      │       │
│  │                                                             │       │
│  │  τ_A[i] = +0.5   ─┐                                        │       │
│  │  τ_B[i] = +0.3   ─┼─→ 符號一致 (+) → merged = avg(+0.5, +0.3) = +0.4│
│  │  τ_C[i] = -0.2   ─┘     被忽略（少數）                       │       │
│  │                                                             │       │
│  └─────────────────────────────────────────────────────────────┘       │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
```

In [None]:
def ties_merging(
    task_vectors: List[OrderedDict],
    trim_ratio: float = 0.2,
    scaling_factor: float = 1.0
) -> OrderedDict:
    """
    TIES-Merging 實作
    
    Args:
        task_vectors: 多個 task vectors
        trim_ratio: 要修剪掉的比例 (0-1)
        scaling_factor: 最終縮放係數
    """
    merged = OrderedDict()
    
    for key in task_vectors[0]:
        # 收集所有 task vectors 在這個參數位置的值
        values = torch.stack([tv[key] for tv in task_vectors])  # [num_tasks, ...]
        
        # Step 1: TRIM - 修剪小值
        trimmed_values = []
        for tv in values:
            # 計算閾值
            magnitudes = tv.abs()
            threshold = torch.quantile(magnitudes.flatten(), trim_ratio)
            
            # 修剪
            trimmed = tv.clone()
            trimmed[magnitudes < threshold] = 0
            trimmed_values.append(trimmed)
        
        trimmed_values = torch.stack(trimmed_values)  # [num_tasks, ...]
        
        # Step 2: ELECT SIGN - 選擇符號
        # 計算每個位置的符號投票
        signs = torch.sign(trimmed_values)  # [num_tasks, ...]
        sign_sum = signs.sum(dim=0)  # 正數表示正號多，負數表示負號多
        elected_sign = torch.sign(sign_sum)  # 多數票符號
        
        # 處理平票的情況（設為 0）
        elected_sign[sign_sum == 0] = 0
        
        # Step 3: DISJOINT MERGE - 只合併符號一致的值
        # 創建 mask：只保留符號與 elected_sign 一致的值
        mask = (signs == elected_sign.unsqueeze(0)) & (trimmed_values != 0)
        
        # 計算符合條件的值的平均
        masked_values = trimmed_values * mask.float()
        count = mask.float().sum(dim=0).clamp(min=1)  # 避免除以零
        merged_param = masked_values.sum(dim=0) / count
        
        merged[key] = scaling_factor * merged_param
    
    return merged


# 測試 TIES-Merging
print("TIES-Merging 示範")
print("="*60)

# 使用之前創建的 task vectors
ties_merged = ties_merging(
    [task_vector_a, task_vector_b],
    trim_ratio=0.2,
    scaling_factor=1.0
)

print(f"TIES merged fc1.weight norm: {ties_merged['fc1.weight'].norm().item():.4f}")
print(f"Simple average norm: {combined_tv['fc1.weight'].norm().item():.4f}")

# 比較稀疏性
ties_sparsity = (ties_merged['fc1.weight'] == 0).float().mean().item()
simple_sparsity = (combined_tv['fc1.weight'] == 0).float().mean().item()
print(f"\nTIES sparsity: {ties_sparsity:.2%}")
print(f"Simple average sparsity: {simple_sparsity:.2%}")

In [None]:
def dare_merging(
    task_vectors: List[OrderedDict],
    drop_rate: float = 0.9,
    scaling_factor: float = 1.0
) -> OrderedDict:
    """
    DARE (Drop And REscale) Merging
    
    另一種減少干擾的方法：
    1. 隨機 drop 一部分參數
    2. 重新縮放保留的參數
    3. 合併
    """
    merged = OrderedDict()
    rescale = 1.0 / (1.0 - drop_rate)
    
    for key in task_vectors[0]:
        values = torch.stack([tv[key] for tv in task_vectors])
        
        # 隨機 drop 並重新縮放
        dropped_values = []
        for tv in values:
            mask = torch.bernoulli(torch.ones_like(tv) * (1 - drop_rate))
            dropped = tv * mask * rescale
            dropped_values.append(dropped)
        
        dropped_values = torch.stack(dropped_values)
        merged[key] = scaling_factor * dropped_values.mean(dim=0)
    
    return merged


# 測試 DARE
print("\nDARE Merging 示範")
print("="*60)

dare_merged = dare_merging(
    [task_vector_a, task_vector_b],
    drop_rate=0.9,
    scaling_factor=1.0
)

print(f"DARE merged fc1.weight norm: {dare_merged['fc1.weight'].norm().item():.4f}")

## 5. 使用 mergekit

### 5.1 mergekit 簡介

mergekit 是一個實用的模型合併工具，支援多種合併方法。

```yaml
# mergekit 配置文件範例
merge_method: ties
slices:
  - sources:
      - model: base_model
        layer_range: [0, 32]
      - model: math_lora
        layer_range: [0, 32]
      - model: code_lora
        layer_range: [0, 32]
base_model: base_model
parameters:
  density: 0.5     # TIES 保留比例
  weight: 1.0      # 縮放係數
dtype: float16
```

In [None]:
# 模擬 mergekit 風格的 API
class ModelMerger:
    """
    模型合併器（模擬 mergekit 功能）
    """
    METHODS = ['linear', 'slerp', 'ties', 'dare', 'task_arithmetic']
    
    def __init__(self, base_model: nn.Module):
        self.base_model = base_model
        self.base_state = deepcopy(base_model.state_dict())
    
    def merge(
        self,
        models: List[nn.Module],
        method: str = 'linear',
        weights: Optional[List[float]] = None,
        **kwargs
    ) -> nn.Module:
        """
        合併多個模型
        
        Args:
            models: 要合併的模型列表
            method: 合併方法
            weights: 各模型的權重
            **kwargs: 方法特定的參數
        """
        if method not in self.METHODS:
            raise ValueError(f"Unknown method: {method}. Available: {self.METHODS}")
        
        print(f"Merging {len(models)} models using {method} method...")
        
        if method == 'linear':
            if weights is None:
                merged_state = uniform_averaging(models)
            else:
                merged_state = weighted_averaging(models, weights)
        
        elif method == 'task_arithmetic':
            task_vectors = [
                compute_task_vector(self.base_state, m.state_dict())
                for m in models
            ]
            combined_tv = add_task_vectors(task_vectors, weights)
            scaling = kwargs.get('scaling_factor', 1.0)
            merged_state = apply_task_vector(self.base_state, combined_tv, scaling)
        
        elif method == 'ties':
            task_vectors = [
                compute_task_vector(self.base_state, m.state_dict())
                for m in models
            ]
            trim_ratio = kwargs.get('trim_ratio', 0.2)
            scaling = kwargs.get('scaling_factor', 1.0)
            merged_tv = ties_merging(task_vectors, trim_ratio, scaling)
            merged_state = apply_task_vector(self.base_state, merged_tv, 1.0)
        
        elif method == 'dare':
            task_vectors = [
                compute_task_vector(self.base_state, m.state_dict())
                for m in models
            ]
            drop_rate = kwargs.get('drop_rate', 0.9)
            scaling = kwargs.get('scaling_factor', 1.0)
            merged_tv = dare_merging(task_vectors, drop_rate, scaling)
            merged_state = apply_task_vector(self.base_state, merged_tv, 1.0)
        
        elif method == 'slerp':
            # 球面線性插值（用於兩個模型）
            if len(models) != 2:
                raise ValueError("SLERP requires exactly 2 models")
            t = kwargs.get('t', 0.5)
            merged_state = self._slerp(models[0].state_dict(), models[1].state_dict(), t)
        
        # 創建合併後的模型
        merged_model = deepcopy(self.base_model)
        merged_model.load_state_dict(merged_state)
        
        print("Merge complete!")
        return merged_model
    
    def _slerp(self, state_a: OrderedDict, state_b: OrderedDict, t: float) -> OrderedDict:
        """
        球面線性插值 (Spherical Linear Interpolation)
        
        對於單位向量 v0, v1 和插值參數 t:
        slerp(v0, v1, t) = sin((1-t)Ω)/sin(Ω) * v0 + sin(tΩ)/sin(Ω) * v1
        其中 Ω = arccos(v0·v1)
        """
        result = OrderedDict()
        
        for key in state_a:
            a = state_a[key].flatten().float()
            b = state_b[key].flatten().float()
            
            # 正規化
            a_norm = a / (a.norm() + 1e-8)
            b_norm = b / (b.norm() + 1e-8)
            
            # 計算角度
            dot = torch.clamp((a_norm * b_norm).sum(), -1, 1)
            omega = torch.acos(dot)
            
            if omega.abs() < 1e-6:
                # 角度很小，用線性插值
                result[key] = ((1 - t) * state_a[key] + t * state_b[key])
            else:
                # SLERP
                sin_omega = torch.sin(omega)
                coef_a = torch.sin((1 - t) * omega) / sin_omega
                coef_b = torch.sin(t * omega) / sin_omega
                result[key] = (coef_a * state_a[key] + coef_b * state_b[key])
        
        return result


# 使用範例
print("\n" + "="*60)
print("ModelMerger 使用範例")
print("="*60)

merger = ModelMerger(pretrained)

# 測試不同方法
for method in ['linear', 'task_arithmetic', 'ties']:
    print(f"\n--- {method.upper()} ---")
    merged = merger.merge(
        [finetuned_a, finetuned_b],
        method=method,
        weights=[0.5, 0.5],
        scaling_factor=1.0,
        trim_ratio=0.2
    )
    print(f"  Merged model fc1.weight mean: {merged.fc1.weight.data.mean().item():.4f}")

## 6. LoRA 模型合併

### 6.1 合併多個 LoRA Adapters

In [None]:
class LoRALinear(nn.Module):
    """簡化的 LoRA Linear 層"""
    def __init__(self, in_features, out_features, rank=8, alpha=16):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.scaling = alpha / rank
    
    def forward(self, x):
        base_out = self.linear(x)
        lora_out = (x @ self.lora_A.T @ self.lora_B.T) * self.scaling
        return base_out + lora_out
    
    def get_merged_weight(self):
        """將 LoRA 權重合併到基礎權重"""
        delta_w = (self.lora_B @ self.lora_A) * self.scaling
        return self.linear.weight.data + delta_w
    
    def get_lora_weight(self):
        """取得 LoRA delta 權重"""
        return (self.lora_B @ self.lora_A) * self.scaling


def merge_lora_adapters(
    base_linear: nn.Linear,
    lora_adapters: List[Tuple[torch.Tensor, torch.Tensor, float]],
    weights: Optional[List[float]] = None,
    method: str = 'linear'
) -> torch.Tensor:
    """
    合併多個 LoRA adapters
    
    Args:
        base_linear: 基礎 Linear 層
        lora_adapters: [(lora_A, lora_B, scaling), ...]
        weights: 各 adapter 的權重
        method: 合併方法
    
    Returns:
        合併後的權重
    """
    if weights is None:
        weights = [1.0 / len(lora_adapters)] * len(lora_adapters)
    
    # 計算所有 LoRA delta
    deltas = []
    for lora_A, lora_B, scaling in lora_adapters:
        delta = (lora_B @ lora_A) * scaling
        deltas.append(delta)
    
    if method == 'linear':
        # 加權平均
        merged_delta = sum(w * d for w, d in zip(weights, deltas))
    
    elif method == 'cat':
        # 串接（需要特殊處理）
        # 這只是概念展示，實際上串接會改變維度
        merged_delta = sum(deltas) / len(deltas)
    
    return base_linear.weight.data + merged_delta


# 示範 LoRA 合併
print("LoRA Adapter 合併示範")
print("="*60)

# 創建基礎層和多個 LoRA adapters
in_dim, out_dim, rank = 64, 32, 8
base_linear = nn.Linear(in_dim, out_dim, bias=False)

# 模擬不同任務的 LoRA adapters
lora_adapters = []
for i in range(3):
    torch.manual_seed(i + 100)
    lora_A = torch.randn(rank, in_dim) * 0.01
    lora_B = torch.randn(out_dim, rank) * 0.01
    scaling = 16 / rank
    lora_adapters.append((lora_A, lora_B, scaling))

# 合併
merged_weight = merge_lora_adapters(
    base_linear,
    lora_adapters,
    weights=[0.4, 0.3, 0.3],
    method='linear'
)

print(f"Base weight shape: {base_linear.weight.shape}")
print(f"Merged weight shape: {merged_weight.shape}")
print(f"Weight change norm: {(merged_weight - base_linear.weight.data).norm().item():.4f}")

## 7. 練習題

### 練習 1：實作自適應權重選擇

In [None]:
# 練習 1：實作基於驗證集表現的自適應權重
def adaptive_weight_search(
    models: List[nn.Module],
    val_data: Tuple[torch.Tensor, torch.Tensor],
    base_model: nn.Module,
    num_samples: int = 50
) -> Tuple[List[float], nn.Module]:
    """
    TODO: 搜尋最佳的合併權重
    
    方法：
    1. 隨機採樣多組權重（使用 Dirichlet 分布）
    2. 對每組權重，合併模型並在驗證集評估
    3. 選擇最佳權重
    
    Args:
        models: 要合併的模型
        val_data: (X_val, y_val)
        base_model: 用於創建合併模型的模板
        num_samples: 要嘗試的權重組合數量
    
    Returns:
        (best_weights, best_model)
    """
    # 提示：
    # 1. 使用 np.random.dirichlet 生成權重
    # 2. 使用 weighted_averaging 合併
    # 3. 評估合併模型
    pass

print("練習 1：實作 adaptive_weight_search 函數")

### 練習 2：實作 Layer-wise 合併

In [None]:
# 練習 2：對不同層使用不同的合併策略
def layerwise_merge(
    models: List[nn.Module],
    layer_configs: Dict[str, Dict]
) -> OrderedDict:
    """
    TODO: 對不同層使用不同的合併配置
    
    想法：
    - 底層（embedding）可能需要更保守的合併
    - 頂層可能可以更激進地合併
    
    Args:
        models: 要合併的模型
        layer_configs: {
            "fc1": {"method": "ties", "trim_ratio": 0.3},
            "fc2": {"method": "linear", "weights": [0.5, 0.5]},
            "fc3": {"method": "slerp", "t": 0.5}
        }
    
    Returns:
        合併後的 state_dict
    """
    # 提示：
    # 1. 遍歷每個參數
    # 2. 根據層名稱查找對應的配置
    # 3. 應用對應的合併方法
    pass

print("練習 2：實作 layerwise_merge 函數")

### 練習 3：分析合併效果

In [None]:
# 練習 3：分析不同合併方法的效果
def analyze_merge_quality(
    original_models: List[nn.Module],
    merged_model: nn.Module,
    test_datasets: Dict[str, Tuple[torch.Tensor, torch.Tensor]]
) -> Dict:
    """
    TODO: 分析合併模型的品質
    
    分析項目：
    1. 各任務的性能保持程度
    2. 權重干擾程度
    3. 與原始模型的距離
    
    Args:
        original_models: 原始的專門模型
        merged_model: 合併後的模型
        test_datasets: {"task_a": (X, y), "task_b": (X, y), ...}
    
    Returns:
        分析結果字典
    """
    # 提示：
    # 1. 評估每個任務的準確率
    # 2. 計算權重變化的統計量
    # 3. 繪製比較圖表
    pass

print("練習 3：實作 analyze_merge_quality 函數")

## 8. 總結

### 8.1 合併方法比較

```
┌─────────────────────────────────────────────────────────────────────────┐
│                      模型合併方法比較                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  方法               │ 優點                    │ 缺點                    │
│  ───────────────────────────────────────────────────────────────────── │
│  Linear Average    │ 簡單、快速              │ 可能有嚴重干擾          │
│                    │ 不需要額外計算          │ 不同任務可能衝突        │
│  ───────────────────────────────────────────────────────────────────── │
│  Task Arithmetic   │ 概念清晰               │ 需要調整 scaling        │
│                    │ 可以做加減法            │ 干擾問題仍存在          │
│  ───────────────────────────────────────────────────────────────────── │
│  TIES              │ 減少干擾               │ 計算較複雜              │
│                    │ 符號一致性             │ 可能丟失有用資訊        │
│  ───────────────────────────────────────────────────────────────────── │
│  DARE              │ 簡單有效               │ 隨機性                  │
│                    │ 減少冗餘               │ 需要多次嘗試            │
│  ───────────────────────────────────────────────────────────────────── │
│  SLERP             │ 平滑插值               │ 只適用於兩個模型        │
│                    │ 保持範數               │                         │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
```

### 8.2 實際應用建議

1. **相似任務**：Linear Average 或 Model Soups
2. **不同任務**：TIES 或 Task Arithmetic
3. **多個 LoRA**：先評估各自性能，再選擇合併策略
4. **生產環境**：務必在驗證集上評估合併效果

In [None]:
print("="*60)
print("模型合併 - 學習完成！")
print("="*60)
print("\n你已經學會：")
print("✓ 理解模型合併的動機與挑戰")
print("✓ Weight Averaging 和 Model Soups")
print("✓ Task Arithmetic（任務向量加法）")
print("✓ TIES-Merging 減少干擾")
print("✓ LoRA Adapter 合併")
print("\n下一步學習建議：")
print("1. 使用 mergekit 合併真實的 LLM")
print("2. 探索 frankenmerge（混合不同層）")
print("3. 研究 Layer-wise 合併策略")
print("4. 了解合併與 continual learning 的關係")