# Lifelong Learning (終身學習 / 持續學習)

**對應課程**: 李宏毅 2021 ML HW14 - Life-Long Learning

本 notebook 介紹如何讓模型持續學習新任務而不遺忘舊知識：
- **問題定義**: 持續學習場景與挑戰
- **Replay-based**: 經驗回放、生成式回放
- **Architecture-based**: PackNet, Progressive Neural Networks
- **評估指標**: Average Accuracy, Forgetting, Forward Transfer

```
持續學習場景：

時間軸
──────────────────────────────────────────────────────────►
    Task 1        Task 2        Task 3        Task 4
   (貓狗分類)   (車輛識別)    (花卉分類)    (手寫數字)
      │            │            │            │
      ▼            ▼            ▼            ▼
  ┌───────┐    ┌───────┐    ┌───────┐    ┌───────┐
  │ Model │ ─► │ Model │ ─► │ Model │ ─► │ Model │
  │  M₁   │    │  M₂   │    │  M₃   │    │  M₄   │
  └───────┘    └───────┘    └───────┘    └───────┘

挑戰：
- 災難性遺忘：學習 Task 2 後忘記 Task 1
- 無法存取舊任務數據
- 需要在新舊任務間取得平衡

相關 Notebook:
- catastrophic_forgetting.ipynb: EWC, Synaptic Intelligence 等正則化方法
- 本 notebook 聚焦於 Replay-based 和 Architecture-based 方法
```

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, List, Dict, Tuple, Callable
from dataclasses import dataclass, field
from collections import defaultdict
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Part 1: 持續學習問題設定

創建一個多任務序列學習的場景。

In [None]:
@dataclass
class Task:
    """任務定義"""
    task_id: int
    name: str
    X_train: torch.Tensor
    y_train: torch.Tensor
    X_test: torch.Tensor
    y_test: torch.Tensor
    num_classes: int


def create_permuted_mnist_tasks(
    num_tasks: int = 5,
    num_samples: int = 1000,
    input_dim: int = 784
) -> List[Task]:
    """
    創建 Permuted MNIST 風格的任務序列
    每個任務使用不同的像素排列
    """
    np.random.seed(42)
    torch.manual_seed(42)
    
    tasks = []
    
    for task_id in range(num_tasks):
        # 生成隨機排列（第一個任務使用原始順序）
        if task_id == 0:
            perm = np.arange(input_dim)
        else:
            perm = np.random.permutation(input_dim)
            
        # 生成模擬數據（簡化版）
        num_classes = 10
        
        # 訓練數據
        X_train = torch.randn(num_samples, input_dim)
        y_train = torch.randint(0, num_classes, (num_samples,))
        # 應用排列
        X_train = X_train[:, perm]
        
        # 測試數據
        X_test = torch.randn(num_samples // 5, input_dim)
        y_test = torch.randint(0, num_classes, (num_samples // 5,))
        X_test = X_test[:, perm]
        
        task = Task(
            task_id=task_id,
            name=f"Permuted Task {task_id}",
            X_train=X_train,
            y_train=y_train,
            X_test=X_test,
            y_test=y_test,
            num_classes=num_classes
        )
        tasks.append(task)
        
    return tasks


def create_split_tasks(
    num_tasks: int = 5,
    classes_per_task: int = 2,
    num_samples: int = 500
) -> List[Task]:
    """
    創建 Split 風格的任務序列
    每個任務只包含部分類別
    """
    np.random.seed(42)
    torch.manual_seed(42)
    
    tasks = []
    input_dim = 64
    
    for task_id in range(num_tasks):
        # 每個任務的類別範圍
        start_class = task_id * classes_per_task
        
        # 生成數據
        X_train_list = []
        y_train_list = []
        
        for c in range(classes_per_task):
            # 每個類別有不同的均值
            class_mean = torch.randn(input_dim) * 2
            X_c = torch.randn(num_samples // classes_per_task, input_dim) + class_mean
            y_c = torch.full((num_samples // classes_per_task,), c, dtype=torch.long)
            X_train_list.append(X_c)
            y_train_list.append(y_c)
            
        X_train = torch.cat(X_train_list)
        y_train = torch.cat(y_train_list)
        
        # 打亂順序
        perm = torch.randperm(len(X_train))
        X_train = X_train[perm]
        y_train = y_train[perm]
        
        # 測試數據（類似方式生成）
        X_test_list = []
        y_test_list = []
        for c in range(classes_per_task):
            class_mean = torch.randn(input_dim) * 2
            X_c = torch.randn(num_samples // classes_per_task // 5, input_dim) + class_mean
            y_c = torch.full((num_samples // classes_per_task // 5,), c, dtype=torch.long)
            X_test_list.append(X_c)
            y_test_list.append(y_c)
            
        X_test = torch.cat(X_test_list)
        y_test = torch.cat(y_test_list)
        
        task = Task(
            task_id=task_id,
            name=f"Split Task {task_id} (classes {start_class}-{start_class+classes_per_task-1})",
            X_train=X_train,
            y_train=y_train,
            X_test=X_test,
            y_test=y_test,
            num_classes=classes_per_task
        )
        tasks.append(task)
        
    return tasks


# 創建任務
tasks = create_split_tasks(num_tasks=5, classes_per_task=2)
print("創建的任務：")
for task in tasks:
    print(f"  {task.name}: {len(task.X_train)} 訓練樣本, {task.num_classes} 類別")

## Part 2: Replay-based Methods

通過保存或生成舊任務的樣本來防止遺忘。

```
Replay 方法分類：

1. Experience Replay (ER)
   ├─ 保存部分舊樣本到記憶緩衝區
   ├─ 訓練時混合新舊樣本
   └─ 簡單有效，但需要存儲空間

2. Generative Replay (GR)
   ├─ 訓練生成器生成舊任務樣本
   ├─ 不需要存儲真實樣本
   └─ 生成品質影響效果

3. Gradient Episodic Memory (GEM)
   ├─ 保存樣本用於約束梯度方向
   ├─ 確保更新不損害舊任務
   └─ 計算成本較高
```

In [None]:
class ReplayBuffer:
    """經驗回放緩衝區"""
    
    def __init__(
        self,
        buffer_size: int = 1000,
        strategy: str = 'reservoir'  # 'reservoir' or 'ring'
    ):
        self.buffer_size = buffer_size
        self.strategy = strategy
        
        self.buffer_x: List[torch.Tensor] = []
        self.buffer_y: List[torch.Tensor] = []
        self.buffer_task: List[int] = []
        
        self.num_seen = 0
        
    def add(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        task_id: int
    ):
        """添加樣本到緩衝區"""
        batch_size = x.shape[0]
        
        for i in range(batch_size):
            self.num_seen += 1
            
            if self.strategy == 'reservoir':
                # Reservoir Sampling：保持均勻分佈
                if len(self.buffer_x) < self.buffer_size:
                    self.buffer_x.append(x[i].clone())
                    self.buffer_y.append(y[i].clone())
                    self.buffer_task.append(task_id)
                else:
                    # 以 buffer_size / num_seen 的概率替換
                    j = np.random.randint(0, self.num_seen)
                    if j < self.buffer_size:
                        self.buffer_x[j] = x[i].clone()
                        self.buffer_y[j] = y[i].clone()
                        self.buffer_task[j] = task_id
                        
            elif self.strategy == 'ring':
                # Ring Buffer：FIFO
                if len(self.buffer_x) < self.buffer_size:
                    self.buffer_x.append(x[i].clone())
                    self.buffer_y.append(y[i].clone())
                    self.buffer_task.append(task_id)
                else:
                    idx = self.num_seen % self.buffer_size
                    self.buffer_x[idx] = x[i].clone()
                    self.buffer_y[idx] = y[i].clone()
                    self.buffer_task[idx] = task_id
                    
    def sample(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
        """從緩衝區採樣"""
        if len(self.buffer_x) == 0:
            return None, None, None
            
        indices = np.random.choice(
            len(self.buffer_x),
            size=min(batch_size, len(self.buffer_x)),
            replace=False
        )
        
        x = torch.stack([self.buffer_x[i] for i in indices])
        y = torch.stack([self.buffer_y[i] for i in indices])
        task_ids = [self.buffer_task[i] for i in indices]
        
        return x, y, task_ids
    
    def __len__(self):
        return len(self.buffer_x)
    
    def get_task_distribution(self) -> Dict[int, int]:
        """獲取緩衝區中各任務的樣本數量"""
        dist = defaultdict(int)
        for t in self.buffer_task:
            dist[t] += 1
        return dict(dist)


# 測試
buffer = ReplayBuffer(buffer_size=100)

# 添加多個任務的樣本
for task in tasks[:3]:
    buffer.add(task.X_train[:50], task.y_train[:50], task.task_id)
    
print(f"緩衝區大小: {len(buffer)}")
print(f"任務分佈: {buffer.get_task_distribution()}")

# 採樣
x, y, t = buffer.sample(16)
print(f"採樣形狀: {x.shape}")

In [None]:
class ExperienceReplay:
    """經驗回放訓練器"""
    
    def __init__(
        self,
        model: nn.Module,
        buffer_size: int = 500,
        replay_batch_size: int = 16,
        replay_freq: int = 1  # 每幾個 batch 做一次 replay
    ):
        self.model = model
        self.buffer = ReplayBuffer(buffer_size)
        self.replay_batch_size = replay_batch_size
        self.replay_freq = replay_freq
        
    def train_task(
        self,
        task: Task,
        epochs: int = 10,
        batch_size: int = 32,
        lr: float = 0.01
    ) -> List[float]:
        """訓練單個任務（帶經驗回放）"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        losses = []
        batch_count = 0
        
        for epoch in range(epochs):
            # 打亂數據
            perm = torch.randperm(len(task.X_train))
            X = task.X_train[perm]
            y = task.y_train[perm]
            
            for i in range(0, len(X), batch_size):
                x_batch = X[i:i+batch_size]
                y_batch = y[i:i+batch_size]
                
                optimizer.zero_grad()
                
                # 當前任務損失
                output = self.model(x_batch)
                loss = criterion(output, y_batch)
                
                # 經驗回放損失
                if batch_count % self.replay_freq == 0 and len(self.buffer) > 0:
                    x_replay, y_replay, _ = self.buffer.sample(self.replay_batch_size)
                    if x_replay is not None:
                        output_replay = self.model(x_replay)
                        loss_replay = criterion(output_replay, y_replay)
                        loss = loss + loss_replay
                        
                loss.backward()
                optimizer.step()
                
                losses.append(loss.item())
                batch_count += 1
                
        # 將當前任務的部分樣本加入緩衝區
        num_to_add = min(100, len(task.X_train))
        indices = np.random.choice(len(task.X_train), num_to_add, replace=False)
        self.buffer.add(
            task.X_train[indices],
            task.y_train[indices],
            task.task_id
        )
        
        return losses


# 測試
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=128, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, x):
        return self.net(x)

model_er = SimpleMLP()
er_trainer = ExperienceReplay(model_er, buffer_size=200)

print("使用經驗回放訓練...")
for i, task in enumerate(tasks[:3]):
    print(f"\n訓練 {task.name}")
    losses = er_trainer.train_task(task, epochs=5)
    print(f"  最終 Loss: {np.mean(losses[-10:]):.4f}")
    print(f"  緩衝區分佈: {er_trainer.buffer.get_task_distribution()}")

In [None]:
class GenerativeReplay:
    """生成式回放（使用 VAE）"""
    
    def __init__(
        self,
        classifier: nn.Module,
        generator: nn.Module,  # VAE or GAN
        input_dim: int = 64,
        latent_dim: int = 32
    ):
        self.classifier = classifier
        self.generator = generator
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # 舊的分類器和生成器（用於生成舊樣本）
        self.old_classifier = None
        self.old_generator = None
        
    def generate_old_samples(
        self,
        num_samples: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """生成舊任務的樣本"""
        if self.old_generator is None:
            return None, None
            
        self.old_generator.eval()
        self.old_classifier.eval()
        
        with torch.no_grad():
            # 從隱空間採樣
            z = torch.randn(num_samples, self.latent_dim)
            
            # 生成樣本
            x_generated = self.old_generator.decode(z)
            
            # 使用舊分類器產生軟標籤
            y_soft = F.softmax(self.old_classifier(x_generated), dim=1)
            
        return x_generated, y_soft
    
    def update_old_models(self):
        """更新舊模型"""
        self.old_classifier = copy.deepcopy(self.classifier)
        self.old_generator = copy.deepcopy(self.generator)
        
        # 凍結舊模型
        for param in self.old_classifier.parameters():
            param.requires_grad = False
        for param in self.old_generator.parameters():
            param.requires_grad = False


class SimpleVAE(nn.Module):
    """簡單的 VAE 用於生成式回放"""
    
    def __init__(self, input_dim: int = 64, hidden_dim: int = 128, latent_dim: int = 32):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


print("生成式回放架構已定義")

## Part 3: Architecture-based Methods

通過修改網路架構來分配不同任務的參數。

```
Architecture-based 方法：

1. PackNet
   ├─ 訓練後剪枝，釋放參數給新任務
   ├─ 凍結舊任務參數
   └─ 無遺忘，但容量有限

2. Progressive Neural Networks
   ├─ 為每個任務添加新的網路列
   ├─ 使用橫向連接傳遞知識
   └─ 無遺忘，但模型不斷增長

3. Dynamic Expandable Network (DEN)
   ├─ 選擇性地擴展網路
   ├─ 重用相關神經元
   └─ 平衡容量和效率
```

In [None]:
class PackNet:
    """PackNet: 通過剪枝和凍結實現持續學習"""
    
    def __init__(
        self,
        model: nn.Module,
        prune_ratio: float = 0.5  # 每個任務使用的參數比例
    ):
        self.model = model
        self.prune_ratio = prune_ratio
        
        # 記錄每個任務使用的參數遮罩
        self.task_masks: Dict[int, Dict[str, torch.Tensor]] = {}
        
        # 已使用的參數遮罩（累積）
        self.used_mask: Dict[str, torch.Tensor] = {}
        
        # 初始化遮罩
        for name, param in model.named_parameters():
            if 'weight' in name:
                self.used_mask[name] = torch.zeros_like(param.data)
                
    def train_task(
        self,
        task: Task,
        epochs: int = 20,
        batch_size: int = 32,
        lr: float = 0.01
    ):
        """訓練任務並進行剪枝"""
        # 獲取可用參數遮罩
        available_mask = {}
        for name, param in self.model.named_parameters():
            if name in self.used_mask:
                available_mask[name] = 1 - self.used_mask[name]
                
        # 訓練
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            perm = torch.randperm(len(task.X_train))
            X = task.X_train[perm]
            y = task.y_train[perm]
            
            for i in range(0, len(X), batch_size):
                x_batch = X[i:i+batch_size]
                y_batch = y[i:i+batch_size]
                
                optimizer.zero_grad()
                
                # 應用遮罩（只更新可用參數）
                for name, param in self.model.named_parameters():
                    if name in available_mask:
                        param.data *= available_mask[name]
                        
                output = self.model(x_batch)
                loss = criterion(output, y_batch)
                loss.backward()
                
                # 遮罩梯度
                for name, param in self.model.named_parameters():
                    if name in available_mask and param.grad is not None:
                        param.grad *= available_mask[name]
                        
                optimizer.step()
                
        # 剪枝：選擇最重要的參數
        task_mask = self._prune_and_freeze(task.task_id, available_mask)
        self.task_masks[task.task_id] = task_mask
        
        # 更新已使用遮罩
        for name in task_mask:
            self.used_mask[name] = self.used_mask[name] + task_mask[name]
            self.used_mask[name] = torch.clamp(self.used_mask[name], 0, 1)
            
    def _prune_and_freeze(
        self,
        task_id: int,
        available_mask: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """剪枝並凍結參數"""
        task_mask = {}
        
        for name, param in self.model.named_parameters():
            if name not in available_mask:
                continue
                
            # 計算可用參數數量
            available = available_mask[name]
            num_available = int(available.sum().item())
            
            if num_available == 0:
                task_mask[name] = torch.zeros_like(param.data)
                continue
                
            # 選擇最重要的參數（按絕對值）
            num_to_keep = int(num_available * self.prune_ratio)
            
            # 只考慮可用參數
            masked_weights = param.data.abs() * available
            threshold = torch.topk(masked_weights.flatten(), num_to_keep).values.min()
            
            # 創建任務遮罩
            mask = (masked_weights >= threshold).float() * available
            task_mask[name] = mask
            
        return task_mask
    
    def get_capacity_usage(self) -> float:
        """獲取已使用的容量比例"""
        total = 0
        used = 0
        for name, mask in self.used_mask.items():
            total += mask.numel()
            used += mask.sum().item()
        return used / total if total > 0 else 0


# 測試
model_packnet = SimpleMLP()
packnet = PackNet(model_packnet, prune_ratio=0.3)

print("使用 PackNet 訓練...")
for task in tasks[:3]:
    print(f"\n訓練 {task.name}")
    packnet.train_task(task, epochs=10)
    print(f"  容量使用: {packnet.get_capacity_usage():.2%}")

In [None]:
class ProgressiveNeuralNetwork(nn.Module):
    """Progressive Neural Networks"""
    
    def __init__(
        self,
        input_dim: int = 64,
        hidden_dim: int = 64,
        num_classes: int = 2
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        # 存儲每個任務的網路列
        self.columns: nn.ModuleList = nn.ModuleList()
        
        # 橫向連接
        self.lateral_connections: nn.ModuleList = nn.ModuleList()
        
        # 任務特定的輸出頭
        self.heads: nn.ModuleList = nn.ModuleList()
        
        self.num_tasks = 0
        
    def add_task(self):
        """為新任務添加網路列"""
        task_id = self.num_tasks
        
        # 新的網路列（2 層）
        column = nn.ModuleList([
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.Linear(self.hidden_dim, self.hidden_dim)
        ])
        self.columns.append(column)
        
        # 橫向連接（從舊列到新列）
        if task_id > 0:
            lateral = nn.ModuleList([
                nn.Linear(self.hidden_dim * task_id, self.hidden_dim),  # 連接第一層
                nn.Linear(self.hidden_dim * task_id, self.hidden_dim)   # 連接第二層
            ])
        else:
            lateral = nn.ModuleList([None, None])
        self.lateral_connections.append(lateral)
        
        # 輸出頭
        head = nn.Linear(self.hidden_dim, self.num_classes)
        self.heads.append(head)
        
        self.num_tasks += 1
        
        # 凍結舊列
        for i in range(task_id):
            for param in self.columns[i].parameters():
                param.requires_grad = False
            if self.lateral_connections[i][0] is not None:
                for param in self.lateral_connections[i].parameters():
                    param.requires_grad = False
            for param in self.heads[i].parameters():
                param.requires_grad = False
                
        return task_id
    
    def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
        """前向傳播（指定任務）"""
        if task_id >= self.num_tasks:
            raise ValueError(f"Task {task_id} not added yet")
            
        # 計算所有列的隱藏狀態（到指定任務）
        hidden_states = []  # [layer][column]
        
        for layer_idx in range(2):  # 2 層
            layer_hiddens = []
            
            for col_idx in range(task_id + 1):
                column = self.columns[col_idx]
                
                if layer_idx == 0:
                    # 第一層：輸入是 x
                    h = F.relu(column[layer_idx](x))
                else:
                    # 後續層：輸入是上一層的輸出
                    h_prev = hidden_states[layer_idx - 1][col_idx]
                    h = F.relu(column[layer_idx](h_prev))
                    
                # 添加橫向連接（如果存在）
                if col_idx > 0 and self.lateral_connections[col_idx][layer_idx] is not None:
                    # 收集所有舊列的隱藏狀態
                    if layer_idx == 0:
                        prev_hiddens = torch.cat([hidden_states[0][i] for i in range(col_idx)], dim=1)
                    else:
                        prev_hiddens = torch.cat([hidden_states[layer_idx-1][i] for i in range(col_idx)], dim=1)
                    
                    lateral_contribution = self.lateral_connections[col_idx][layer_idx](prev_hiddens)
                    h = h + lateral_contribution
                    
                layer_hiddens.append(h)
                
            hidden_states.append(layer_hiddens)
            
        # 使用對應任務的輸出頭
        final_hidden = hidden_states[-1][task_id]
        output = self.heads[task_id](final_hidden)
        
        return output


# 測試
pnn = ProgressiveNeuralNetwork(input_dim=64, hidden_dim=32, num_classes=2)

# 添加任務並測試
for i in range(3):
    task_id = pnn.add_task()
    x = torch.randn(16, 64)
    out = pnn(x, task_id)
    print(f"Task {task_id}: 輸出形狀 {out.shape}")
    
# 統計參數
total_params = sum(p.numel() for p in pnn.parameters())
print(f"\n總參數量: {total_params:,}")

## Part 4: 評估指標

持續學習的標準評估指標。

In [None]:
@dataclass
class ContinualLearningMetrics:
    """持續學習評估指標"""
    
    # 準確率矩陣 R[i,j] = 訓練完任務 i 後在任務 j 上的準確率
    accuracy_matrix: np.ndarray = field(default_factory=lambda: np.array([]))
    
    def compute_average_accuracy(self) -> float:
        """
        平均準確率 (Average Accuracy)
        訓練完所有任務後，在所有任務上的平均準確率
        """
        if len(self.accuracy_matrix) == 0:
            return 0.0
        return self.accuracy_matrix[-1].mean()
    
    def compute_forgetting(self) -> float:
        """
        遺忘度 (Forgetting)
        F = (1/T-1) * Σ max_{l<T}(R[l,j]) - R[T,j]
        """
        T = len(self.accuracy_matrix)
        if T < 2:
            return 0.0
            
        forgetting = 0.0
        for j in range(T - 1):  # 除了最後一個任務
            # 找到任務 j 的最高歷史準確率
            max_acc = max(self.accuracy_matrix[l, j] for l in range(j, T))
            # 當前準確率
            curr_acc = self.accuracy_matrix[-1, j]
            forgetting += max_acc - curr_acc
            
        return forgetting / (T - 1)
    
    def compute_forward_transfer(self) -> float:
        """
        前向遷移 (Forward Transfer)
        FWT = (1/T-1) * Σ R[i-1,i] - b_i
        其中 b_i 是隨機初始化模型在任務 i 上的準確率
        （簡化版：假設 b_i = 1/num_classes）
        """
        T = len(self.accuracy_matrix)
        if T < 2:
            return 0.0
            
        # 假設隨機基線
        random_baseline = 0.5  # 二分類
        
        fwt = 0.0
        for i in range(1, T):
            # 訓練任務 i 之前在任務 i 上的準確率
            acc_before = self.accuracy_matrix[i-1, i] if i < len(self.accuracy_matrix[i-1]) else random_baseline
            fwt += acc_before - random_baseline
            
        return fwt / (T - 1)
    
    def compute_backward_transfer(self) -> float:
        """
        後向遷移 (Backward Transfer)
        BWT = (1/T-1) * Σ R[T,j] - R[j,j]
        """
        T = len(self.accuracy_matrix)
        if T < 2:
            return 0.0
            
        bwt = 0.0
        for j in range(T - 1):
            # 最終準確率 - 剛訓練完的準確率
            bwt += self.accuracy_matrix[-1, j] - self.accuracy_matrix[j, j]
            
        return bwt / (T - 1)
    
    def summary(self) -> Dict[str, float]:
        """返回所有指標的摘要"""
        return {
            'Average Accuracy': self.compute_average_accuracy(),
            'Forgetting': self.compute_forgetting(),
            'Forward Transfer': self.compute_forward_transfer(),
            'Backward Transfer': self.compute_backward_transfer()
        }


def evaluate_continual_learner(
    model: nn.Module,
    tasks: List[Task],
    train_fn: Callable,  # 訓練函數
    eval_fn: Callable    # 評估函數
) -> ContinualLearningMetrics:
    """
    評估持續學習者
    
    Args:
        model: 模型
        tasks: 任務列表
        train_fn: 訓練函數 train_fn(model, task)
        eval_fn: 評估函數 eval_fn(model, task) -> accuracy
    """
    num_tasks = len(tasks)
    accuracy_matrix = np.zeros((num_tasks, num_tasks))
    
    for i, task in enumerate(tasks):
        print(f"\n訓練 Task {i}: {task.name}")
        train_fn(model, task)
        
        # 評估所有已見過的任務
        for j in range(i + 1):
            acc = eval_fn(model, tasks[j])
            accuracy_matrix[i, j] = acc
            print(f"  Task {j} 準確率: {acc:.2%}")
            
    metrics = ContinualLearningMetrics(accuracy_matrix=accuracy_matrix)
    return metrics


# 範例準確率矩陣（模擬）
example_matrix = np.array([
    [0.90, 0.00, 0.00, 0.00, 0.00],  # 訓練完 Task 0
    [0.70, 0.85, 0.00, 0.00, 0.00],  # 訓練完 Task 1
    [0.60, 0.75, 0.88, 0.00, 0.00],  # 訓練完 Task 2
    [0.55, 0.70, 0.80, 0.92, 0.00],  # 訓練完 Task 3
    [0.50, 0.65, 0.75, 0.85, 0.90],  # 訓練完 Task 4
])

metrics = ContinualLearningMetrics(accuracy_matrix=example_matrix)
print("範例評估指標:")
for name, value in metrics.summary().items():
    print(f"  {name}: {value:.4f}")

In [None]:
def visualize_accuracy_matrix(accuracy_matrix: np.ndarray):
    """視覺化準確率矩陣"""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # 只顯示非零部分
    mask = accuracy_matrix == 0
    masked_matrix = np.ma.array(accuracy_matrix, mask=mask)
    
    im = ax.imshow(masked_matrix, cmap='RdYlGn', vmin=0, vmax=1)
    
    # 添加數值標籤
    for i in range(len(accuracy_matrix)):
        for j in range(len(accuracy_matrix[i])):
            if accuracy_matrix[i, j] > 0:
                ax.text(j, i, f'{accuracy_matrix[i, j]:.2f}',
                       ha='center', va='center', fontsize=10)
                
    ax.set_xlabel('任務 j (評估)')
    ax.set_ylabel('任務 i (訓練完成後)')
    ax.set_title('準確率矩陣 R[i,j]')
    
    plt.colorbar(im, ax=ax, label='準確率')
    plt.tight_layout()
    plt.show()

visualize_accuracy_matrix(example_matrix)

## Part 5: 方法比較

In [None]:
def print_method_comparison():
    """印出持續學習方法比較"""
    print("""
╔═══════════════════════════════════════════════════════════════════════════════════╗
║                          持續學習方法比較                                          ║
╠════════════════════╦══════════════════════════════════════════════════════════════╣
║       類別         ║                      方法與特點                               ║
╠════════════════════╬══════════════════════════════════════════════════════════════╣
║                    ║ EWC: 用 Fisher Information 保護重要參數                      ║
║  Regularization    ║ SI: 追蹤參數對損失的貢獻度                                   ║
║  -based            ║ LwF: 知識蒸餾保留舊知識                                      ║
║                    ║ 優點: 不需存儲，模型大小固定                                 ║
║                    ║ 缺點: 長序列任務效果下降                                     ║
╠════════════════════╬══════════════════════════════════════════════════════════════╣
║                    ║ Experience Replay: 存儲並重播舊樣本                          ║
║  Replay            ║ Generative Replay: 用生成器生成舊樣本                        ║
║  -based            ║ GEM: 約束梯度不損害舊任務                                    ║
║                    ║ 優點: 效果好，簡單                                           ║
║                    ║ 缺點: 需要存儲空間或生成器                                   ║
╠════════════════════╬══════════════════════════════════════════════════════════════╣
║                    ║ PackNet: 剪枝+凍結                                           ║
║  Architecture      ║ PNN: 為每個任務添加網路列                                    ║
║  -based            ║ DEN: 動態擴展網路                                            ║
║                    ║ 優點: 無遺忘                                                 ║
║                    ║ 缺點: 模型增長或容量受限                                     ║
╠════════════════════╬══════════════════════════════════════════════════════════════╣
║                    ║ HAL: 結合 Replay + EWC                                       ║
║  Hybrid            ║ ER-EWC: Replay + 正則化                                      ║
║                    ║ 通常效果最好，但也最複雜                                     ║
╚════════════════════╩══════════════════════════════════════════════════════════════╝

選擇指南：
┌────────────────────────────────────────────────────────────────────┐
│ 場景                                │ 推薦方法                     │
├────────────────────────────────────────────────────────────────────┤
│ 記憶體受限                          │ EWC / SI                     │
│ 可存儲少量樣本                      │ Experience Replay            │
│ 任務數量已知且固定                  │ PackNet                      │
│ 需要零遺忘                          │ PNN                          │
│ 任務相似度高                        │ Generative Replay            │
│ 追求最佳效果                        │ Hybrid (ER + EWC)            │
└────────────────────────────────────────────────────────────────────┘

評估指標解讀：
┌────────────────────────────────────────────────────────────────────┐
│ Average Accuracy ↑  │ 越高越好，整體性能                          │
│ Forgetting ↓        │ 越低越好，遺忘程度                          │
│ Forward Transfer ↑  │ > 0 表示有正向遷移                          │
│ Backward Transfer ↑ │ > 0 表示學新知識提升了舊任務（罕見）        │
└────────────────────────────────────────────────────────────────────┘
""")

print_method_comparison()

## 練習

### Exercise 1: 實作 A-GEM (Averaged Gradient Episodic Memory)

In [None]:
class AGEM:
    """Averaged Gradient Episodic Memory"""
    
    def __init__(self, model: nn.Module, buffer_size: int = 256):
        self.model = model
        self.buffer = ReplayBuffer(buffer_size)
        
    def project_gradient(
        self,
        grad: torch.Tensor,
        ref_grad: torch.Tensor
    ) -> torch.Tensor:
        """
        投影梯度以不損害舊任務
        
        如果 grad · ref_grad < 0，投影 grad 到 ref_grad 的正交補空間
        
        TODO: 實作梯度投影
        """
        pass
    
    def train_step(self, x: torch.Tensor, y: torch.Tensor, task_id: int):
        """
        單步訓練
        
        TODO: 
        1. 計算當前 batch 的梯度
        2. 計算 buffer 樣本的參考梯度
        3. 如果違反約束，投影梯度
        4. 更新參數
        """
        pass

### Exercise 2: 比較不同方法

In [None]:
def compare_continual_learning_methods(tasks: List[Task]):
    """
    比較不同持續學習方法的效果
    
    TODO:
    1. 實現基線（Fine-tuning）
    2. 實現 Experience Replay
    3. 實現 EWC（參見 catastrophic_forgetting.ipynb）
    4. 比較各方法的評估指標
    5. 繪製準確率矩陣比較圖
    """
    pass

### Exercise 3: 實作 Memory-Aware Synapses (MAS)

In [None]:
class MAS:
    """Memory Aware Synapses"""
    
    def __init__(self, model: nn.Module, lambda_: float = 1.0):
        """
        MAS 使用參數對輸出的敏感度作為重要性度量
        
        Ω_i = E[ ||∂L/∂θ_i||² ]  （這裡 L 是輸出的 L2 範數）
        
        TODO: 實作 MAS
        """
        pass
    
    def estimate_importance(self, data_loader):
        """估計參數重要性"""
        pass
    
    def compute_loss(self, task_loss: torch.Tensor) -> torch.Tensor:
        """計算總損失 = 任務損失 + λ * 正則化損失"""
        pass

## 總結

```
持續學習重點回顧：

1. 問題設定
   ├─ 任務序列學習
   ├─ 舊任務數據不可用
   └─ 需要平衡新舊知識

2. Replay-based
   ├─ Experience Replay: 存儲真實樣本
   ├─ Generative Replay: 生成偽樣本
   └─ GEM/A-GEM: 梯度約束

3. Architecture-based
   ├─ PackNet: 剪枝+凍結
   ├─ PNN: 網路擴展
   └─ DEN: 動態擴展

4. 評估指標
   ├─ Average Accuracy: 整體性能
   ├─ Forgetting: 遺忘程度
   ├─ Forward Transfer: 正向遷移
   └─ Backward Transfer: 後向遷移

實際應用建議：
- 優先嘗試 Experience Replay（簡單有效）
- 記憶體受限時用 EWC/SI
- 追求零遺忘用 PNN
- 最佳效果通常是 Hybrid 方法
```