# 災難性遺忘與持續學習 (Catastrophic Forgetting & Continual Learning)

**對應課程**: 李宏毅 2025 Spring ML HW6, 2025 Fall GenAI-ML HW8

本 notebook 探討神經網路在連續學習多個任務時的災難性遺忘問題，以及相應的解決方案。

## 學習目標
1. 理解災難性遺忘的原因與現象
2. 學會量化測量遺忘程度
3. 實作 EWC (Elastic Weight Consolidation)
4. 了解 LoRA 與遺忘的關係
5. 掌握持續學習策略

## Part 1: 什麼是災難性遺忘？

### 1.1 問題定義

```
┌─────────────────────────────────────────────────────────────────┐
│                    災難性遺忘 (Catastrophic Forgetting)          │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  定義: 當神經網路學習新任務時，會快速且嚴重地忘記先前學習的任務  │
│                                                                 │
│  時間軸:                                                        │
│  ┌──────────┐     ┌──────────┐     ┌──────────┐                │
│  │ Task A   │ →   │ Task B   │ →   │ Task C   │                │
│  │ (學習)   │     │ (學習)   │     │ (學習)   │                │
│  └──────────┘     └──────────┘     └──────────┘                │
│       ↓                 ↓                 ↓                     │
│  A: 100%           A: 30%↓          A: 10%↓↓                   │
│                    B: 100%          B: 40%↓                     │
│                                     C: 100%                     │
│                                                                 │
│  問題: 學完 Task C 後，Task A 和 B 的效能嚴重下降                │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

### 1.2 為什麼會發生？

```
┌─────────────────────────────────────────────────────────────────┐
│                    遺忘的根本原因                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. 權重共享 (Shared Parameters)                                │
│     • 神經網路使用相同的參數處理所有任務                         │
│     • 新任務的梯度更新會覆蓋舊任務學到的特徵                     │
│                                                                 │
│  2. 參數空間干擾 (Interference)                                 │
│                                                                 │
│        參數空間                                                 │
│           ▲                                                     │
│           │      ★ Task A 最優解                                │
│           │    ╱                                                │
│           │   ╱  ← 梯度更新路徑                                  │
│           │  ╱                                                  │
│           │ ●────→ ◆ Task B 最優解                              │
│           │     遠離 Task A 最優解                               │
│           └──────────────────→                                 │
│                                                                 │
│  3. 分布偏移 (Distribution Shift)                               │
│     • 不同任務的資料分布不同                                    │
│     • 模型適應新分布，忘記舊分布                                │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

In [None]:
# 環境設置
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
from copy import deepcopy
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用設備: {device}")

In [None]:
# 建立簡單的分類任務來示範災難性遺忘

class SimpleMLP(nn.Module):
    """簡單的 MLP 用於示範"""
    def __init__(self, input_dim=784, hidden_dim=256, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def create_permuted_mnist_tasks(num_tasks=3):
    """
    創建 Permuted MNIST 任務
    每個任務使用不同的像素排列
    """
    from torchvision import datasets, transforms
    
    # 載入 MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    try:
        train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
        test_data = datasets.MNIST('./data', train=False, transform=transform)
    except:
        print("無法下載 MNIST，使用合成資料")
        # 合成資料
        X_train = torch.randn(10000, 1, 28, 28)
        y_train = torch.randint(0, 10, (10000,))
        X_test = torch.randn(2000, 1, 28, 28)
        y_test = torch.randint(0, 10, (2000,))
        train_data = TensorDataset(X_train, y_train)
        test_data = TensorDataset(X_test, y_test)
    
    tasks = []
    
    for task_id in range(num_tasks):
        if task_id == 0:
            # 第一個任務使用原始排列
            perm = torch.arange(784)
        else:
            # 其他任務使用隨機排列
            perm = torch.randperm(784)
        
        tasks.append({
            'name': f'Task {task_id}',
            'permutation': perm,
            'train_data': train_data,
            'test_data': test_data
        })
    
    return tasks

def apply_permutation(x, perm):
    """應用像素排列"""
    batch_size = x.size(0)
    x_flat = x.view(batch_size, -1)
    x_perm = x_flat[:, perm]
    return x_perm.view(batch_size, 1, 28, 28)

print("任務創建函數已定義")

## Part 2: 觀察災難性遺忘現象

In [None]:
# 訓練與評估函數

def train_epoch(model, dataloader, optimizer, perm, device):
    """訓練一個 epoch"""
    model.train()
    total_loss = 0
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        x = apply_permutation(x, perm.to(device))
        
        optimizer.zero_grad()
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, perm, device):
    """評估準確率"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            x = apply_permutation(x, perm.to(device))
            
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    
    return correct / total

def sequential_training(model, tasks, epochs_per_task=3, batch_size=128, lr=0.001):
    """
    Sequential training: 依序訓練每個任務（會導致遺忘）
    """
    history = {task['name']: [] for task in tasks}
    
    for task_idx, task in enumerate(tasks):
        print(f"\n訓練 {task['name']}...")
        
        # 建立 DataLoader（使用子集以加快示範）
        subset_indices = list(range(min(5000, len(task['train_data']))))
        train_subset = Subset(task['train_data'], subset_indices)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        
        test_subset_indices = list(range(min(1000, len(task['test_data']))))
        test_subsets = [Subset(t['test_data'], test_subset_indices) for t in tasks]
        test_loaders = [DataLoader(ts, batch_size=batch_size) for ts in test_subsets]
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        for epoch in range(epochs_per_task):
            train_epoch(model, train_loader, optimizer, task['permutation'], device)
            
            # 評估所有任務
            for t_idx, t in enumerate(tasks[:task_idx+1]):
                acc = evaluate(model, test_loaders[t_idx], t['permutation'], device)
                history[t['name']].append(acc)
        
        # 打印當前各任務準確率
        print(f"  各任務準確率:")
        for t_idx, t in enumerate(tasks[:task_idx+1]):
            print(f"    {t['name']}: {history[t['name']][-1]:.2%}")
    
    return history

print("訓練函數已定義")

In [None]:
# 示範災難性遺忘
print("=== 示範災難性遺忘 ===")

# 創建任務
tasks = create_permuted_mnist_tasks(num_tasks=3)

# 建立模型
model = SimpleMLP().to(device)

# Sequential training（會導致遺忘）
history = sequential_training(model, tasks, epochs_per_task=3)

In [None]:
# 視覺化遺忘現象
def visualize_forgetting(history, title="Catastrophic Forgetting"):
    """視覺化各任務準確率變化"""
    
    plt.figure(figsize=(12, 5))
    
    colors = ['steelblue', 'coral', 'seagreen', 'purple', 'orange']
    
    for i, (task_name, accs) in enumerate(history.items()):
        if accs:  # 確保有資料
            plt.plot(accs, 'o-', label=task_name, color=colors[i % len(colors)], markersize=4)
    
    # 標註任務切換點
    task_boundaries = []
    cumulative = 0
    for i, (_, accs) in enumerate(history.items()):
        if i > 0 and accs:
            task_boundaries.append(cumulative)
        cumulative = len(list(history.values())[0])  # 取第一個任務的長度作為參考
    
    plt.xlabel('Training Steps')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.legend(loc='lower left')
    plt.grid(alpha=0.3)
    plt.ylim(0, 1.05)
    
    plt.tight_layout()
    plt.show()

visualize_forgetting(history, "Sequential Training: Catastrophic Forgetting")

print("\n觀察:")
print("- 學習新任務時，舊任務的準確率急劇下降")
print("- 這就是災難性遺忘現象")

## Part 3: 遺忘的量化測量

In [None]:
def compute_forgetting_metrics(history: Dict[str, List[float]]) -> Dict:
    """
    計算遺忘相關指標
    
    1. 平均遺忘 (Average Forgetting)
    2. 最大遺忘 (Maximum Forgetting)
    3. 後向轉移 (Backward Transfer)
    """
    metrics = {}
    
    forgetting_scores = []
    
    for task_name, accs in history.items():
        if len(accs) > 1:
            max_acc = max(accs)  # 該任務曾達到的最高準確率
            final_acc = accs[-1]  # 最終準確率
            forgetting = max_acc - final_acc
            forgetting_scores.append(forgetting)
    
    if forgetting_scores:
        metrics['average_forgetting'] = np.mean(forgetting_scores)
        metrics['max_forgetting'] = max(forgetting_scores)
        metrics['forgetting_per_task'] = {
            name: max(accs) - accs[-1] if len(accs) > 1 else 0
            for name, accs in history.items()
        }
    
    # 計算平均準確率
    final_accs = [accs[-1] for accs in history.values() if accs]
    metrics['average_accuracy'] = np.mean(final_accs) if final_accs else 0
    
    return metrics

# 計算並打印指標
metrics = compute_forgetting_metrics(history)

print("=== 遺忘指標 ===")
print(f"平均遺忘: {metrics.get('average_forgetting', 0):.2%}")
print(f"最大遺忘: {metrics.get('max_forgetting', 0):.2%}")
print(f"最終平均準確率: {metrics.get('average_accuracy', 0):.2%}")
print("\n各任務遺忘:")
for task, forgetting in metrics.get('forgetting_per_task', {}).items():
    print(f"  {task}: {forgetting:.2%}")

## Part 4: EWC (Elastic Weight Consolidation)

### 4.1 EWC 原理

EWC 的核心想法：保護對舊任務重要的參數。

$$L_{EWC} = L_{new}(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta^*_i)^2$$

其中：
- $L_{new}$: 新任務的損失
- $F_i$: Fisher 資訊矩陣（衡量參數重要性）
- $\theta^*$: 舊任務訓練後的參數
- $\lambda$: 正則化強度

```
┌─────────────────────────────────────────────────────────────────┐
│                         EWC 示意圖                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│        參數空間                                                 │
│           ▲                                                     │
│           │      ★ Task A 最優解                                │
│           │    ╱ ╲                                              │
│           │   ╱   ╲  受 EWC 約束的更新路徑                      │
│           │  ╱     ╲                                            │
│           │ ●───────◆ Task B 最優解（在 A 附近）                │
│           │                                                     │
│           │ EWC 限制參數偏離 Task A 最優解的程度                │
│           └──────────────────────────────→                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

In [None]:
class EWC:
    """
    Elastic Weight Consolidation (EWC)
    
    參考: "Overcoming catastrophic forgetting in neural networks" (Kirkpatrick et al., 2017)
    """
    
    def __init__(self, model: nn.Module, dataloader, perm, device, 
                 num_samples: int = 200):
        """
        Args:
            model: 已訓練好的模型
            dataloader: 用於計算 Fisher 資訊的資料
            perm: 該任務的像素排列
            device: 計算設備
            num_samples: 用於估計 Fisher 的樣本數
        """
        self.model = model
        self.device = device
        
        # 儲存參數名稱和值
        self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
        
        # 計算 Fisher 資訊矩陣
        self.fisher = self._compute_fisher(dataloader, perm, num_samples)
    
    def _compute_fisher(self, dataloader, perm, num_samples: int) -> Dict[str, torch.Tensor]:
        """
        計算 Fisher 資訊矩陣（對角近似）
        
        Fisher 資訊衡量參數對於輸出的重要性
        """
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad}
        
        self.model.eval()
        count = 0
        
        for x, y in dataloader:
            if count >= num_samples:
                break
            
            x = x.to(self.device)
            x = apply_permutation(x, perm.to(self.device))
            
            self.model.zero_grad()
            output = self.model(x)
            
            # 使用 log-likelihood 的梯度
            # Fisher = E[grad log p(y|x) @ grad log p(y|x).T]
            log_probs = F.log_softmax(output, dim=1)
            
            # 對每個樣本計算
            for i in range(x.size(0)):
                if count >= num_samples:
                    break
                    
                # 取預測類別的 log prob
                pred = output[i].argmax()
                loss = -log_probs[i, pred]
                
                self.model.zero_grad()
                loss.backward(retain_graph=True)
                
                # 累加梯度平方
                for n, p in self.model.named_parameters():
                    if p.requires_grad and p.grad is not None:
                        fisher[n] += p.grad.data.pow(2)
                
                count += 1
        
        # 取平均
        for n in fisher:
            fisher[n] /= num_samples
        
        return fisher
    
    def penalty(self, model: nn.Module) -> torch.Tensor:
        """
        計算 EWC penalty
        
        penalty = sum_i F_i * (theta_i - theta*_i)^2
        """
        loss = 0
        for n, p in model.named_parameters():
            if n in self.fisher:
                loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
        return loss

print("EWC 類別已定義")

In [None]:
def train_with_ewc(model, tasks, epochs_per_task=3, batch_size=128, 
                   lr=0.001, ewc_lambda=5000):
    """
    使用 EWC 進行連續學習
    """
    history = {task['name']: [] for task in tasks}
    ewc_tasks = []  # 儲存每個任務的 EWC 物件
    
    for task_idx, task in enumerate(tasks):
        print(f"\n訓練 {task['name']} (with EWC)...")
        
        # 建立 DataLoader
        subset_indices = list(range(min(5000, len(task['train_data']))))
        train_subset = Subset(task['train_data'], subset_indices)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        
        test_subset_indices = list(range(min(1000, len(task['test_data']))))
        test_subsets = [Subset(t['test_data'], test_subset_indices) for t in tasks]
        test_loaders = [DataLoader(ts, batch_size=batch_size) for ts in test_subsets]
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        for epoch in range(epochs_per_task):
            model.train()
            
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                x = apply_permutation(x, task['permutation'].to(device))
                
                optimizer.zero_grad()
                output = model(x)
                
                # 當前任務的損失
                loss = F.cross_entropy(output, y)
                
                # 加上 EWC penalty（對所有之前的任務）
                for ewc in ewc_tasks:
                    loss += (ewc_lambda / 2) * ewc.penalty(model)
                
                loss.backward()
                optimizer.step()
            
            # 評估所有任務
            for t_idx, t in enumerate(tasks[:task_idx+1]):
                acc = evaluate(model, test_loaders[t_idx], t['permutation'], device)
                history[t['name']].append(acc)
        
        # 訓練完當前任務後，計算並儲存 EWC
        ewc = EWC(model, train_loader, task['permutation'], device)
        ewc_tasks.append(ewc)
        
        # 打印當前各任務準確率
        print(f"  各任務準確率:")
        for t_idx, t in enumerate(tasks[:task_idx+1]):
            print(f"    {t['name']}: {history[t['name']][-1]:.2%}")
    
    return history

# 使用 EWC 訓練
print("=== 使用 EWC 減少遺忘 ===")
model_ewc = SimpleMLP().to(device)
history_ewc = train_with_ewc(model_ewc, tasks, epochs_per_task=3, ewc_lambda=5000)

In [None]:
# 比較有無 EWC 的效果
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# 無 EWC（之前的結果）
colors = ['steelblue', 'coral', 'seagreen']
for i, (task_name, accs) in enumerate(history.items()):
    if accs:
        ax1.plot(accs, 'o-', label=task_name, color=colors[i], markersize=4)
ax1.set_xlabel('Training Steps')
ax1.set_ylabel('Accuracy')
ax1.set_title('Without EWC (Catastrophic Forgetting)')
ax1.legend()
ax1.grid(alpha=0.3)
ax1.set_ylim(0, 1.05)

# 有 EWC
for i, (task_name, accs) in enumerate(history_ewc.items()):
    if accs:
        ax2.plot(accs, 'o-', label=task_name, color=colors[i], markersize=4)
ax2.set_xlabel('Training Steps')
ax2.set_ylabel('Accuracy')
ax2.set_title('With EWC (Reduced Forgetting)')
ax2.legend()
ax2.grid(alpha=0.3)
ax2.set_ylim(0, 1.05)

plt.tight_layout()
plt.show()

# 比較指標
metrics_no_ewc = compute_forgetting_metrics(history)
metrics_ewc = compute_forgetting_metrics(history_ewc)

print("\n=== 指標比較 ===")
print(f"平均遺忘: {metrics_no_ewc.get('average_forgetting', 0):.2%} → {metrics_ewc.get('average_forgetting', 0):.2%}")
print(f"最終平均準確率: {metrics_no_ewc.get('average_accuracy', 0):.2%} → {metrics_ewc.get('average_accuracy', 0):.2%}")

## Part 5: LoRA 與遺忘的關係

LoRA 微調有時可以減少遺忘，因為它只更新少量參數。

In [None]:
# LoRA 層（簡化版）
class LoRALinear(nn.Module):
    def __init__(self, original_linear: nn.Linear, rank: int = 4, alpha: float = 8.0):
        super().__init__()
        self.original = original_linear
        
        # 凍結原始權重
        for param in self.original.parameters():
            param.requires_grad = False
        
        # LoRA 參數
        in_features = original_linear.in_features
        out_features = original_linear.out_features
        
        self.lora_A = nn.Linear(in_features, rank, bias=False)
        self.lora_B = nn.Linear(rank, out_features, bias=False)
        self.scaling = alpha / rank
        
        nn.init.kaiming_uniform_(self.lora_A.weight)
        nn.init.zeros_(self.lora_B.weight)
    
    def forward(self, x):
        return self.original(x) + self.lora_B(self.lora_A(x)) * self.scaling

class SimpleMLP_LoRA(nn.Module):
    """帶 LoRA 的 MLP"""
    def __init__(self, base_model: SimpleMLP, lora_rank: int = 4):
        super().__init__()
        
        # 複製結構並添加 LoRA
        self.fc1 = LoRALinear(base_model.fc1, rank=lora_rank)
        self.fc2 = LoRALinear(base_model.fc2, rank=lora_rank)
        self.fc3 = LoRALinear(base_model.fc3, rank=lora_rank)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

print("LoRA MLP 已定義")

In [None]:
# 比較 LoRA 與 Full fine-tuning 的遺忘
print("=== LoRA vs Full Fine-tuning 遺忘比較 ===")

# 首先訓練一個基礎模型
base_model = SimpleMLP().to(device)
task0 = tasks[0]

# 在 Task 0 上訓練
subset_indices = list(range(min(5000, len(task0['train_data']))))
train_subset = Subset(task0['train_data'], subset_indices)
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)

optimizer = torch.optim.Adam(base_model.parameters(), lr=0.001)
for epoch in range(5):
    train_epoch(base_model, train_loader, optimizer, task0['permutation'], device)

# 評估初始效能
test_subset_indices = list(range(min(1000, len(task0['test_data']))))
test_subset = Subset(task0['test_data'], test_subset_indices)
test_loader = DataLoader(test_subset, batch_size=128)

initial_acc = evaluate(base_model, test_loader, task0['permutation'], device)
print(f"Task 0 初始準確率: {initial_acc:.2%}")

# 方法 1: Full fine-tuning on Task 1
model_full = deepcopy(base_model)
task1 = tasks[1]
train_loader1 = DataLoader(Subset(task1['train_data'], subset_indices), batch_size=128, shuffle=True)
optimizer = torch.optim.Adam(model_full.parameters(), lr=0.001)

for epoch in range(5):
    train_epoch(model_full, train_loader1, optimizer, task1['permutation'], device)

# 方法 2: LoRA fine-tuning on Task 1
model_lora = SimpleMLP_LoRA(deepcopy(base_model), lora_rank=8).to(device)
# 只優化 LoRA 參數
lora_params = [p for n, p in model_lora.named_parameters() if 'lora' in n]
optimizer = torch.optim.Adam(lora_params, lr=0.001)

for epoch in range(5):
    train_epoch(model_lora, train_loader1, optimizer, task1['permutation'], device)

# 比較 Task 0 的遺忘
acc_full_task0 = evaluate(model_full, test_loader, task0['permutation'], device)
acc_lora_task0 = evaluate(model_lora, test_loader, task0['permutation'], device)

# Task 1 準確率
test_loader1 = DataLoader(Subset(task1['test_data'], test_subset_indices), batch_size=128)
acc_full_task1 = evaluate(model_full, test_loader1, task1['permutation'], device)
acc_lora_task1 = evaluate(model_lora, test_loader1, task1['permutation'], device)

print(f"\n=== 結果比較 ===")
print(f"方法              | Task 0 準確率 | Task 1 準確率 | Task 0 遺忘")
print(f"-" * 60)
print(f"Full Fine-tuning  | {acc_full_task0:.2%}        | {acc_full_task1:.2%}        | {initial_acc - acc_full_task0:.2%}")
print(f"LoRA Fine-tuning  | {acc_lora_task0:.2%}        | {acc_lora_task1:.2%}        | {initial_acc - acc_lora_task0:.2%}")

## Part 6: 其他持續學習方法

### 6.1 方法概覽

```
┌────────────────────────────────────────────────────────────────┐
│                    持續學習方法分類                             │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│  1. 正則化方法 (Regularization-based)                          │
│     • EWC: 保護重要參數                                        │
│     • SI (Synaptic Intelligence): 累積參數重要性               │
│     • MAS (Memory Aware Synapses): 使用無標籤資料計算重要性    │
│                                                                │
│  2. 重放方法 (Replay-based)                                    │
│     • Experience Replay: 儲存舊任務樣本                        │
│     • Generative Replay: 用生成模型產生舊任務資料              │
│     • Pseudo-rehearsal: 用隨機輸入記錄模型行為                  │
│                                                                │
│  3. 架構方法 (Architecture-based)                              │
│     • Progressive Networks: 每個任務新增網路                   │
│     • PackNet: 剪枝後的空間給新任務                            │
│     • Adapter/LoRA: 每個任務一組小參數                         │
│                                                                │
│  比較:                                                         │
│  ┌────────────────┬──────────────┬─────────────┬────────────┐ │
│  │ 方法            │ 記憶體需求    │ 新任務效能   │ 舊任務保留 │ │
│  ├────────────────┼──────────────┼─────────────┼────────────┤ │
│  │ EWC            │ 低           │ 中          │ 中高       │ │
│  │ Replay         │ 高（存資料）  │ 高          │ 高         │ │
│  │ Progressive    │ 高（存模型）  │ 高          │ 最高       │ │
│  │ LoRA per task  │ 中           │ 中高        │ 高         │ │
│  └────────────────┴──────────────┴─────────────┴────────────┘ │
│                                                                │
└────────────────────────────────────────────────────────────────┘
```

In [None]:
# Experience Replay 簡單實作
class ExperienceReplay:
    """經驗重放：儲存舊任務的樣本"""
    
    def __init__(self, buffer_size: int = 1000):
        self.buffer_size = buffer_size
        self.buffer_x = []
        self.buffer_y = []
        self.task_perms = []  # 儲存每個樣本對應的 permutation
    
    def add_task_samples(self, dataloader, perm, num_samples: int = 200):
        """從任務中添加樣本到 buffer"""
        count = 0
        for x, y in dataloader:
            for i in range(x.size(0)):
                if count >= num_samples:
                    return
                
                self.buffer_x.append(x[i:i+1])
                self.buffer_y.append(y[i:i+1])
                self.task_perms.append(perm)
                count += 1
        
        # 如果超過 buffer 大小，隨機移除
        while len(self.buffer_x) > self.buffer_size:
            idx = np.random.randint(0, len(self.buffer_x))
            del self.buffer_x[idx]
            del self.buffer_y[idx]
            del self.task_perms[idx]
    
    def get_batch(self, batch_size: int, device):
        """取得一批重放樣本"""
        if len(self.buffer_x) == 0:
            return None, None, None
        
        indices = np.random.choice(len(self.buffer_x), 
                                   min(batch_size, len(self.buffer_x)), 
                                   replace=False)
        
        batch_x = torch.cat([self.buffer_x[i] for i in indices]).to(device)
        batch_y = torch.cat([self.buffer_y[i] for i in indices]).to(device)
        batch_perms = [self.task_perms[i] for i in indices]
        
        return batch_x, batch_y, batch_perms

print("Experience Replay 已定義")

## Part 7: 練習題

### Exercise 1: 實作 Synaptic Intelligence (SI)

In [None]:
class SynapticIntelligence:
    """
    Synaptic Intelligence (SI)
    
    核心想法: 追蹤訓練過程中參數對損失下降的貢獻，
    貢獻大的參數更重要，應該被保護。
    
    參考: "Continual Learning Through Synaptic Intelligence" (Zenke et al., 2017)
    """
    
    def __init__(self, model: nn.Module, c: float = 0.1):
        """
        Args:
            model: 模型
            c: 正則化強度
        """
        self.model = model
        self.c = c
        
        # 累積的參數重要性
        self.omega = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
        
        # 訓練過程中追蹤的變數
        self.prev_params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
        self.running_sum = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
    
    def update_running_sum(self):
        """在每個訓練步驟後更新 running sum"""
        for n, p in self.model.named_parameters():
            if n in self.running_sum and p.grad is not None:
                # 累積: 梯度 * 參數變化
                delta = p.detach() - self.prev_params[n]
                self.running_sum[n] += -p.grad.detach() * delta
                self.prev_params[n] = p.clone().detach()
    
    def update_omega(self, task_params_start: Dict[str, torch.Tensor]):
        """
        在任務結束時更新 omega
        
        omega = running_sum / (parameter_change^2 + epsilon)
        """
        epsilon = 1e-7
        
        for n, p in self.model.named_parameters():
            if n in self.omega:
                delta = (p.detach() - task_params_start[n]).pow(2) + epsilon
                self.omega[n] += self.running_sum[n] / delta
        
        # 重置 running sum
        self.running_sum = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad}
    
    def penalty(self, task_params: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        計算 SI penalty
        
        penalty = c * sum(omega * (theta - theta_task)^2)
        """
        loss = 0
        for n, p in self.model.named_parameters():
            if n in self.omega:
                loss += (self.omega[n] * (p - task_params[n]).pow(2)).sum()
        return self.c * loss

print("Synaptic Intelligence (SI) 已實作")
print("\nSI vs EWC 比較:")
print("- EWC: 使用 Fisher 資訊（在任務結束時計算）")
print("- SI: 追蹤訓練過程中參數對損失的貢獻（線上計算）")

### Exercise 2: 比較不同方法的遺忘程度

In [None]:
def compare_continual_learning_methods():
    """
    比較不同持續學習方法的遺忘程度
    """
    results = {}
    
    # 重新創建任務
    tasks = create_permuted_mnist_tasks(num_tasks=3)
    
    # 1. Naive (Sequential Training)
    print("\n1. Naive Sequential Training...")
    model_naive = SimpleMLP().to(device)
    history_naive = sequential_training(model_naive, tasks, epochs_per_task=3)
    results['Naive'] = compute_forgetting_metrics(history_naive)
    
    # 2. EWC
    print("\n2. EWC...")
    model_ewc = SimpleMLP().to(device)
    history_ewc = train_with_ewc(model_ewc, tasks, epochs_per_task=3, ewc_lambda=5000)
    results['EWC'] = compute_forgetting_metrics(history_ewc)
    
    # 打印比較結果
    print("\n" + "="*60)
    print("方法比較結果")
    print("="*60)
    print(f"{'方法':<20} {'平均遺忘':<15} {'最終平均準確率':<15}")
    print("-"*60)
    
    for method, metrics in results.items():
        avg_forget = metrics.get('average_forgetting', 0)
        avg_acc = metrics.get('average_accuracy', 0)
        print(f"{method:<20} {avg_forget:<15.2%} {avg_acc:<15.2%}")
    
    return results

comparison_results = compare_continual_learning_methods()

## 總結

```
┌─────────────────────────────────────────────────────────────┐
│                災難性遺忘與持續學習總結                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 災難性遺忘                                              │
│     • 學習新任務時嚴重忘記舊任務                             │
│     • 原因：參數共享、干擾、分布偏移                         │
│                                                             │
│  2. 量化指標                                                │
│     • 平均遺忘、最大遺忘                                    │
│     • 後向轉移                                              │
│                                                             │
│  3. 解決方法                                                │
│     • 正則化: EWC, SI（保護重要參數）                        │
│     • 重放: 儲存/生成舊任務樣本                              │
│     • 架構: LoRA/Adapter per task                          │
│                                                             │
│  4. LLM 微調中的遺忘                                        │
│     • Full fine-tuning 容易遺忘                             │
│     • LoRA 可以減少遺忘                                     │
│     • 結合 EWC/Replay 效果更好                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘
```

### 下一步學習

- **RLHF**: `reinforcement_learning/rlhf_alignment.ipynb`
- **模型編輯**: `llm_advanced/model_editing.ipynb`
- **模型合併**: `llm_advanced/model_merging.ipynb`

## 參考資源

### 課程
- [李宏毅 2025 Spring ML HW6](https://speech.ee.ntu.edu.tw/~hylee/ml/2025-spring.php)
- [李宏毅 2025 Fall GenAI-ML HW8](https://speech.ee.ntu.edu.tw/~hylee/GenAI-ML/2025-fall.php)

### 論文
- [EWC: Overcoming catastrophic forgetting in neural networks](https://arxiv.org/abs/1612.00796)
- [SI: Continual Learning Through Synaptic Intelligence](https://arxiv.org/abs/1703.04200)
- [Progressive Neural Networks](https://arxiv.org/abs/1606.04671)