# 01 Batch Normalization

## 學習目標

1. 理解 Batch Normalization 解決什麼問題
2. 實作 BatchNorm 的 forward（含 training/inference 模式）
3. 推導並實作 BatchNorm 的 backward
4. 使用梯度檢驗驗證實作

## 為什麼需要 Batch Normalization？

**Internal Covariate Shift 問題**：在訓練過程中，每一層的輸入分佈會隨著前面層參數的改變而改變。這使得後面的層需要不斷適應新的輸入分佈。

**Batch Normalization** (Ioffe & Szegedy, 2015) 的核心思想：
- 在每一層的激活前，將輸入正規化到均值 0、方差 1
- 然後用可學習的參數 $\gamma$（scale）和 $\beta$（shift）來恢復網路的表達能力

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)
print("Batch Normalization module loaded!")

## 第一部分：BatchNorm 前向傳播

### 公式

給定一個 mini-batch $\mathcal{B} = \{x_1, ..., x_m\}$：

1. **計算 batch 統計量**：
   $$\mu_{\mathcal{B}} = \frac{1}{m} \sum_{i=1}^{m} x_i$$
   $$\sigma^2_{\mathcal{B}} = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{\mathcal{B}})^2$$

2. **正規化**：
   $$\hat{x}_i = \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma^2_{\mathcal{B}} + \epsilon}}$$

3. **Scale 和 Shift**：
   $$y_i = \gamma \hat{x}_i + \beta$$

其中 $\gamma$ 和 $\beta$ 是可學習的參數。

### Training vs Inference

- **Training**：使用當前 batch 的 $\mu$ 和 $\sigma^2$，同時更新 running statistics
- **Inference**：使用 running mean 和 running variance

In [None]:
class BatchNorm1D:
    """
    一維 Batch Normalization（用於全連接層後）
    
    輸入形狀：(N, D)
    對每個特徵維度 D 分別正規化
    """
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        """
        Parameters
        ----------
        num_features : int
            特徵維度 D
        eps : float
            數值穩定性常數
        momentum : float
            running statistics 的動量
        """
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 可學習參數
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        
        # Running statistics（用於 inference）
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        
        # 梯度
        self.dgamma = None
        self.dbeta = None
        
        # 快取
        self.cache = None
        self.training = True
    
    def forward(self, x):
        """
        前向傳播
        
        Parameters
        ----------
        x : np.ndarray, shape (N, D)
        
        Returns
        -------
        y : np.ndarray, shape (N, D)
        """
        if self.training:
            # 計算 batch 統計量
            batch_mean = np.mean(x, axis=0)  # (D,)
            batch_var = np.var(x, axis=0)    # (D,)
            
            # 正規化
            x_centered = x - batch_mean
            std = np.sqrt(batch_var + self.eps)
            x_norm = x_centered / std
            
            # Scale 和 Shift
            y = self.gamma * x_norm + self.beta
            
            # 更新 running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
            
            # 儲存快取
            self.cache = (x, x_centered, std, x_norm)
        else:
            # 使用 running statistics
            x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
            y = self.gamma * x_norm + self.beta
        
        return y
    
    def backward(self, dout):
        """
        反向傳播
        
        Parameters
        ----------
        dout : np.ndarray, shape (N, D)
        
        Returns
        -------
        dx : np.ndarray, shape (N, D)
        """
        x, x_centered, std, x_norm = self.cache
        N = x.shape[0]
        
        # 對 gamma 和 beta 的梯度
        self.dgamma = np.sum(dout * x_norm, axis=0)
        self.dbeta = np.sum(dout, axis=0)
        
        # 對 x 的梯度（這是最複雜的部分）
        # 推導見下方說明
        dx_norm = dout * self.gamma
        
        # 計算 d(1/std)
        dvar = np.sum(dx_norm * x_centered, axis=0) * (-0.5) * (std ** -3)
        
        # 計算 d(mean)
        dmean = np.sum(dx_norm * (-1 / std), axis=0) + dvar * np.mean(-2 * x_centered, axis=0)
        
        # 最終的 dx
        dx = dx_norm / std + dvar * 2 * x_centered / N + dmean / N
        
        return dx
    
    def train(self):
        self.training = True
    
    def eval(self):
        self.training = False

# 測試
bn = BatchNorm1D(num_features=5)
x = np.random.randn(10, 5) * 3 + 2  # 非標準分佈

print("輸入統計量:")
print(f"  mean: {np.mean(x, axis=0)}")
print(f"  std: {np.std(x, axis=0)}")

y = bn.forward(x)

print("\n輸出統計量（正規化後）:")
print(f"  mean: {np.mean(y, axis=0)}")
print(f"  std: {np.std(y, axis=0)}")
print("  （應該接近 mean=0, std=1）")

## 第二部分：反向傳播推導

這是 BatchNorm 中最複雜的部分。讓我們用計算圖來推導。

### 計算圖

```
x → [mean] → μ
         ↓
x → [x - μ] → x_centered → [mean(x²)] → var
                    ↓                      ↓
                    ↓                [sqrt(var + ε)] → std
                    ↓                      ↓
                    └──────→ [x_centered / std] → x_norm
                                                    ↓
                                            [γ * x_norm + β] → y
```

### 梯度推導

設 $m = N$（batch size），逐步反向傳播：

1. **對 $\gamma$ 和 $\beta$**：
   $$\frac{\partial L}{\partial \gamma} = \sum_i \frac{\partial L}{\partial y_i} \hat{x}_i$$
   $$\frac{\partial L}{\partial \beta} = \sum_i \frac{\partial L}{\partial y_i}$$

2. **對 $\hat{x}$**：
   $$\frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \cdot \gamma$$

3. **對 $x$**（最複雜）：
   $\hat{x} = \frac{x - \mu}{\sigma}$ 其中 $\mu$ 和 $\sigma$ 都依賴於 $x$

   最終結果：
   $$\frac{\partial L}{\partial x_i} = \frac{1}{m\sigma} \left[ m \frac{\partial L}{\partial \hat{x}_i} - \sum_j \frac{\partial L}{\partial \hat{x}_j} - \hat{x}_i \sum_j \frac{\partial L}{\partial \hat{x}_j} \hat{x}_j \right]$$

In [None]:
# 梯度檢驗
def gradient_check_bn1d(bn, x, eps=1e-5):
    """
    對 BatchNorm1D 進行梯度檢驗
    """
    bn.train()
    
    # 前向傳播
    y = bn.forward(x)
    
    # 假設 loss = sum(y^2)
    dout = 2 * y
    
    # 反向傳播
    dx = bn.backward(dout)
    
    all_passed = True
    
    # === 檢驗 dgamma ===
    print("=== 檢驗 dgamma ===")
    dgamma_numerical = np.zeros_like(bn.gamma)
    
    for j in range(len(bn.gamma)):
        old_val = bn.gamma[j]
        
        bn.gamma[j] = old_val + eps
        y_plus = bn.forward(x)
        loss_plus = np.sum(y_plus ** 2)
        
        bn.gamma[j] = old_val - eps
        y_minus = bn.forward(x)
        loss_minus = np.sum(y_minus ** 2)
        
        bn.gamma[j] = old_val
        dgamma_numerical[j] = (loss_plus - loss_minus) / (2 * eps)
    
    rel_error = np.max(np.abs(bn.dgamma - dgamma_numerical) / (np.abs(bn.dgamma) + np.abs(dgamma_numerical) + 1e-8))
    print(f"  最大相對誤差: {rel_error:.2e}")
    print(f"  通過: {rel_error < 1e-4}")
    if rel_error > 1e-4:
        all_passed = False
    
    # === 檢驗 dbeta ===
    print("\n=== 檢驗 dbeta ===")
    dbeta_numerical = np.zeros_like(bn.beta)
    
    for j in range(len(bn.beta)):
        old_val = bn.beta[j]
        
        bn.beta[j] = old_val + eps
        y_plus = bn.forward(x)
        loss_plus = np.sum(y_plus ** 2)
        
        bn.beta[j] = old_val - eps
        y_minus = bn.forward(x)
        loss_minus = np.sum(y_minus ** 2)
        
        bn.beta[j] = old_val
        dbeta_numerical[j] = (loss_plus - loss_minus) / (2 * eps)
    
    rel_error = np.max(np.abs(bn.dbeta - dbeta_numerical) / (np.abs(bn.dbeta) + np.abs(dbeta_numerical) + 1e-8))
    print(f"  最大相對誤差: {rel_error:.2e}")
    print(f"  通過: {rel_error < 1e-4}")
    if rel_error > 1e-4:
        all_passed = False
    
    # === 檢驗 dx ===
    print("\n=== 檢驗 dx ===")
    dx_numerical = np.zeros_like(x)
    x_test = x.copy()
    
    num_checks = min(10, x.size)
    indices = np.random.choice(x.size, num_checks, replace=False)
    
    for idx in indices:
        multi_idx = np.unravel_index(idx, x.shape)
        old_val = x_test[multi_idx]
        
        x_test[multi_idx] = old_val + eps
        y_plus = bn.forward(x_test)
        loss_plus = np.sum(y_plus ** 2)
        
        x_test[multi_idx] = old_val - eps
        y_minus = bn.forward(x_test)
        loss_minus = np.sum(y_minus ** 2)
        
        x_test[multi_idx] = old_val
        dx_numerical[multi_idx] = (loss_plus - loss_minus) / (2 * eps)
    
    # 比較抽查的位置
    for idx in indices:
        multi_idx = np.unravel_index(idx, x.shape)
        ana = dx[multi_idx]
        num = dx_numerical[multi_idx]
        error = abs(ana - num) / (abs(ana) + abs(num) + 1e-8)
        if error > 1e-4:
            print(f"  位置 {multi_idx}: 解析={ana:.6f}, 數值={num:.6f}, 誤差={error:.2e} ❌")
            all_passed = False
    
    if all_passed:
        print(f"  抽查 {num_checks} 個位置全部通過 ✓")
    
    return all_passed

# 執行梯度檢驗
bn = BatchNorm1D(num_features=5)
x = np.random.randn(10, 5)
gradient_check_bn1d(bn, x)

## 第三部分：BatchNorm2D（用於卷積層）

對於卷積層，輸入形狀是 $(N, C, H, W)$。BatchNorm2D 對每個通道 $C$ 分別計算統計量，跨 $N, H, W$ 維度。

In [None]:
class BatchNorm2D:
    """
    二維 Batch Normalization（用於卷積層後）
    
    輸入形狀：(N, C, H, W)
    對每個通道 C 分別正規化（跨 N, H, W 維度）
    """
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        """
        Parameters
        ----------
        num_features : int
            通道數 C
        """
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 可學習參數
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        
        # Running statistics
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        
        self.dgamma = None
        self.dbeta = None
        self.cache = None
        self.training = True
    
    def forward(self, x):
        """
        前向傳播
        
        Parameters
        ----------
        x : np.ndarray, shape (N, C, H, W)
        
        Returns
        -------
        y : np.ndarray, shape (N, C, H, W)
        """
        N, C, H, W = x.shape
        
        if self.training:
            # 計算每個通道的 mean 和 var（跨 N, H, W）
            batch_mean = np.mean(x, axis=(0, 2, 3))  # (C,)
            batch_var = np.var(x, axis=(0, 2, 3))    # (C,)
            
            # 廣播形狀 (C,) -> (1, C, 1, 1)
            mean_bc = batch_mean.reshape(1, C, 1, 1)
            var_bc = batch_var.reshape(1, C, 1, 1)
            
            # 正規化
            x_centered = x - mean_bc
            std = np.sqrt(var_bc + self.eps)
            x_norm = x_centered / std
            
            # Scale 和 Shift
            gamma_bc = self.gamma.reshape(1, C, 1, 1)
            beta_bc = self.beta.reshape(1, C, 1, 1)
            y = gamma_bc * x_norm + beta_bc
            
            # 更新 running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
            
            self.cache = (x, x_centered, std, x_norm, batch_mean, batch_var)
        else:
            mean_bc = self.running_mean.reshape(1, C, 1, 1)
            var_bc = self.running_var.reshape(1, C, 1, 1)
            gamma_bc = self.gamma.reshape(1, C, 1, 1)
            beta_bc = self.beta.reshape(1, C, 1, 1)
            
            x_norm = (x - mean_bc) / np.sqrt(var_bc + self.eps)
            y = gamma_bc * x_norm + beta_bc
        
        return y
    
    def backward(self, dout):
        """
        反向傳播
        """
        x, x_centered, std, x_norm, batch_mean, batch_var = self.cache
        N, C, H, W = x.shape
        m = N * H * W  # 統計量計算的元素數
        
        # 廣播形狀
        gamma_bc = self.gamma.reshape(1, C, 1, 1)
        
        # 對 gamma 和 beta 的梯度
        self.dgamma = np.sum(dout * x_norm, axis=(0, 2, 3))
        self.dbeta = np.sum(dout, axis=(0, 2, 3))
        
        # 對 x 的梯度
        dx_norm = dout * gamma_bc
        
        # 使用簡化的公式（效率更高）
        # dx = (1/m) / std * (m * dx_norm - sum(dx_norm) - x_norm * sum(dx_norm * x_norm))
        sum_dx_norm = np.sum(dx_norm, axis=(0, 2, 3), keepdims=True)
        sum_dx_norm_xnorm = np.sum(dx_norm * x_norm, axis=(0, 2, 3), keepdims=True)
        
        dx = (1.0 / m) / std * (m * dx_norm - sum_dx_norm - x_norm * sum_dx_norm_xnorm)
        
        return dx
    
    def train(self):
        self.training = True
    
    def eval(self):
        self.training = False

# 測試 BatchNorm2D
bn2d = BatchNorm2D(num_features=3)
x = np.random.randn(4, 3, 8, 8) * 5 + 3  # 非標準分佈

print("輸入統計量（每個通道）:")
for c in range(3):
    print(f"  通道 {c}: mean={np.mean(x[:, c]):.4f}, std={np.std(x[:, c]):.4f}")

y = bn2d.forward(x)

print("\n輸出統計量（正規化後）:")
for c in range(3):
    print(f"  通道 {c}: mean={np.mean(y[:, c]):.4f}, std={np.std(y[:, c]):.4f}")

In [None]:
# BatchNorm2D 梯度檢驗
def gradient_check_bn2d(bn, x, eps=1e-5):
    """對 BatchNorm2D 進行梯度檢驗"""
    bn.train()
    
    y = bn.forward(x)
    dout = 2 * y
    dx = bn.backward(dout)
    
    all_passed = True
    
    # 檢驗 dgamma
    print("=== 檢驗 dgamma ===")
    dgamma_numerical = np.zeros_like(bn.gamma)
    for j in range(len(bn.gamma)):
        old_val = bn.gamma[j]
        
        bn.gamma[j] = old_val + eps
        loss_plus = np.sum(bn.forward(x) ** 2)
        
        bn.gamma[j] = old_val - eps
        loss_minus = np.sum(bn.forward(x) ** 2)
        
        bn.gamma[j] = old_val
        dgamma_numerical[j] = (loss_plus - loss_minus) / (2 * eps)
    
    rel_error = np.max(np.abs(bn.dgamma - dgamma_numerical) / (np.abs(bn.dgamma) + np.abs(dgamma_numerical) + 1e-8))
    print(f"  最大相對誤差: {rel_error:.2e}")
    print(f"  通過: {rel_error < 1e-4}")
    if rel_error > 1e-4:
        all_passed = False
    
    # 檢驗 dx（抽樣）
    print("\n=== 檢驗 dx ===")
    dx_numerical = np.zeros_like(x)
    x_test = x.copy()
    
    num_checks = min(10, x.size)
    indices = np.random.choice(x.size, num_checks, replace=False)
    
    for idx in indices:
        multi_idx = np.unravel_index(idx, x.shape)
        old_val = x_test[multi_idx]
        
        x_test[multi_idx] = old_val + eps
        loss_plus = np.sum(bn.forward(x_test) ** 2)
        
        x_test[multi_idx] = old_val - eps
        loss_minus = np.sum(bn.forward(x_test) ** 2)
        
        x_test[multi_idx] = old_val
        dx_numerical[multi_idx] = (loss_plus - loss_minus) / (2 * eps)
    
    max_error = 0
    for idx in indices:
        multi_idx = np.unravel_index(idx, x.shape)
        error = abs(dx[multi_idx] - dx_numerical[multi_idx]) / (abs(dx[multi_idx]) + abs(dx_numerical[multi_idx]) + 1e-8)
        max_error = max(max_error, error)
    
    print(f"  最大相對誤差: {max_error:.2e}")
    print(f"  通過: {max_error < 1e-4}")
    if max_error > 1e-4:
        all_passed = False
    
    return all_passed

bn2d = BatchNorm2D(num_features=3)
x = np.random.randn(2, 3, 4, 4)
gradient_check_bn2d(bn2d, x)

## 第四部分：BatchNorm 的效果

讓我們視覺化 BatchNorm 對訓練的影響。

In [None]:
# 簡單的網路來比較有無 BatchNorm 的效果

class FCWithBN:
    """帶 BatchNorm 的全連接層"""
    def __init__(self, in_features, out_features):
        std = np.sqrt(2.0 / in_features)
        self.W = np.random.randn(in_features, out_features) * std
        self.b = np.zeros(out_features)
        self.bn = BatchNorm1D(out_features)
        
        self.dW = None
        self.db = None
        self.cache = None
    
    def forward(self, X):
        self.cache = X
        z = X @ self.W + self.b
        out = self.bn.forward(z)
        return out
    
    def backward(self, dout):
        dz = self.bn.backward(dout)
        X = self.cache
        self.dW = X.T @ dz
        self.db = np.sum(dz, axis=0)
        return dz @ self.W.T


def relu(x):
    return np.maximum(0, x)

def relu_backward(dout, x):
    return dout * (x > 0)


class DeepNetWithBN:
    """深度網路（有 BatchNorm）"""
    def __init__(self, layer_dims):
        self.layers = []
        for i in range(len(layer_dims) - 1):
            self.layers.append(FCWithBN(layer_dims[i], layer_dims[i+1]))
        self.relu_cache = []
    
    def forward(self, X):
        self.relu_cache = []
        out = X
        for i, layer in enumerate(self.layers[:-1]):
            out = layer.forward(out)
            self.relu_cache.append(out)
            out = relu(out)
        out = self.layers[-1].forward(out)
        return out
    
    def backward(self, dout):
        dout = self.layers[-1].backward(dout)
        for i, layer in enumerate(reversed(self.layers[:-1])):
            dout = relu_backward(dout, self.relu_cache[-(i+1)])
            dout = layer.backward(dout)
    
    def get_params_and_grads(self):
        params = []
        for layer in self.layers:
            params.append((layer.W, layer.dW))
            params.append((layer.b, layer.db))
            params.append((layer.bn.gamma, layer.bn.dgamma))
            params.append((layer.bn.beta, layer.bn.dbeta))
        return params


class DeepNetNoBN:
    """深度網路（無 BatchNorm）"""
    def __init__(self, layer_dims):
        self.Ws = []
        self.bs = []
        self.dWs = []
        self.dbs = []
        
        for i in range(len(layer_dims) - 1):
            std = np.sqrt(2.0 / layer_dims[i])
            self.Ws.append(np.random.randn(layer_dims[i], layer_dims[i+1]) * std)
            self.bs.append(np.zeros(layer_dims[i+1]))
            self.dWs.append(None)
            self.dbs.append(None)
        
        self.caches = []
    
    def forward(self, X):
        self.caches = [X]
        out = X
        for i in range(len(self.Ws) - 1):
            out = out @ self.Ws[i] + self.bs[i]
            self.caches.append(out)
            out = relu(out)
        out = out @ self.Ws[-1] + self.bs[-1]
        return out
    
    def backward(self, dout):
        self.dWs[-1] = self.caches[-1].T @ dout if len(self.caches) > len(self.Ws) else relu(self.caches[-1]).T @ dout
        # 簡化版本
        pass
    
    def get_params_and_grads(self):
        params = []
        for i in range(len(self.Ws)):
            params.append((self.Ws[i], self.dWs[i]))
            params.append((self.bs[i], self.dbs[i]))
        return params

print("網路定義完成！")

In [None]:
# 視覺化 BatchNorm 對激活值分佈的影響

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# 深度網路參數
layer_dims = [20, 50, 50, 50, 50, 10]

# 無 BatchNorm
np.random.seed(42)
X = np.random.randn(100, 20)
activations_no_bn = [X]

for i in range(len(layer_dims) - 1):
    W = np.random.randn(layer_dims[i], layer_dims[i+1]) * 0.1
    b = np.zeros(layer_dims[i+1])
    out = X @ W + b
    if i < len(layer_dims) - 2:
        out = relu(out)
    activations_no_bn.append(out)
    X = out

# 有 BatchNorm
np.random.seed(42)
X = np.random.randn(100, 20)
activations_bn = [X]

for i in range(len(layer_dims) - 1):
    W = np.random.randn(layer_dims[i], layer_dims[i+1]) * 0.1
    b = np.zeros(layer_dims[i+1])
    out = X @ W + b
    
    # BatchNorm
    mean = np.mean(out, axis=0)
    var = np.var(out, axis=0)
    out = (out - mean) / np.sqrt(var + 1e-5)
    
    if i < len(layer_dims) - 2:
        out = relu(out)
    activations_bn.append(out)
    X = out

# 繪圖
for i in range(4):
    # 無 BatchNorm
    ax = axes[0, i]
    ax.hist(activations_no_bn[i+1].flatten(), bins=50, alpha=0.7)
    ax.set_title(f'No BN - Layer {i+1}')
    ax.set_xlim(-3, 3)
    
    # 有 BatchNorm
    ax = axes[1, i]
    ax.hist(activations_bn[i+1].flatten(), bins=50, alpha=0.7, color='orange')
    ax.set_title(f'With BN - Layer {i+1}')
    ax.set_xlim(-3, 3)

axes[0, 0].set_ylabel('No BatchNorm')
axes[1, 0].set_ylabel('With BatchNorm')

plt.tight_layout()
plt.show()

print("\n觀察：")
print("- 無 BatchNorm：激活值分佈逐層變化，可能會很窄或很寬")
print("- 有 BatchNorm：激活值分佈保持穩定，接近標準正態分佈")

## 練習題

### 練習 1：實作 Layer Normalization

Layer Normalization 與 Batch Normalization 不同，它是對每個樣本的特徵維度進行正規化，而不是跨 batch。

- **BatchNorm**：跨 batch 維度正規化
- **LayerNorm**：跨 feature 維度正規化

LayerNorm 在 Transformer 中被廣泛使用。

In [None]:
class LayerNorm:
    """
    Layer Normalization
    
    對每個樣本的特徵維度正規化
    輸入形狀：(N, D) 或 (N, C, H, W)
    """
    
    def __init__(self, normalized_shape, eps=1e-5):
        """
        Parameters
        ----------
        normalized_shape : int or tuple
            正規化的維度
        """
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = normalized_shape
        self.eps = eps
        
        # 可學習參數
        self.gamma = np.ones(normalized_shape)
        self.beta = np.zeros(normalized_shape)
        
        self.dgamma = None
        self.dbeta = None
        self.cache = None
    
    def forward(self, x):
        """
        前向傳播
        """
        # 解答：
        # 確定正規化的軸（最後幾個維度）
        num_axes = len(self.normalized_shape)
        axes = tuple(range(-num_axes, 0))
        
        # 計算每個樣本的 mean 和 var
        mean = np.mean(x, axis=axes, keepdims=True)
        var = np.var(x, axis=axes, keepdims=True)
        
        # 正規化
        x_centered = x - mean
        std = np.sqrt(var + self.eps)
        x_norm = x_centered / std
        
        # Scale 和 Shift
        y = self.gamma * x_norm + self.beta
        
        self.cache = (x, x_centered, std, x_norm, axes)
        return y
    
    def backward(self, dout):
        """
        反向傳播
        """
        x, x_centered, std, x_norm, axes = self.cache
        
        # 計算正規化維度的元素數
        m = 1
        for ax in axes:
            m *= x.shape[ax]
        
        # 對 gamma 和 beta 的梯度
        self.dgamma = np.sum(dout * x_norm, axis=tuple(range(x.ndim - len(self.normalized_shape))))
        self.dbeta = np.sum(dout, axis=tuple(range(x.ndim - len(self.normalized_shape))))
        
        # 對 x 的梯度
        dx_norm = dout * self.gamma
        
        # 使用簡化公式
        sum_dx_norm = np.sum(dx_norm, axis=axes, keepdims=True)
        sum_dx_norm_xnorm = np.sum(dx_norm * x_norm, axis=axes, keepdims=True)
        
        dx = (1.0 / m) / std * (m * dx_norm - sum_dx_norm - x_norm * sum_dx_norm_xnorm)
        
        return dx

# 測試 LayerNorm
ln = LayerNorm(normalized_shape=10)
x = np.random.randn(5, 10) * 3 + 2

print("輸入統計量（每個樣本）:")
for i in range(3):
    print(f"  樣本 {i}: mean={np.mean(x[i]):.4f}, std={np.std(x[i]):.4f}")

y = ln.forward(x)

print("\n輸出統計量（正規化後）:")
for i in range(3):
    print(f"  樣本 {i}: mean={np.mean(y[i]):.4f}, std={np.std(y[i]):.4f}")

print("\n比較：")
print("- BatchNorm 對每個特徵維度，跨樣本正規化")
print("- LayerNorm 對每個樣本，跨特徵維度正規化")

## 總結

在這個 notebook 中，我們學習了：

### Batch Normalization 公式

1. **計算統計量**：$\mu = \frac{1}{m}\sum x_i$, $\sigma^2 = \frac{1}{m}\sum (x_i - \mu)^2$
2. **正規化**：$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$
3. **縮放和平移**：$y = \gamma \hat{x} + \beta$

### Training vs Inference

| 階段 | 使用的統計量 | 更新 running stats |
|------|-------------|--------------------|
| Training | batch mean/var | 是 |
| Inference | running mean/var | 否 |

### BatchNorm 的好處

1. **減少 Internal Covariate Shift**：穩定訓練
2. **允許更大的學習率**：加速收斂
3. **輕微的正則化效果**：因為每個 batch 的統計量不同
4. **減少對初始化的敏感度**

### BatchNorm vs LayerNorm

| 特性 | BatchNorm | LayerNorm |
|------|-----------|----------|
| 正規化軸 | 跨 batch | 跨 feature |
| 適用場景 | CNN | Transformer, RNN |
| 依賴 batch size | 是 | 否 |

### 下一步

接下來我們將學習 **ResNet Block**，利用 BatchNorm 和殘差連接構建更深的網路！