# ResNet Block - 殘差塊

## 學習目標

1. 理解深度網路訓練的困難（梯度消失/爆炸、degradation problem）
2. 理解 residual connection 的原理和為什麼有效
3. 實作 Basic Block（含 forward 和 backward）
4. 實作 Bottleneck Block（選做）
5. 驗證 residual connection 對梯度流的影響

## 參考資料

- He et al., "Deep Residual Learning for Image Recognition", CVPR 2016

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

---

## 第一部分：深度網路的問題

### 1.1 Degradation Problem

在 ResNet 之前，人們發現一個奇怪的現象：

- 更深的網路反而有**更高的 training error**（不是 test error！）
- 這不是 overfitting，因為連訓練集都表現不好
- 理論上，更深的網路至少應該和淺的一樣好（多出來的層可以學習 identity mapping）

### 1.2 梯度消失/爆炸

假設一個 L 層的網路，每層的 Jacobian 為 $J_l$：

$$\frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial h_L} \cdot J_L \cdot J_{L-1} \cdots J_2 \cdot \frac{\partial h_1}{\partial W_1}$$

如果每個 $||J_l|| < 1$，連乘後會**指數級衰減**（消失）

如果每個 $||J_l|| > 1$，連乘後會**指數級增長**（爆炸）

In [None]:
# 示範：梯度消失

def simulate_gradient_flow(n_layers, jacobian_norm):
    """模擬梯度在深層網路中的流動"""
    gradient = 1.0  # 從輸出層的梯度開始
    gradients = [gradient]
    
    for _ in range(n_layers):
        # 每經過一層，梯度乘以 Jacobian 的 norm
        gradient *= jacobian_norm
        gradients.append(gradient)
    
    return gradients

# 比較不同的 Jacobian norm
n_layers = 50
norms = [0.9, 0.99, 1.0, 1.01, 1.1]

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for norm in [0.9, 0.95, 0.99]:
    grads = simulate_gradient_flow(n_layers, norm)
    plt.plot(grads, label=f'||J|| = {norm}')
plt.xlabel('Layer (from output to input)')
plt.ylabel('Gradient magnitude')
plt.title('Gradient Vanishing (||J|| < 1)')
plt.legend()
plt.yscale('log')

plt.subplot(1, 2, 2)
for norm in [1.01, 1.05, 1.1]:
    grads = simulate_gradient_flow(n_layers, norm)
    plt.plot(grads, label=f'||J|| = {norm}')
plt.xlabel('Layer (from output to input)')
plt.ylabel('Gradient magnitude')
plt.title('Gradient Exploding (||J|| > 1)')
plt.legend()
plt.yscale('log')

plt.tight_layout()
plt.show()

print(f"50 layers with ||J||=0.9: gradient = {0.9**50:.2e}")
print(f"50 layers with ||J||=1.1: gradient = {1.1**50:.2e}")

---

## 第二部分：Residual Connection

### 2.1 核心想法

原本網路學習：$H(x) = $ 目標映射

ResNet 改成學習：$F(x) = H(x) - x$（殘差）

輸出變成：$H(x) = F(x) + x$

```
        ┌──────────────────────┐
        │                      │
        │        identity      │
        │                      │
x ──────┼───→ [F(x)] ─────────(+)───→ H(x) = F(x) + x
        │                      │
        │   Conv → BN → ReLU   │
        │   Conv → BN         │
        │                      │
        └──────────────────────┘
```

### 2.2 為什麼這樣更好？

**從優化角度**：
- 如果最優解接近 identity（即 $H(x) \approx x$），那麼網路只需學習 $F(x) \approx 0$
- 學習 $F(x) = 0$ 比學習 $H(x) = x$ 容易（把權重推向 0 很簡單）

**從梯度角度**：
$$\frac{\partial H}{\partial x} = \frac{\partial F}{\partial x} + I$$

這個 $+I$（identity）保證了：
- 即使 $\frac{\partial F}{\partial x}$ 很小，梯度也至少為 1
- 提供了一條「高速公路」讓梯度直接流回去

In [None]:
# 比較有無 residual connection 的梯度流

def simulate_gradient_with_residual(n_layers, f_jacobian_norm):
    """模擬有 residual connection 的梯度流
    
    對於 y = F(x) + x:
    dy/dx = dF/dx + I
    
    這裡簡化為純量：dy/dx = f_jacobian_norm + 1
    """
    # 總 Jacobian = dF/dx + 1
    total_jacobian = f_jacobian_norm + 1
    
    gradient = 1.0
    gradients = [gradient]
    
    for _ in range(n_layers):
        gradient *= total_jacobian
        gradients.append(gradient)
    
    return gradients

n_layers = 50

plt.figure(figsize=(12, 5))

# 左圖：沒有 residual
plt.subplot(1, 2, 1)
for f_norm in [0.3, 0.5, 0.7]:
    grads_no_res = simulate_gradient_flow(n_layers, f_norm)
    plt.plot(grads_no_res, label=f'Without residual (||∂F/∂x|| = {f_norm})')
plt.xlabel('Layer')
plt.ylabel('Gradient magnitude')
plt.title('Without Residual Connection')
plt.legend()
plt.yscale('log')
plt.ylim([1e-20, 1e5])

# 右圖：有 residual
plt.subplot(1, 2, 2)
for f_norm in [-0.3, 0.0, 0.3]:
    grads_with_res = simulate_gradient_with_residual(n_layers, f_norm)
    plt.plot(grads_with_res, label=f'With residual (||∂F/∂x|| = {f_norm})')
plt.xlabel('Layer')
plt.ylabel('Gradient magnitude')
plt.title('With Residual Connection')
plt.legend()
plt.yscale('log')
plt.ylim([1e-20, 1e5])

plt.tight_layout()
plt.show()

print("觀察：")
print("- 沒有 residual：梯度消失很快")
print("- 有 residual：即使 F 的梯度小，總梯度仍然穩定")

---

## 第三部分：先準備基礎組件

ResNet Block 需要用到：Conv2D, BatchNorm2D, ReLU

這裡重新實作這些組件（從之前的 notebooks 簡化而來）

In [None]:
# 工具函數：im2col 和 col2im

def im2col(x, kH, kW, stride=1, pad=0):
    """將 4D tensor 展開成 2D matrix 以便做矩陣乘法
    
    Parameters
    ----------
    x : np.ndarray, shape (N, C, H, W)
    kH, kW : int - kernel size
    stride : int
    pad : int
    
    Returns
    -------
    col : np.ndarray, shape (N*out_H*out_W, C*kH*kW)
    """
    N, C, H, W = x.shape
    
    # Padding
    if pad > 0:
        x = np.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode='constant')
    
    H_pad, W_pad = x.shape[2], x.shape[3]
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    # 使用 stride_tricks 來高效提取 patches
    shape = (N, C, kH, kW, out_H, out_W)
    strides = (x.strides[0], x.strides[1], x.strides[2], x.strides[3],
               x.strides[2] * stride, x.strides[3] * stride)
    
    col = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_H * out_W, -1)
    
    return col

def col2im(col, x_shape, kH, kW, stride=1, pad=0):
    """im2col 的逆操作
    
    Parameters
    ----------
    col : np.ndarray, shape (N*out_H*out_W, C*kH*kW)
    x_shape : tuple - 原始 x 的 shape (N, C, H, W)
    
    Returns
    -------
    x : np.ndarray, shape (N, C, H, W)
    """
    N, C, H, W = x_shape
    H_pad = H + 2 * pad
    W_pad = W + 2 * pad
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    col = col.reshape(N, out_H, out_W, C, kH, kW).transpose(0, 3, 4, 5, 1, 2)
    
    x_pad = np.zeros((N, C, H_pad, W_pad))
    
    for i in range(kH):
        i_max = i + stride * out_H
        for j in range(kW):
            j_max = j + stride * out_W
            x_pad[:, :, i:i_max:stride, j:j_max:stride] += col[:, :, i, j, :, :]
    
    if pad > 0:
        return x_pad[:, :, pad:-pad, pad:-pad]
    return x_pad

In [None]:
class Conv2D:
    """2D Convolution Layer"""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # He initialization
        scale = np.sqrt(2.0 / (in_channels * kernel_size * kernel_size))
        self.W = np.random.randn(out_channels, in_channels, kernel_size, kernel_size) * scale
        self.b = np.zeros(out_channels)
        
        self.cache = None
        self.dW = None
        self.db = None
    
    def forward(self, x):
        """Forward pass using im2col
        
        Parameters
        ----------
        x : np.ndarray, shape (N, C_in, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C_out, H_out, W_out)
        """
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        
        # 計算輸出大小
        H_out = (H + 2 * self.padding - kH) // self.stride + 1
        W_out = (W + 2 * self.padding - kW) // self.stride + 1
        
        # im2col
        col = im2col(x, kH, kW, self.stride, self.padding)
        
        # Reshape weights
        W_col = self.W.reshape(self.out_channels, -1)
        
        # Matrix multiplication
        out = col @ W_col.T + self.b
        
        # Reshape output
        out = out.reshape(N, H_out, W_out, self.out_channels).transpose(0, 3, 1, 2)
        
        self.cache = (x, col)
        return out
    
    def backward(self, dout):
        """Backward pass
        
        Parameters
        ----------
        dout : np.ndarray, shape (N, C_out, H_out, W_out)
        
        Returns
        -------
        dx : np.ndarray, shape (N, C_in, H, W)
        """
        x, col = self.cache
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        
        # Reshape dout: (N, C_out, H_out, W_out) -> (N*H_out*W_out, C_out)
        dout_reshaped = dout.transpose(0, 2, 3, 1).reshape(-1, self.out_channels)
        
        # Gradient for W and b
        self.dW = (dout_reshaped.T @ col).reshape(self.W.shape)
        self.db = dout_reshaped.sum(axis=0)
        
        # Gradient for x
        W_col = self.W.reshape(self.out_channels, -1)
        dcol = dout_reshaped @ W_col
        dx = col2im(dcol, x.shape, kH, kW, self.stride, self.padding)
        
        return dx

In [None]:
class BatchNorm2D:
    """Batch Normalization for Conv layers"""
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        
        self.cache = None
        self.training = True
        
        self.dgamma = None
        self.dbeta = None
    
    def forward(self, x):
        """Forward pass
        
        Parameters
        ----------
        x : np.ndarray, shape (N, C, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C, H, W)
        """
        N, C, H, W = x.shape
        
        if self.training:
            # 計算 mean 和 var（對 N, H, W 取平均，保留 C）
            mean = x.mean(axis=(0, 2, 3))
            var = x.var(axis=(0, 2, 3))
            
            # 更新 running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        std = np.sqrt(var + self.eps)
        x_norm = (x - mean.reshape(1, C, 1, 1)) / std.reshape(1, C, 1, 1)
        
        # Scale and shift
        out = self.gamma.reshape(1, C, 1, 1) * x_norm + self.beta.reshape(1, C, 1, 1)
        
        if self.training:
            self.cache = (x, x_norm, mean, var, std)
        
        return out
    
    def backward(self, dout):
        """Backward pass"""
        x, x_norm, mean, var, std = self.cache
        N, C, H, W = x.shape
        m = N * H * W  # number of elements per channel
        
        # Gradients for gamma and beta
        self.dgamma = (dout * x_norm).sum(axis=(0, 2, 3))
        self.dbeta = dout.sum(axis=(0, 2, 3))
        
        # Gradient for x_norm
        dx_norm = dout * self.gamma.reshape(1, C, 1, 1)
        
        # Gradient for x
        dvar = (-0.5 * dx_norm * (x - mean.reshape(1, C, 1, 1)) / 
                (var.reshape(1, C, 1, 1) + self.eps) ** 1.5).sum(axis=(0, 2, 3))
        
        dmean = (-dx_norm / std.reshape(1, C, 1, 1)).sum(axis=(0, 2, 3))
        dmean += dvar * (-2 / m) * (x - mean.reshape(1, C, 1, 1)).sum(axis=(0, 2, 3))
        
        dx = dx_norm / std.reshape(1, C, 1, 1)
        dx += (2 / m) * dvar.reshape(1, C, 1, 1) * (x - mean.reshape(1, C, 1, 1))
        dx += dmean.reshape(1, C, 1, 1) / m
        
        return dx

In [None]:
class ReLU:
    """ReLU activation"""
    
    def __init__(self):
        self.cache = None
    
    def forward(self, x):
        self.cache = x
        return np.maximum(0, x)
    
    def backward(self, dout):
        x = self.cache
        return dout * (x > 0)

---

## 第四部分：Basic Block 實作

### 4.1 Basic Block 結構

```
              ┌────────────────┐
              │   Identity     │  (or 1x1 conv if dimension mismatch)
              │                │
x ────────────┼────────────────┼────────────┐
              │                │            │
              ↓                │            │
        ┌──────────┐          │            │
        │ Conv 3x3 │          │            │
        │   BN     │          │            │
        │  ReLU    │          │            │
        └────┬─────┘          │            │
             │                │            │
             ↓                │            │
        ┌──────────┐          │            │
        │ Conv 3x3 │          │            │
        │   BN     │          │            │
        └────┬─────┘          │            │
             │                │            │
             ↓                │            │
           F(x)               +            x (or shortcut(x))
             │                │            │
             └────────────────┼────────────┘
                              │
                              ↓
                            ReLU
                              │
                              ↓
                          F(x) + x
```

In [None]:
class BasicBlock:
    """
    ResNet Basic Block:
    x -> Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN -> (+x) -> ReLU
    
    如果 stride > 1 或 in_channels != out_channels，
    shortcut 需要用 1x1 conv 來調整維度。
    """
    
    def __init__(self, in_channels, out_channels, stride=1):
        """
        Parameters
        ----------
        in_channels : int
            輸入 channels 數
        out_channels : int
            輸出 channels 數
        stride : int
            第一個 conv 的 stride（用於 downsample）
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        # 主路徑
        # 第一個 conv：可能有 stride（用於 downsample）
        self.conv1 = Conv2D(in_channels, out_channels, kernel_size=3, 
                           stride=stride, padding=1)
        self.bn1 = BatchNorm2D(out_channels)
        self.relu1 = ReLU()
        
        # 第二個 conv：stride 固定為 1
        self.conv2 = Conv2D(out_channels, out_channels, kernel_size=3,
                           stride=1, padding=1)
        self.bn2 = BatchNorm2D(out_channels)
        
        # 最後的 ReLU（在 add 之後）
        self.relu2 = ReLU()
        
        # Shortcut：如果維度不匹配，需要 1x1 conv
        self.use_shortcut_conv = (stride != 1) or (in_channels != out_channels)
        if self.use_shortcut_conv:
            self.shortcut_conv = Conv2D(in_channels, out_channels, kernel_size=1,
                                        stride=stride, padding=0)
            self.shortcut_bn = BatchNorm2D(out_channels)
        
        self.cache = None
    
    def forward(self, x):
        """
        Forward pass
        
        Parameters
        ----------
        x : np.ndarray, shape (N, C_in, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C_out, H_out, W_out)
        """
        # 保存輸入
        identity = x
        
        # 主路徑
        out = self.conv1.forward(x)
        out = self.bn1.forward(out)
        out = self.relu1.forward(out)
        
        out = self.conv2.forward(out)
        out = self.bn2.forward(out)
        
        # Shortcut path
        if self.use_shortcut_conv:
            identity = self.shortcut_conv.forward(x)
            identity = self.shortcut_bn.forward(identity)
        
        # Residual addition
        out = out + identity
        
        # Final ReLU
        out = self.relu2.forward(out)
        
        self.cache = (x,)
        return out
    
    def backward(self, dout):
        """
        Backward pass
        
        關鍵：在 residual addition 處，梯度分流到兩條路徑
        
        Parameters
        ----------
        dout : np.ndarray, shape (N, C_out, H_out, W_out)
        
        Returns
        -------
        dx : np.ndarray, shape (N, C_in, H, W)
        """
        x, = self.cache
        
        # Backward through final ReLU
        dout = self.relu2.backward(dout)
        
        # 在 addition 處分流
        # d(out + identity) = dout for both branches
        d_main = dout  # 主路徑的梯度
        d_shortcut = dout.copy()  # shortcut 的梯度
        
        # 主路徑 backward
        d_main = self.bn2.backward(d_main)
        d_main = self.conv2.backward(d_main)
        
        d_main = self.relu1.backward(d_main)
        d_main = self.bn1.backward(d_main)
        d_main = self.conv1.backward(d_main)
        
        # Shortcut backward
        if self.use_shortcut_conv:
            d_shortcut = self.shortcut_bn.backward(d_shortcut)
            d_shortcut = self.shortcut_conv.backward(d_shortcut)
        
        # 合併兩條路徑的梯度
        dx = d_main + d_shortcut
        
        return dx
    
    def train(self):
        """Set to training mode"""
        self.bn1.training = True
        self.bn2.training = True
        if self.use_shortcut_conv:
            self.shortcut_bn.training = True
    
    def eval(self):
        """Set to evaluation mode"""
        self.bn1.training = False
        self.bn2.training = False
        if self.use_shortcut_conv:
            self.shortcut_bn.training = False

In [None]:
# 測試 BasicBlock forward

# Case 1: 維度不變
block1 = BasicBlock(in_channels=16, out_channels=16, stride=1)
x1 = np.random.randn(2, 16, 8, 8)
out1 = block1.forward(x1)
print(f"Case 1 (same dim): input {x1.shape} -> output {out1.shape}")

# Case 2: Downsample（stride=2）
block2 = BasicBlock(in_channels=16, out_channels=32, stride=2)
x2 = np.random.randn(2, 16, 8, 8)
out2 = block2.forward(x2)
print(f"Case 2 (downsample): input {x2.shape} -> output {out2.shape}")

# Case 3: 只改 channels
block3 = BasicBlock(in_channels=16, out_channels=32, stride=1)
x3 = np.random.randn(2, 16, 8, 8)
out3 = block3.forward(x3)
print(f"Case 3 (channel change): input {x3.shape} -> output {out3.shape}")

### 4.2 Gradient Check for BasicBlock

In [None]:
def numerical_gradient(f, x, eps=1e-5):
    """計算數值梯度"""
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    
    while not it.finished:
        idx = it.multi_index
        old_val = x[idx]
        
        x[idx] = old_val + eps
        fxh1 = f(x)
        
        x[idx] = old_val - eps
        fxh2 = f(x)
        
        grad[idx] = (fxh1 - fxh2) / (2 * eps)
        x[idx] = old_val
        
        it.iternext()
    
    return grad

def gradient_check(analytic_grad, numeric_grad, name=""):
    """比較解析梯度和數值梯度"""
    diff = np.abs(analytic_grad - numeric_grad)
    max_diff = np.max(diff)
    rel_error = max_diff / (np.maximum(np.abs(analytic_grad).max(), np.abs(numeric_grad).max()) + 1e-8)
    
    status = "✓ PASS" if rel_error < 1e-4 else "✗ FAIL"
    print(f"{name}: max_diff = {max_diff:.2e}, rel_error = {rel_error:.2e} {status}")
    return rel_error < 1e-4

In [None]:
# Gradient check for BasicBlock
print("Gradient Check for BasicBlock")
print("=" * 50)

# 使用小型輸入避免太慢
np.random.seed(42)
block = BasicBlock(in_channels=4, out_channels=8, stride=2)
x = np.random.randn(2, 4, 6, 6)

# Forward and backward
out = block.forward(x)
dout = np.random.randn(*out.shape)
dx_analytic = block.backward(dout)

# Numerical gradient for dx
def f_dx(x_):
    block_copy = BasicBlock(in_channels=4, out_channels=8, stride=2)
    # Copy weights
    block_copy.conv1.W = block.conv1.W.copy()
    block_copy.conv1.b = block.conv1.b.copy()
    block_copy.conv2.W = block.conv2.W.copy()
    block_copy.conv2.b = block.conv2.b.copy()
    block_copy.bn1.gamma = block.bn1.gamma.copy()
    block_copy.bn1.beta = block.bn1.beta.copy()
    block_copy.bn2.gamma = block.bn2.gamma.copy()
    block_copy.bn2.beta = block.bn2.beta.copy()
    if block.use_shortcut_conv:
        block_copy.shortcut_conv.W = block.shortcut_conv.W.copy()
        block_copy.shortcut_conv.b = block.shortcut_conv.b.copy()
        block_copy.shortcut_bn.gamma = block.shortcut_bn.gamma.copy()
        block_copy.shortcut_bn.beta = block.shortcut_bn.beta.copy()
    out_ = block_copy.forward(x_)
    return np.sum(out_ * dout)

dx_numeric = numerical_gradient(f_dx, x.copy())
gradient_check(dx_analytic, dx_numeric, "dx")

print("\nBasicBlock gradient check completed!")

---

## 第五部分：Bottleneck Block（選做）

### 5.1 Bottleneck 結構

對於更深的 ResNet（ResNet-50 及以上），使用 Bottleneck Block 來減少計算量：

```
x ────────────┬────────────────────────────────┐
              │                                │
              ↓                                │
        ┌──────────┐                          │
        │ Conv 1x1 │  減少 channels           │
        │   BN     │  (e.g., 256 -> 64)       │
        │  ReLU    │                          │
        └────┬─────┘                          │
             │                                │
             ↓                                │
        ┌──────────┐                          │
        │ Conv 3x3 │  主要計算                │ shortcut
        │   BN     │  (64 channels)           │
        │  ReLU    │                          │
        └────┬─────┘                          │
             │                                │
             ↓                                │
        ┌──────────┐                          │
        │ Conv 1x1 │  恢復 channels           │
        │   BN     │  (64 -> 256)             │
        └────┬─────┘                          │
             │                                │
             +────────────────────────────────┘
             │
           ReLU
             │
             ↓
```

**為什麼這樣設計？**
- 3x3 conv 的計算量正比於 $C^2$（input channels × output channels）
- 先用 1x1 conv 把 channels 降到 1/4，再做 3x3，再恢復
- 計算量：$(1/4)^2 = 1/16$ 的 3x3 conv 計算量

In [None]:
class BottleneckBlock:
    """
    ResNet Bottleneck Block:
    x -> Conv1x1 -> BN -> ReLU -> Conv3x3 -> BN -> ReLU -> Conv1x1 -> BN -> (+x) -> ReLU
    
    expansion = 4: 輸出 channels = 4 * mid_channels
    """
    
    expansion = 4  # 輸出 channels 是中間 channels 的 4 倍
    
    def __init__(self, in_channels, mid_channels, stride=1):
        """
        Parameters
        ----------
        in_channels : int
            輸入 channels
        mid_channels : int
            中間層（3x3 conv）的 channels
        stride : int
            3x3 conv 的 stride（用於 downsample）
        """
        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.out_channels = mid_channels * self.expansion
        self.stride = stride
        
        # 1x1 conv: 降維
        self.conv1 = Conv2D(in_channels, mid_channels, kernel_size=1,
                           stride=1, padding=0)
        self.bn1 = BatchNorm2D(mid_channels)
        self.relu1 = ReLU()
        
        # 3x3 conv: 主要計算（可能有 stride）
        self.conv2 = Conv2D(mid_channels, mid_channels, kernel_size=3,
                           stride=stride, padding=1)
        self.bn2 = BatchNorm2D(mid_channels)
        self.relu2 = ReLU()
        
        # 1x1 conv: 升維
        self.conv3 = Conv2D(mid_channels, self.out_channels, kernel_size=1,
                           stride=1, padding=0)
        self.bn3 = BatchNorm2D(self.out_channels)
        
        # Final ReLU
        self.relu3 = ReLU()
        
        # Shortcut
        self.use_shortcut_conv = (stride != 1) or (in_channels != self.out_channels)
        if self.use_shortcut_conv:
            self.shortcut_conv = Conv2D(in_channels, self.out_channels, kernel_size=1,
                                        stride=stride, padding=0)
            self.shortcut_bn = BatchNorm2D(self.out_channels)
        
        self.cache = None
    
    def forward(self, x):
        """Forward pass"""
        identity = x
        
        # 1x1 conv (reduce)
        out = self.conv1.forward(x)
        out = self.bn1.forward(out)
        out = self.relu1.forward(out)
        
        # 3x3 conv
        out = self.conv2.forward(out)
        out = self.bn2.forward(out)
        out = self.relu2.forward(out)
        
        # 1x1 conv (expand)
        out = self.conv3.forward(out)
        out = self.bn3.forward(out)
        
        # Shortcut
        if self.use_shortcut_conv:
            identity = self.shortcut_conv.forward(x)
            identity = self.shortcut_bn.forward(identity)
        
        # Residual addition
        out = out + identity
        
        # Final ReLU
        out = self.relu3.forward(out)
        
        self.cache = (x,)
        return out
    
    def backward(self, dout):
        """Backward pass"""
        x, = self.cache
        
        # Final ReLU backward
        dout = self.relu3.backward(dout)
        
        # Split gradient
        d_main = dout
        d_shortcut = dout.copy()
        
        # Main path backward
        d_main = self.bn3.backward(d_main)
        d_main = self.conv3.backward(d_main)
        
        d_main = self.relu2.backward(d_main)
        d_main = self.bn2.backward(d_main)
        d_main = self.conv2.backward(d_main)
        
        d_main = self.relu1.backward(d_main)
        d_main = self.bn1.backward(d_main)
        d_main = self.conv1.backward(d_main)
        
        # Shortcut backward
        if self.use_shortcut_conv:
            d_shortcut = self.shortcut_bn.backward(d_shortcut)
            d_shortcut = self.shortcut_conv.backward(d_shortcut)
        
        # Merge gradients
        dx = d_main + d_shortcut
        
        return dx
    
    def train(self):
        """Set to training mode"""
        for bn in [self.bn1, self.bn2, self.bn3]:
            bn.training = True
        if self.use_shortcut_conv:
            self.shortcut_bn.training = True
    
    def eval(self):
        """Set to evaluation mode"""
        for bn in [self.bn1, self.bn2, self.bn3]:
            bn.training = False
        if self.use_shortcut_conv:
            self.shortcut_bn.training = False

In [None]:
# 測試 BottleneckBlock

# 標準配置：in=64 -> mid=64 -> out=256
bottleneck1 = BottleneckBlock(in_channels=64, mid_channels=64, stride=1)
x1 = np.random.randn(2, 64, 8, 8)
out1 = bottleneck1.forward(x1)
print(f"Bottleneck (standard): input {x1.shape} -> output {out1.shape}")
print(f"  Expected output channels: 64 * 4 = 256")

# Downsample: in=256 -> mid=128 -> out=512
bottleneck2 = BottleneckBlock(in_channels=256, mid_channels=128, stride=2)
x2 = np.random.randn(2, 256, 8, 8)
out2 = bottleneck2.forward(x2)
print(f"\nBottleneck (downsample): input {x2.shape} -> output {out2.shape}")
print(f"  Expected output channels: 128 * 4 = 512")

In [None]:
# 計算量比較

def count_conv_flops(in_c, out_c, k, h, w):
    """計算卷積的浮點運算數（乘加次數）"""
    # 每個輸出點需要 in_c * k * k 次乘加
    return out_c * h * w * (in_c * k * k)

# 假設輸入 256 channels, 56x56 feature map
h, w = 56, 56
in_c = 256

# Basic Block: 兩個 3x3 conv
basic_flops = count_conv_flops(in_c, in_c, 3, h, w) * 2

# Bottleneck: 1x1 (256->64) + 3x3 (64->64) + 1x1 (64->256)
mid_c = 64
bottleneck_flops = (count_conv_flops(in_c, mid_c, 1, h, w) +  # 1x1 reduce
                   count_conv_flops(mid_c, mid_c, 3, h, w) +   # 3x3
                   count_conv_flops(mid_c, in_c, 1, h, w))     # 1x1 expand

print(f"Input: {in_c} channels, {h}x{w} feature map")
print(f"\nBasic Block FLOPs:     {basic_flops:,}")
print(f"Bottleneck Block FLOPs: {bottleneck_flops:,}")
print(f"\nSpeedup ratio: {basic_flops / bottleneck_flops:.2f}x")
print(f"\n注意：Bottleneck 的輸出 channels 是 {mid_c * 4} = {mid_c}*4")

---

## 第六部分：組裝 ResNet-style 網路

現在我們用 BasicBlock 組裝一個小型 ResNet。

In [None]:
class GlobalAvgPool2D:
    """Global Average Pooling"""
    
    def __init__(self):
        self.cache = None
    
    def forward(self, x):
        """(N, C, H, W) -> (N, C)"""
        self.cache = x.shape
        return x.mean(axis=(2, 3))
    
    def backward(self, dout):
        """(N, C) -> (N, C, H, W)"""
        N, C, H, W = self.cache
        dx = dout.reshape(N, C, 1, 1) / (H * W)
        dx = np.broadcast_to(dx, (N, C, H, W)).copy()
        return dx


class FullyConnected:
    """Fully Connected Layer"""
    
    def __init__(self, in_features, out_features):
        scale = np.sqrt(2.0 / in_features)
        self.W = np.random.randn(in_features, out_features) * scale
        self.b = np.zeros(out_features)
        self.cache = None
        self.dW = None
        self.db = None
    
    def forward(self, x):
        self.cache = x
        return x @ self.W + self.b
    
    def backward(self, dout):
        x = self.cache
        self.dW = x.T @ dout
        self.db = dout.sum(axis=0)
        return dout @ self.W.T

In [None]:
class TinyResNet:
    """
    一個小型 ResNet-style 網路
    
    結構：
    Conv3x3 -> BN -> ReLU -> 
    BasicBlock (16) x 2 ->
    BasicBlock (32, stride=2) -> BasicBlock (32) ->
    GlobalAvgPool -> FC -> Softmax
    """
    
    def __init__(self, num_classes=10):
        self.num_classes = num_classes
        
        # 初始卷積
        self.conv1 = Conv2D(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = BatchNorm2D(16)
        self.relu1 = ReLU()
        
        # Layer 1: 2 x BasicBlock (16 channels)
        self.layer1_block1 = BasicBlock(16, 16, stride=1)
        self.layer1_block2 = BasicBlock(16, 16, stride=1)
        
        # Layer 2: 2 x BasicBlock (32 channels), first one downsamples
        self.layer2_block1 = BasicBlock(16, 32, stride=2)
        self.layer2_block2 = BasicBlock(32, 32, stride=1)
        
        # Global Average Pooling + FC
        self.gap = GlobalAvgPool2D()
        self.fc = FullyConnected(32, num_classes)
    
    def forward(self, x):
        """Forward pass
        
        Parameters
        ----------
        x : np.ndarray, shape (N, 1, H, W)
        
        Returns
        -------
        logits : np.ndarray, shape (N, num_classes)
        """
        # Initial conv
        out = self.conv1.forward(x)
        out = self.bn1.forward(out)
        out = self.relu1.forward(out)
        
        # Layer 1
        out = self.layer1_block1.forward(out)
        out = self.layer1_block2.forward(out)
        
        # Layer 2
        out = self.layer2_block1.forward(out)
        out = self.layer2_block2.forward(out)
        
        # GAP + FC
        out = self.gap.forward(out)
        logits = self.fc.forward(out)
        
        return logits
    
    def backward(self, dlogits):
        """Backward pass"""
        # FC + GAP
        dout = self.fc.backward(dlogits)
        dout = self.gap.backward(dout)
        
        # Layer 2
        dout = self.layer2_block2.backward(dout)
        dout = self.layer2_block1.backward(dout)
        
        # Layer 1
        dout = self.layer1_block2.backward(dout)
        dout = self.layer1_block1.backward(dout)
        
        # Initial conv
        dout = self.relu1.backward(dout)
        dout = self.bn1.backward(dout)
        dout = self.conv1.backward(dout)
        
        return dout
    
    def get_params_and_grads(self):
        """獲取所有參數和梯度"""
        params = []
        grads = []
        
        # Conv1
        params.extend([self.conv1.W, self.conv1.b, self.bn1.gamma, self.bn1.beta])
        grads.extend([self.conv1.dW, self.conv1.db, self.bn1.dgamma, self.bn1.dbeta])
        
        # Layer 1
        for block in [self.layer1_block1, self.layer1_block2]:
            params.extend([block.conv1.W, block.conv1.b, block.bn1.gamma, block.bn1.beta,
                          block.conv2.W, block.conv2.b, block.bn2.gamma, block.bn2.beta])
            grads.extend([block.conv1.dW, block.conv1.db, block.bn1.dgamma, block.bn1.dbeta,
                         block.conv2.dW, block.conv2.db, block.bn2.dgamma, block.bn2.dbeta])
        
        # Layer 2
        for block in [self.layer2_block1, self.layer2_block2]:
            params.extend([block.conv1.W, block.conv1.b, block.bn1.gamma, block.bn1.beta,
                          block.conv2.W, block.conv2.b, block.bn2.gamma, block.bn2.beta])
            grads.extend([block.conv1.dW, block.conv1.db, block.bn1.dgamma, block.bn1.dbeta,
                         block.conv2.dW, block.conv2.db, block.bn2.dgamma, block.bn2.dbeta])
            if block.use_shortcut_conv:
                params.extend([block.shortcut_conv.W, block.shortcut_conv.b,
                              block.shortcut_bn.gamma, block.shortcut_bn.beta])
                grads.extend([block.shortcut_conv.dW, block.shortcut_conv.db,
                             block.shortcut_bn.dgamma, block.shortcut_bn.dbeta])
        
        # FC
        params.extend([self.fc.W, self.fc.b])
        grads.extend([self.fc.dW, self.fc.db])
        
        return params, grads
    
    def train(self):
        """Set to training mode"""
        self.bn1.training = True
        for block in [self.layer1_block1, self.layer1_block2,
                      self.layer2_block1, self.layer2_block2]:
            block.train()
    
    def eval(self):
        """Set to evaluation mode"""
        self.bn1.training = False
        for block in [self.layer1_block1, self.layer1_block2,
                      self.layer2_block1, self.layer2_block2]:
            block.eval()

In [None]:
# 測試 TinyResNet

model = TinyResNet(num_classes=10)
x = np.random.randn(4, 1, 32, 32)  # 類似 MNIST 的輸入（但是 32x32）

# Forward
logits = model.forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")

# 統計參數數量
params, _ = model.get_params_and_grads()
total_params = sum(p.size for p in params)
print(f"\nTotal parameters: {total_params:,}")

---

## 第七部分：訓練比較 - 有無 Residual Connection

為了展示 residual connection 的效果，我們比較有無 residual 的網路訓練。

In [None]:
# 建立一個沒有 residual connection 的對照網路

class PlainBlock:
    """沒有 residual connection 的 block"""
    
    def __init__(self, in_channels, out_channels, stride=1):
        self.conv1 = Conv2D(in_channels, out_channels, kernel_size=3,
                           stride=stride, padding=1)
        self.bn1 = BatchNorm2D(out_channels)
        self.relu1 = ReLU()
        
        self.conv2 = Conv2D(out_channels, out_channels, kernel_size=3,
                           stride=1, padding=1)
        self.bn2 = BatchNorm2D(out_channels)
        self.relu2 = ReLU()
        
        self.cache = None
    
    def forward(self, x):
        out = self.conv1.forward(x)
        out = self.bn1.forward(out)
        out = self.relu1.forward(out)
        
        out = self.conv2.forward(out)
        out = self.bn2.forward(out)
        out = self.relu2.forward(out)  # 沒有 + x
        
        self.cache = (x,)
        return out
    
    def backward(self, dout):
        dout = self.relu2.backward(dout)
        dout = self.bn2.backward(dout)
        dout = self.conv2.backward(dout)
        
        dout = self.relu1.backward(dout)
        dout = self.bn1.backward(dout)
        dout = self.conv1.backward(dout)
        
        return dout
    
    def train(self):
        self.bn1.training = True
        self.bn2.training = True
    
    def eval(self):
        self.bn1.training = False
        self.bn2.training = False

In [None]:
class TinyPlainNet:
    """沒有 residual connection 的網路"""
    
    def __init__(self, num_classes=10):
        self.num_classes = num_classes
        
        self.conv1 = Conv2D(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = BatchNorm2D(16)
        self.relu1 = ReLU()
        
        self.layer1_block1 = PlainBlock(16, 16, stride=1)
        self.layer1_block2 = PlainBlock(16, 16, stride=1)
        
        # 需要額外處理 channel 變化
        self.transition = Conv2D(16, 32, kernel_size=1, stride=2, padding=0)
        self.transition_bn = BatchNorm2D(32)
        self.transition_relu = ReLU()
        
        self.layer2_block1 = PlainBlock(32, 32, stride=1)
        self.layer2_block2 = PlainBlock(32, 32, stride=1)
        
        self.gap = GlobalAvgPool2D()
        self.fc = FullyConnected(32, num_classes)
    
    def forward(self, x):
        out = self.conv1.forward(x)
        out = self.bn1.forward(out)
        out = self.relu1.forward(out)
        
        out = self.layer1_block1.forward(out)
        out = self.layer1_block2.forward(out)
        
        out = self.transition.forward(out)
        out = self.transition_bn.forward(out)
        out = self.transition_relu.forward(out)
        
        out = self.layer2_block1.forward(out)
        out = self.layer2_block2.forward(out)
        
        out = self.gap.forward(out)
        logits = self.fc.forward(out)
        
        return logits
    
    def backward(self, dlogits):
        dout = self.fc.backward(dlogits)
        dout = self.gap.backward(dout)
        
        dout = self.layer2_block2.backward(dout)
        dout = self.layer2_block1.backward(dout)
        
        dout = self.transition_relu.backward(dout)
        dout = self.transition_bn.backward(dout)
        dout = self.transition.backward(dout)
        
        dout = self.layer1_block2.backward(dout)
        dout = self.layer1_block1.backward(dout)
        
        dout = self.relu1.backward(dout)
        dout = self.bn1.backward(dout)
        dout = self.conv1.backward(dout)
        
        return dout
    
    def get_params_and_grads(self):
        params = []
        grads = []
        
        params.extend([self.conv1.W, self.conv1.b, self.bn1.gamma, self.bn1.beta])
        grads.extend([self.conv1.dW, self.conv1.db, self.bn1.dgamma, self.bn1.dbeta])
        
        for block in [self.layer1_block1, self.layer1_block2]:
            params.extend([block.conv1.W, block.conv1.b, block.bn1.gamma, block.bn1.beta,
                          block.conv2.W, block.conv2.b, block.bn2.gamma, block.bn2.beta])
            grads.extend([block.conv1.dW, block.conv1.db, block.bn1.dgamma, block.bn1.dbeta,
                         block.conv2.dW, block.conv2.db, block.bn2.dgamma, block.bn2.dbeta])
        
        params.extend([self.transition.W, self.transition.b,
                      self.transition_bn.gamma, self.transition_bn.beta])
        grads.extend([self.transition.dW, self.transition.db,
                     self.transition_bn.dgamma, self.transition_bn.dbeta])
        
        for block in [self.layer2_block1, self.layer2_block2]:
            params.extend([block.conv1.W, block.conv1.b, block.bn1.gamma, block.bn1.beta,
                          block.conv2.W, block.conv2.b, block.bn2.gamma, block.bn2.beta])
            grads.extend([block.conv1.dW, block.conv1.db, block.bn1.dgamma, block.bn1.dbeta,
                         block.conv2.dW, block.conv2.db, block.bn2.dgamma, block.bn2.dbeta])
        
        params.extend([self.fc.W, self.fc.b])
        grads.extend([self.fc.dW, self.fc.db])
        
        return params, grads
    
    def train(self):
        self.bn1.training = True
        self.transition_bn.training = True
        for block in [self.layer1_block1, self.layer1_block2,
                      self.layer2_block1, self.layer2_block2]:
            block.train()
    
    def eval(self):
        self.bn1.training = False
        self.transition_bn.training = False
        for block in [self.layer1_block1, self.layer1_block2,
                      self.layer2_block1, self.layer2_block2]:
            block.eval()

In [None]:
# 生成合成數據集

def generate_simple_shapes(n_samples, img_size=32):
    """生成簡單的形狀分類數據
    
    4 類：水平線、垂直線、對角線（左上到右下）、對角線（右上到左下）
    """
    X = np.zeros((n_samples, 1, img_size, img_size))
    y = np.zeros(n_samples, dtype=int)
    
    for i in range(n_samples):
        label = np.random.randint(4)
        y[i] = label
        
        # 隨機位置和一點噪音
        noise = np.random.randn(img_size, img_size) * 0.1
        
        if label == 0:  # 水平線
            row = np.random.randint(5, img_size - 5)
            X[i, 0, row-1:row+2, 5:-5] = 1
        elif label == 1:  # 垂直線
            col = np.random.randint(5, img_size - 5)
            X[i, 0, 5:-5, col-1:col+2] = 1
        elif label == 2:  # 對角線 (\\)
            for j in range(-1, 2):
                np.fill_diagonal(X[i, 0, :, max(0, j):], 1)
        else:  # 對角線 (/)
            for j in range(-1, 2):
                np.fill_diagonal(np.fliplr(X[i, 0])[:, max(0, j):], 1)
        
        X[i, 0] += noise
    
    return X.astype(np.float32), y

# 生成數據
np.random.seed(42)
X_train, y_train = generate_simple_shapes(400, img_size=32)
X_test, y_test = generate_simple_shapes(100, img_size=32)

print(f"Training set: {X_train.shape}, {y_train.shape}")
print(f"Test set: {X_test.shape}, {y_test.shape}")

# 顯示樣本
fig, axes = plt.subplots(2, 4, figsize=(10, 5))
labels = ['Horizontal', 'Vertical', 'Diagonal \\', 'Diagonal /']
for i in range(4):
    idx = np.where(y_train == i)[0][0]
    axes[0, i].imshow(X_train[idx, 0], cmap='gray')
    axes[0, i].set_title(labels[i])
    axes[0, i].axis('off')
    
    idx = np.where(y_train == i)[0][1]
    axes[1, i].imshow(X_train[idx, 0], cmap='gray')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
def softmax(x):
    """Numerically stable softmax"""
    x_max = np.max(x, axis=1, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)

def cross_entropy_loss(logits, y):
    """Cross-entropy loss"""
    N = logits.shape[0]
    probs = softmax(logits)
    loss = -np.mean(np.log(probs[np.arange(N), y] + 1e-8))
    
    # Gradient
    dlogits = probs.copy()
    dlogits[np.arange(N), y] -= 1
    dlogits /= N
    
    return loss, dlogits

def accuracy(logits, y):
    """Compute accuracy"""
    preds = np.argmax(logits, axis=1)
    return np.mean(preds == y)

In [None]:
def train_model(model, X_train, y_train, X_test, y_test, 
                epochs=30, batch_size=32, lr=0.01, momentum=0.9):
    """訓練模型"""
    n_samples = X_train.shape[0]
    n_batches = n_samples // batch_size
    
    # Momentum velocities
    params, _ = model.get_params_and_grads()
    velocities = [np.zeros_like(p) for p in params]
    
    train_losses = []
    train_accs = []
    test_accs = []
    
    for epoch in range(epochs):
        # Shuffle
        indices = np.random.permutation(n_samples)
        X_shuffled = X_train[indices]
        y_shuffled = y_train[indices]
        
        epoch_loss = 0
        epoch_correct = 0
        
        model.train()
        
        for i in range(n_batches):
            start = i * batch_size
            end = start + batch_size
            X_batch = X_shuffled[start:end]
            y_batch = y_shuffled[start:end]
            
            # Forward
            logits = model.forward(X_batch)
            loss, dlogits = cross_entropy_loss(logits, y_batch)
            
            # Backward
            model.backward(dlogits)
            
            # Update with momentum SGD
            params, grads = model.get_params_and_grads()
            for j, (p, g) in enumerate(zip(params, grads)):
                if g is not None:
                    velocities[j] = momentum * velocities[j] - lr * g
                    p += velocities[j]
            
            epoch_loss += loss
            epoch_correct += np.sum(np.argmax(logits, axis=1) == y_batch)
        
        # Epoch metrics
        train_loss = epoch_loss / n_batches
        train_acc = epoch_correct / (n_batches * batch_size)
        
        # Test accuracy
        model.eval()
        test_logits = model.forward(X_test)
        test_acc = accuracy(test_logits, y_test)
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_accs.append(test_acc)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:3d}: loss={train_loss:.4f}, "
                  f"train_acc={train_acc:.4f}, test_acc={test_acc:.4f}")
    
    return train_losses, train_accs, test_accs

In [None]:
# 訓練兩個網路

print("Training ResNet (with residual connections):")
print("=" * 50)
np.random.seed(42)
resnet = TinyResNet(num_classes=4)
resnet_losses, resnet_train_accs, resnet_test_accs = train_model(
    resnet, X_train, y_train, X_test, y_test, 
    epochs=30, batch_size=32, lr=0.01
)

print("\n" + "=" * 50)
print("Training PlainNet (without residual connections):")
print("=" * 50)
np.random.seed(42)
plainnet = TinyPlainNet(num_classes=4)
plain_losses, plain_train_accs, plain_test_accs = train_model(
    plainnet, X_train, y_train, X_test, y_test,
    epochs=30, batch_size=32, lr=0.01
)

In [None]:
# 比較結果

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Loss
axes[0].plot(resnet_losses, 'b-', label='ResNet')
axes[0].plot(plain_losses, 'r-', label='PlainNet')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Train accuracy
axes[1].plot(resnet_train_accs, 'b-', label='ResNet')
axes[1].plot(plain_train_accs, 'r-', label='PlainNet')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Test accuracy
axes[2].plot(resnet_test_accs, 'b-', label='ResNet')
axes[2].plot(plain_test_accs, 'r-', label='PlainNet')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy')
axes[2].set_title('Test Accuracy')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"ResNet   - Train Acc: {resnet_train_accs[-1]:.4f}, Test Acc: {resnet_test_accs[-1]:.4f}")
print(f"PlainNet - Train Acc: {plain_train_accs[-1]:.4f}, Test Acc: {plain_test_accs[-1]:.4f}")

---

## 第八部分：梯度流動分析

觀察兩種網路的梯度在各層的分佈。

In [None]:
def analyze_gradients(model, X_batch, y_batch):
    """分析各層梯度的統計量"""
    model.train()
    
    # Forward and backward
    logits = model.forward(X_batch)
    loss, dlogits = cross_entropy_loss(logits, y_batch)
    model.backward(dlogits)
    
    # Collect gradient statistics
    _, grads = model.get_params_and_grads()
    
    grad_norms = []
    for g in grads:
        if g is not None:
            grad_norms.append(np.linalg.norm(g))
    
    return grad_norms

# 比較梯度
X_batch = X_train[:32]
y_batch = y_train[:32]

resnet_grads = analyze_gradients(resnet, X_batch, y_batch)
plain_grads = analyze_gradients(plainnet, X_batch, y_batch)

# 繪圖
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].bar(range(len(resnet_grads)), resnet_grads, alpha=0.7)
axes[0].set_xlabel('Parameter Index')
axes[0].set_ylabel('Gradient Norm')
axes[0].set_title('ResNet Gradient Norms')
axes[0].set_yscale('log')

axes[1].bar(range(len(plain_grads)), plain_grads, alpha=0.7, color='orange')
axes[1].set_xlabel('Parameter Index')
axes[1].set_ylabel('Gradient Norm')
axes[1].set_title('PlainNet Gradient Norms')
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

print(f"ResNet gradient norm range: [{min(resnet_grads):.2e}, {max(resnet_grads):.2e}]")
print(f"PlainNet gradient norm range: [{min(plain_grads):.2e}, {max(plain_grads):.2e}]")

---

## 總結

### ResNet 的核心貢獻

1. **Residual Connection**: $H(x) = F(x) + x$
   - 提供「高速公路」讓梯度直接回傳
   - 使得網路只需學習殘差 $F(x) = H(x) - x$

2. **為什麼有效**：
   - 梯度至少為 1（$\frac{\partial H}{\partial x} = \frac{\partial F}{\partial x} + I$）
   - 如果最優解接近 identity，學習 $F \approx 0$ 很容易
   - 允許訓練非常深的網路（152 層甚至更深）

3. **實作要點**：
   - Basic Block: Conv3x3 → BN → ReLU → Conv3x3 → BN → (+x) → ReLU
   - Bottleneck Block: 1x1 → 3x3 → 1x1 來減少計算量
   - 維度不匹配時用 1x1 conv 調整 shortcut

### 下一步

- Module 6 Part 3: U-Net 架構（encoder-decoder + skip connections）