# U-Net Architecture - 語義分割網路

## 學習目標

1. 理解語義分割（semantic segmentation）任務
2. 理解 encoder-decoder 架構
3. 理解 skip connections 在 U-Net 中的作用
4. 實作上採樣方法：Nearest neighbor, Bilinear, Transposed Convolution
5. 實作簡化版 U-Net（含 forward 和 backward）

## 參考資料

- Ronneberger et al., "U-Net: Convolutional Networks for Biomedical Image Segmentation", MICCAI 2015

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

---

## 第一部分：語義分割問題

### 1.1 什麼是語義分割？

- **分類（Classification）**：整張圖一個標籤，例如「這是貓」
- **物件偵測（Detection）**：找出物體的 bounding box
- **語義分割（Semantic Segmentation）**：**每個像素**都有一個類別標籤

```
輸入：(H, W, 3) 彩色圖片
輸出：(H, W) 每個像素的類別
     或 (H, W, C) 每個像素在 C 個類別上的機率
```

### 1.2 挑戰

- 輸出和輸入**大小相同**
- 需要同時有：
  - **High-level features**: 理解語義（這是什麼物體）
  - **Low-level features**: 保留空間細節（邊界在哪裡）

普通 CNN 經過多層 pooling 後，特徵圖變得很小，無法恢復細節。

In [None]:
# 示範：語義分割任務

def create_segmentation_example():
    """創建一個簡單的分割任務示例"""
    img_size = 64
    
    # 創建圖片（帶有圓形和方形）
    image = np.zeros((img_size, img_size))
    mask = np.zeros((img_size, img_size), dtype=int)  # 0: background, 1: circle, 2: square
    
    # 畫圓形
    y, x = np.ogrid[:img_size, :img_size]
    cx, cy, r = 20, 20, 12
    circle_mask = (x - cx)**2 + (y - cy)**2 <= r**2
    image[circle_mask] = 0.8
    mask[circle_mask] = 1
    
    # 畫方形
    sx, sy, size = 40, 35, 15
    image[sy:sy+size, sx:sx+size] = 0.6
    mask[sy:sy+size, sx:sx+size] = 2
    
    # 加噪音
    image += np.random.randn(img_size, img_size) * 0.1
    
    return image, mask

image, mask = create_segmentation_example()

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(image, cmap='gray')
axes[0].set_title('Input Image')
axes[0].axis('off')

axes[1].imshow(mask, cmap='tab10', vmin=0, vmax=3)
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')

# 顯示 overlay
axes[2].imshow(image, cmap='gray')
axes[2].imshow(mask, cmap='tab10', alpha=0.5, vmin=0, vmax=3)
axes[2].set_title('Overlay')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("類別：0=背景(藍), 1=圓形(橙), 2=方形(綠)")

---

## 第二部分：U-Net 架構概述

### 2.1 Encoder-Decoder 結構

```
                        U-Net Architecture
    
    Encoder (Contracting)          Decoder (Expanding)
    
    [64x64, 64ch] ─────────────────────────────────────> [64x64, n_classes]
          │                                                    ↑
          │ MaxPool                                   Upsample │
          ↓                                                    │
    [32x32, 128ch] ─────────────────────────────────> [32x32, 128ch]
          │                     Skip Connection              ↑
          │ MaxPool                                   Upsample │
          ↓                                                    │
    [16x16, 256ch] ─────────────────────────────────> [16x16, 256ch]
          │                     Skip Connection              ↑
          │ MaxPool                                   Upsample │
          ↓                                                    │
    [8x8, 512ch]  ──────────────────────────────────> [8x8, 512ch]
          │                                                    ↑
          └──────────> [Bottleneck] ───────────────────────────┘
```

### 2.2 Skip Connections 的作用

- **問題**：Decoder 需要恢復空間細節，但經過多次 downsampling 後已經丟失
- **解法**：把 Encoder 對應層的特徵直接「跳接」到 Decoder
- **好處**：
  - Decoder 同時有 high-level（來自 bottleneck）和 low-level（來自 skip）特徵
  - 梯度可以更直接地流回 Encoder

---

## 第三部分：基礎組件

從之前的 notebooks 引入並擴展。

In [None]:
# im2col 和 col2im（從 02_resnet_block.ipynb）

def im2col(x, kH, kW, stride=1, pad=0):
    """將 4D tensor 展開成 2D matrix"""
    N, C, H, W = x.shape
    
    if pad > 0:
        x = np.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode='constant')
    
    H_pad, W_pad = x.shape[2], x.shape[3]
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    shape = (N, C, kH, kW, out_H, out_W)
    strides = (x.strides[0], x.strides[1], x.strides[2], x.strides[3],
               x.strides[2] * stride, x.strides[3] * stride)
    
    col = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_H * out_W, -1)
    
    return col

def col2im(col, x_shape, kH, kW, stride=1, pad=0):
    """im2col 的逆操作"""
    N, C, H, W = x_shape
    H_pad = H + 2 * pad
    W_pad = W + 2 * pad
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    col = col.reshape(N, out_H, out_W, C, kH, kW).transpose(0, 3, 4, 5, 1, 2)
    
    x_pad = np.zeros((N, C, H_pad, W_pad))
    
    for i in range(kH):
        i_max = i + stride * out_H
        for j in range(kW):
            j_max = j + stride * out_W
            x_pad[:, :, i:i_max:stride, j:j_max:stride] += col[:, :, i, j, :, :]
    
    if pad > 0:
        return x_pad[:, :, pad:-pad, pad:-pad]
    return x_pad

In [None]:
class Conv2D:
    """2D Convolution Layer"""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        scale = np.sqrt(2.0 / (in_channels * kernel_size * kernel_size))
        self.W = np.random.randn(out_channels, in_channels, kernel_size, kernel_size) * scale
        self.b = np.zeros(out_channels)
        
        self.cache = None
        self.dW = None
        self.db = None
    
    def forward(self, x):
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        
        H_out = (H + 2 * self.padding - kH) // self.stride + 1
        W_out = (W + 2 * self.padding - kW) // self.stride + 1
        
        col = im2col(x, kH, kW, self.stride, self.padding)
        W_col = self.W.reshape(self.out_channels, -1)
        
        out = col @ W_col.T + self.b
        out = out.reshape(N, H_out, W_out, self.out_channels).transpose(0, 3, 1, 2)
        
        self.cache = (x, col)
        return out
    
    def backward(self, dout):
        x, col = self.cache
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        
        dout_reshaped = dout.transpose(0, 2, 3, 1).reshape(-1, self.out_channels)
        
        self.dW = (dout_reshaped.T @ col).reshape(self.W.shape)
        self.db = dout_reshaped.sum(axis=0)
        
        W_col = self.W.reshape(self.out_channels, -1)
        dcol = dout_reshaped @ W_col
        dx = col2im(dcol, x.shape, kH, kW, self.stride, self.padding)
        
        return dx

In [None]:
class BatchNorm2D:
    """Batch Normalization for Conv layers"""
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        
        self.cache = None
        self.training = True
        self.dgamma = None
        self.dbeta = None
    
    def forward(self, x):
        N, C, H, W = x.shape
        
        if self.training:
            mean = x.mean(axis=(0, 2, 3))
            var = x.var(axis=(0, 2, 3))
            
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        std = np.sqrt(var + self.eps)
        x_norm = (x - mean.reshape(1, C, 1, 1)) / std.reshape(1, C, 1, 1)
        out = self.gamma.reshape(1, C, 1, 1) * x_norm + self.beta.reshape(1, C, 1, 1)
        
        if self.training:
            self.cache = (x, x_norm, mean, var, std)
        
        return out
    
    def backward(self, dout):
        x, x_norm, mean, var, std = self.cache
        N, C, H, W = x.shape
        m = N * H * W
        
        self.dgamma = (dout * x_norm).sum(axis=(0, 2, 3))
        self.dbeta = dout.sum(axis=(0, 2, 3))
        
        dx_norm = dout * self.gamma.reshape(1, C, 1, 1)
        
        dvar = (-0.5 * dx_norm * (x - mean.reshape(1, C, 1, 1)) / 
                (var.reshape(1, C, 1, 1) + self.eps) ** 1.5).sum(axis=(0, 2, 3))
        
        dmean = (-dx_norm / std.reshape(1, C, 1, 1)).sum(axis=(0, 2, 3))
        dmean += dvar * (-2 / m) * (x - mean.reshape(1, C, 1, 1)).sum(axis=(0, 2, 3))
        
        dx = dx_norm / std.reshape(1, C, 1, 1)
        dx += (2 / m) * dvar.reshape(1, C, 1, 1) * (x - mean.reshape(1, C, 1, 1))
        dx += dmean.reshape(1, C, 1, 1) / m
        
        return dx

In [None]:
class ReLU:
    """ReLU activation"""
    
    def __init__(self):
        self.cache = None
    
    def forward(self, x):
        self.cache = x
        return np.maximum(0, x)
    
    def backward(self, dout):
        x = self.cache
        return dout * (x > 0)


class MaxPool2D:
    """Max Pooling 2D"""
    
    def __init__(self, kernel_size=2, stride=2):
        self.kernel_size = kernel_size
        self.stride = stride
        self.cache = None
    
    def forward(self, x):
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        s = self.stride
        
        out_H = (H - kH) // s + 1
        out_W = (W - kW) // s + 1
        
        # Reshape for pooling
        x_reshaped = x.reshape(N, C, out_H, s, out_W, s)
        
        # 只有在 stride == kernel_size 時才能這樣做
        if s == kH:
            out = x_reshaped.max(axis=(3, 5))
            
            # 記錄 max 的位置
            x_col = x_reshaped.transpose(0, 1, 2, 4, 3, 5).reshape(N, C, out_H, out_W, -1)
            max_idx = x_col.argmax(axis=-1)
            self.cache = (x.shape, max_idx)
        else:
            # 通用情況：用 im2col
            col = im2col(x, kH, kW, s, 0)
            col = col.reshape(-1, C, kH * kW)
            out = col.max(axis=2)
            max_idx = col.argmax(axis=2)
            out = out.reshape(N, out_H, out_W, C).transpose(0, 3, 1, 2)
            self.cache = (x.shape, max_idx, col.shape)
        
        return out
    
    def backward(self, dout):
        x_shape, max_idx = self.cache[:2]
        N, C, H, W = x_shape
        kH = kW = self.kernel_size
        s = self.stride
        out_H, out_W = dout.shape[2], dout.shape[3]
        
        if s == kH:
            dx = np.zeros(x_shape)
            
            # 將梯度放回 max 位置
            for n in range(N):
                for c in range(C):
                    for i in range(out_H):
                        for j in range(out_W):
                            idx = max_idx[n, c, i, j]
                            ii, jj = idx // kW, idx % kW
                            dx[n, c, i*s + ii, j*s + jj] = dout[n, c, i, j]
        else:
            # 通用情況
            col_shape = self.cache[2]
            dcol = np.zeros(col_shape)
            dout_flat = dout.transpose(0, 2, 3, 1).flatten()
            
            # 這裡簡化處理
            dx = np.zeros(x_shape)
        
        return dx

---

## 第四部分：上採樣方法

Decoder 需要把特徵圖從小變大，有幾種方法：

1. **Nearest Neighbor Upsampling**: 最簡單，直接複製
2. **Bilinear Interpolation**: 線性插值
3. **Transposed Convolution**: 可學習的上採樣

In [None]:
class NearestUpsample:
    """
    Nearest Neighbor Upsampling
    
    最簡單的上採樣：直接把每個值複製到周圍的位置
    
    例如 scale=2:
    [[1, 2],     [[1, 1, 2, 2],
     [3, 4]]  ->  [1, 1, 2, 2],
                  [3, 3, 4, 4],
                  [3, 3, 4, 4]]
    """
    
    def __init__(self, scale_factor=2):
        self.scale_factor = scale_factor
        self.cache = None
    
    def forward(self, x):
        """
        Parameters
        ----------
        x : np.ndarray, shape (N, C, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C, H*scale, W*scale)
        """
        N, C, H, W = x.shape
        s = self.scale_factor
        
        # 使用 repeat 來複製
        out = x.repeat(s, axis=2).repeat(s, axis=3)
        
        self.cache = x.shape
        return out
    
    def backward(self, dout):
        """
        Backward: 把對應位置的梯度加起來
        
        Parameters
        ----------
        dout : np.ndarray, shape (N, C, H*scale, W*scale)
        
        Returns
        -------
        dx : np.ndarray, shape (N, C, H, W)
        """
        N, C, H, W = self.cache
        s = self.scale_factor
        
        # Reshape 然後 sum
        # (N, C, H*s, W*s) -> (N, C, H, s, W, s) -> sum over s dimensions
        dx = dout.reshape(N, C, H, s, W, s).sum(axis=(3, 5))
        
        return dx

In [None]:
# 測試 NearestUpsample

x = np.array([[[[1, 2],
                [3, 4]]]], dtype=float)

print("Input (1, 1, 2, 2):")
print(x[0, 0])

upsample = NearestUpsample(scale_factor=2)
out = upsample.forward(x)

print("\nOutput (1, 1, 4, 4):")
print(out[0, 0])

# 測試 backward
dout = np.ones_like(out)
dx = upsample.backward(dout)

print("\ndout (all ones):")
print(dout[0, 0])

print("\ndx (should be 4 for each element, because 2x2 gradients sum up):")
print(dx[0, 0])

In [None]:
class BilinearUpsample:
    """
    Bilinear Interpolation Upsampling
    
    使用雙線性插值來上採樣，比 nearest neighbor 更平滑
    """
    
    def __init__(self, scale_factor=2):
        self.scale_factor = scale_factor
        self.cache = None
    
    def forward(self, x):
        """
        Parameters
        ----------
        x : np.ndarray, shape (N, C, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C, H*scale, W*scale)
        """
        N, C, H, W = x.shape
        s = self.scale_factor
        new_H, new_W = H * s, W * s
        
        out = np.zeros((N, C, new_H, new_W))
        
        # 計算每個輸出位置對應的輸入位置
        # 使用 align_corners=False 的方式
        for i in range(new_H):
            for j in range(new_W):
                # 輸出 (i, j) 對應輸入的位置
                src_i = (i + 0.5) / s - 0.5
                src_j = (j + 0.5) / s - 0.5
                
                # 找到四個鄰居
                i0 = int(np.floor(src_i))
                j0 = int(np.floor(src_j))
                i1 = i0 + 1
                j1 = j0 + 1
                
                # 權重
                wi = src_i - i0
                wj = src_j - j0
                
                # Clamp indices
                i0 = max(0, min(i0, H - 1))
                i1 = max(0, min(i1, H - 1))
                j0 = max(0, min(j0, W - 1))
                j1 = max(0, min(j1, W - 1))
                
                # Bilinear interpolation
                out[:, :, i, j] = ((1 - wi) * (1 - wj) * x[:, :, i0, j0] +
                                   (1 - wi) * wj * x[:, :, i0, j1] +
                                   wi * (1 - wj) * x[:, :, i1, j0] +
                                   wi * wj * x[:, :, i1, j1])
        
        self.cache = (x.shape, s)
        return out
    
    def backward(self, dout):
        """
        Backward pass for bilinear upsampling
        """
        x_shape, s = self.cache
        N, C, H, W = x_shape
        new_H, new_W = H * s, W * s
        
        dx = np.zeros(x_shape)
        
        for i in range(new_H):
            for j in range(new_W):
                src_i = (i + 0.5) / s - 0.5
                src_j = (j + 0.5) / s - 0.5
                
                i0 = int(np.floor(src_i))
                j0 = int(np.floor(src_j))
                i1 = i0 + 1
                j1 = j0 + 1
                
                wi = src_i - i0
                wj = src_j - j0
                
                # Clamp
                i0_c = max(0, min(i0, H - 1))
                i1_c = max(0, min(i1, H - 1))
                j0_c = max(0, min(j0, W - 1))
                j1_c = max(0, min(j1, W - 1))
                
                # 反向傳播：把 dout 分配回去
                d = dout[:, :, i, j]
                dx[:, :, i0_c, j0_c] += (1 - wi) * (1 - wj) * d
                dx[:, :, i0_c, j1_c] += (1 - wi) * wj * d
                dx[:, :, i1_c, j0_c] += wi * (1 - wj) * d
                dx[:, :, i1_c, j1_c] += wi * wj * d
        
        return dx

In [None]:
# 比較 Nearest 和 Bilinear

x = np.random.randn(1, 1, 4, 4)

nearest = NearestUpsample(scale_factor=4)
bilinear = BilinearUpsample(scale_factor=4)

out_nearest = nearest.forward(x)
out_bilinear = bilinear.forward(x)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(x[0, 0], cmap='viridis')
axes[0].set_title('Original (4x4)')

axes[1].imshow(out_nearest[0, 0], cmap='viridis')
axes[1].set_title('Nearest Neighbor (16x16)')

axes[2].imshow(out_bilinear[0, 0], cmap='viridis')
axes[2].set_title('Bilinear (16x16)')

plt.tight_layout()
plt.show()

print("觀察：Bilinear 產生更平滑的結果")

In [None]:
class TransposedConv2D:
    """
    Transposed Convolution (Deconvolution)
    
    可學習的上採樣。實際上是正常卷積的「反向」操作。
    
    Forward: 把小的特徵圖變大
    Backward: 做正常的卷積
    """
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=2, padding=0):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # 初始化權重
        scale = np.sqrt(2.0 / (in_channels * kernel_size * kernel_size))
        self.W = np.random.randn(in_channels, out_channels, kernel_size, kernel_size) * scale
        self.b = np.zeros(out_channels)
        
        self.cache = None
        self.dW = None
        self.db = None
    
    def forward(self, x):
        """
        Parameters
        ----------
        x : np.ndarray, shape (N, C_in, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C_out, H_out, W_out)
        
        H_out = stride * (H - 1) + kernel_size - 2 * padding
        """
        N, C_in, H, W = x.shape
        k = self.kernel_size
        s = self.stride
        p = self.padding
        
        # 計算輸出大小
        H_out = s * (H - 1) + k - 2 * p
        W_out = s * (W - 1) + k - 2 * p
        
        # 初始化輸出（帶 padding）
        out_padded = np.zeros((N, self.out_channels, H_out + 2 * p, W_out + 2 * p))
        
        # 對每個輸入位置，把對應的 kernel 加到輸出
        for i in range(H):
            for j in range(W):
                # 輸出區域的起始位置
                h_start = i * s
                w_start = j * s
                
                # x[:, :, i, j] shape: (N, C_in)
                # W shape: (C_in, C_out, k, k)
                # 結果: (N, C_out, k, k)
                contribution = np.einsum('nc,cokl->nokl', x[:, :, i, j], self.W)
                out_padded[:, :, h_start:h_start+k, w_start:w_start+k] += contribution
        
        # 加 bias
        out_padded += self.b.reshape(1, -1, 1, 1)
        
        # 移除 padding
        if p > 0:
            out = out_padded[:, :, p:-p, p:-p]
        else:
            out = out_padded
        
        self.cache = (x, H_out, W_out)
        return out
    
    def backward(self, dout):
        """
        Backward pass
        
        關鍵觀察：TransposedConv 的 backward 就是正常 Conv 的 forward
        """
        x, H_out, W_out = self.cache
        N, C_in, H, W = x.shape
        k = self.kernel_size
        s = self.stride
        p = self.padding
        
        # Pad dout
        if p > 0:
            dout_padded = np.pad(dout, ((0, 0), (0, 0), (p, p), (p, p)))
        else:
            dout_padded = dout
        
        # Gradient for input x
        dx = np.zeros_like(x)
        
        for i in range(H):
            for j in range(W):
                h_start = i * s
                w_start = j * s
                
                # dout_padded[:, :, h_start:h_start+k, w_start:w_start+k] shape: (N, C_out, k, k)
                # W shape: (C_in, C_out, k, k)
                # 結果: (N, C_in)
                dout_patch = dout_padded[:, :, h_start:h_start+k, w_start:w_start+k]
                dx[:, :, i, j] = np.einsum('nokl,cokl->nc', dout_patch, self.W)
        
        # Gradient for W
        self.dW = np.zeros_like(self.W)
        
        for i in range(H):
            for j in range(W):
                h_start = i * s
                w_start = j * s
                
                # x[:, :, i, j] shape: (N, C_in)
                # dout_padded[:, :, h_start:h_start+k, w_start:w_start+k] shape: (N, C_out, k, k)
                dout_patch = dout_padded[:, :, h_start:h_start+k, w_start:w_start+k]
                self.dW += np.einsum('nc,nokl->cokl', x[:, :, i, j], dout_patch)
        
        # Gradient for b
        self.db = dout.sum(axis=(0, 2, 3))
        
        return dx

In [None]:
# 測試 TransposedConv2D

x = np.random.randn(2, 8, 4, 4)  # 輸入 4x4
trans_conv = TransposedConv2D(in_channels=8, out_channels=4, kernel_size=4, stride=2, padding=1)

out = trans_conv.forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Expected: (2, 4, 8, 8) - doubled spatial dimensions")

---

## 第五部分：U-Net 組件

### 5.1 Encoder Block

```
Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN -> ReLU -> MaxPool
                                              ↓
                                         (feature for skip)
```

In [None]:
class DoubleConv:
    """
    U-Net 的基本卷積單元：
    Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN -> ReLU
    """
    
    def __init__(self, in_channels, out_channels):
        self.conv1 = Conv2D(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = BatchNorm2D(out_channels)
        self.relu1 = ReLU()
        
        self.conv2 = Conv2D(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = BatchNorm2D(out_channels)
        self.relu2 = ReLU()
        
        self.cache = None
    
    def forward(self, x):
        out = self.conv1.forward(x)
        out = self.bn1.forward(out)
        out = self.relu1.forward(out)
        
        out = self.conv2.forward(out)
        out = self.bn2.forward(out)
        out = self.relu2.forward(out)
        
        return out
    
    def backward(self, dout):
        dout = self.relu2.backward(dout)
        dout = self.bn2.backward(dout)
        dout = self.conv2.backward(dout)
        
        dout = self.relu1.backward(dout)
        dout = self.bn1.backward(dout)
        dout = self.conv1.backward(dout)
        
        return dout
    
    def train(self):
        self.bn1.training = True
        self.bn2.training = True
    
    def eval(self):
        self.bn1.training = False
        self.bn2.training = False

In [None]:
class EncoderBlock:
    """
    U-Net Encoder Block:
    DoubleConv -> MaxPool
    
    返回：pooled output 和 skip connection 的特徵
    """
    
    def __init__(self, in_channels, out_channels):
        self.double_conv = DoubleConv(in_channels, out_channels)
        self.pool = MaxPool2D(kernel_size=2, stride=2)
        
        self.skip_features = None
    
    def forward(self, x):
        """
        Parameters
        ----------
        x : np.ndarray
        
        Returns
        -------
        pooled : np.ndarray - 給下一層 encoder 用
        skip : np.ndarray - 給對應的 decoder 用（skip connection）
        """
        # Double conv
        features = self.double_conv.forward(x)
        
        # 保存 skip connection 的特徵（pool 之前）
        self.skip_features = features
        
        # Pool
        pooled = self.pool.forward(features)
        
        return pooled, features
    
    def backward(self, dout, dskip=None):
        """
        Parameters
        ----------
        dout : np.ndarray - 來自下一層 encoder 的梯度
        dskip : np.ndarray - 來自 skip connection 的梯度（從 decoder）
        
        Returns
        -------
        dx : np.ndarray
        """
        # Backward through pool
        d_features = self.pool.backward(dout)
        
        # 加上 skip connection 的梯度
        if dskip is not None:
            d_features = d_features + dskip
        
        # Backward through double conv
        dx = self.double_conv.backward(d_features)
        
        return dx
    
    def train(self):
        self.double_conv.train()
    
    def eval(self):
        self.double_conv.eval()

In [None]:
class DecoderBlock:
    """
    U-Net Decoder Block:
    Upsample -> Concat(skip) -> DoubleConv
    """
    
    def __init__(self, in_channels, out_channels, upsample_method='nearest'):
        """
        Parameters
        ----------
        in_channels : int
            輸入 channels（來自 bottleneck 或前一個 decoder）
        out_channels : int
            輸出 channels
        upsample_method : str
            'nearest' or 'bilinear' or 'transposed'
        """
        self.upsample_method = upsample_method
        
        # Upsampling
        if upsample_method == 'nearest':
            self.upsample = NearestUpsample(scale_factor=2)
        elif upsample_method == 'bilinear':
            self.upsample = BilinearUpsample(scale_factor=2)
        else:
            self.upsample = TransposedConv2D(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
        
        # Concat 後 channels 翻倍（in_channels + skip_channels）
        # 假設 skip_channels = in_channels
        self.double_conv = DoubleConv(in_channels * 2, out_channels)
        
        self.cache = None
    
    def forward(self, x, skip):
        """
        Parameters
        ----------
        x : np.ndarray - 來自前一層的特徵
        skip : np.ndarray - 來自 encoder 的 skip connection
        
        Returns
        -------
        out : np.ndarray
        """
        # Upsample
        upsampled = self.upsample.forward(x)
        
        # Concat along channel dimension
        # upsampled: (N, C, H, W), skip: (N, C_skip, H, W)
        concat = np.concatenate([upsampled, skip], axis=1)
        
        # Double conv
        out = self.double_conv.forward(concat)
        
        self.cache = (x, skip, upsampled)
        return out
    
    def backward(self, dout):
        """
        Returns
        -------
        dx : np.ndarray - 給前一層 decoder/bottleneck 的梯度
        dskip : np.ndarray - 給 encoder 的梯度（通過 skip connection）
        """
        x, skip, upsampled = self.cache
        
        # Backward through double conv
        d_concat = self.double_conv.backward(dout)
        
        # Split gradient for concat
        C_up = upsampled.shape[1]
        d_upsampled = d_concat[:, :C_up, :, :]
        dskip = d_concat[:, C_up:, :, :]
        
        # Backward through upsample
        dx = self.upsample.backward(d_upsampled)
        
        return dx, dskip
    
    def train(self):
        self.double_conv.train()
    
    def eval(self):
        self.double_conv.eval()

---

## 第六部分：完整的 SimpleUNet

In [None]:
class SimpleUNet:
    """
    簡化版 U-Net
    
    結構（輸入 64x64）：
    
    Encoder:
        [64x64, 1] -> Enc1 -> [32x32, 32]
        [32x32, 32] -> Enc2 -> [16x16, 64]
        
    Bottleneck:
        [16x16, 64] -> DoubleConv -> [16x16, 128]
        
    Decoder:
        [16x16, 128] + skip[16x16, 64] -> Dec2 -> [32x32, 64]
        [32x32, 64] + skip[32x32, 32] -> Dec1 -> [64x64, 32]
        
    Output:
        [64x64, 32] -> Conv1x1 -> [64x64, n_classes]
    """
    
    def __init__(self, in_channels=1, n_classes=3, base_channels=32):
        self.in_channels = in_channels
        self.n_classes = n_classes
        
        # Encoder
        self.enc1 = EncoderBlock(in_channels, base_channels)         # 1 -> 32
        self.enc2 = EncoderBlock(base_channels, base_channels * 2)   # 32 -> 64
        
        # Bottleneck
        self.bottleneck = DoubleConv(base_channels * 2, base_channels * 4)  # 64 -> 128
        
        # Decoder
        self.dec2 = DecoderBlock(base_channels * 4, base_channels * 2)  # 128 -> 64
        self.dec1 = DecoderBlock(base_channels * 2, base_channels)       # 64 -> 32
        
        # Output
        self.out_conv = Conv2D(base_channels, n_classes, kernel_size=1, stride=1, padding=0)
        
        # 保存 skip features
        self.skip1 = None
        self.skip2 = None
    
    def forward(self, x):
        """
        Parameters
        ----------
        x : np.ndarray, shape (N, in_channels, H, W)
        
        Returns
        -------
        logits : np.ndarray, shape (N, n_classes, H, W)
        """
        # Encoder
        x, self.skip1 = self.enc1.forward(x)   # skip1: (N, 32, H, W)
        x, self.skip2 = self.enc2.forward(x)   # skip2: (N, 64, H/2, W/2)
        
        # Bottleneck
        x = self.bottleneck.forward(x)         # (N, 128, H/4, W/4)
        
        # Decoder
        x = self.dec2.forward(x, self.skip2)   # (N, 64, H/2, W/2)
        x = self.dec1.forward(x, self.skip1)   # (N, 32, H, W)
        
        # Output
        logits = self.out_conv.forward(x)      # (N, n_classes, H, W)
        
        return logits
    
    def backward(self, dlogits):
        """
        Backward pass
        """
        # Output conv
        dout = self.out_conv.backward(dlogits)
        
        # Decoder
        dout, dskip1 = self.dec1.backward(dout)
        dout, dskip2 = self.dec2.backward(dout)
        
        # Bottleneck
        dout = self.bottleneck.backward(dout)
        
        # Encoder (需要加上 skip 的梯度)
        dout = self.enc2.backward(dout, dskip2)
        dout = self.enc1.backward(dout, dskip1)
        
        return dout
    
    def get_params_and_grads(self):
        """獲取所有參數和梯度"""
        params = []
        grads = []
        
        # Helper function
        def add_double_conv(dc):
            params.extend([dc.conv1.W, dc.conv1.b, dc.bn1.gamma, dc.bn1.beta,
                          dc.conv2.W, dc.conv2.b, dc.bn2.gamma, dc.bn2.beta])
            grads.extend([dc.conv1.dW, dc.conv1.db, dc.bn1.dgamma, dc.bn1.dbeta,
                         dc.conv2.dW, dc.conv2.db, dc.bn2.dgamma, dc.bn2.dbeta])
        
        # Encoders
        add_double_conv(self.enc1.double_conv)
        add_double_conv(self.enc2.double_conv)
        
        # Bottleneck
        add_double_conv(self.bottleneck)
        
        # Decoders
        if hasattr(self.dec2.upsample, 'W'):  # TransposedConv
            params.extend([self.dec2.upsample.W, self.dec2.upsample.b])
            grads.extend([self.dec2.upsample.dW, self.dec2.upsample.db])
        add_double_conv(self.dec2.double_conv)
        
        if hasattr(self.dec1.upsample, 'W'):
            params.extend([self.dec1.upsample.W, self.dec1.upsample.b])
            grads.extend([self.dec1.upsample.dW, self.dec1.upsample.db])
        add_double_conv(self.dec1.double_conv)
        
        # Output
        params.extend([self.out_conv.W, self.out_conv.b])
        grads.extend([self.out_conv.dW, self.out_conv.db])
        
        return params, grads
    
    def train(self):
        """Set to training mode"""
        self.enc1.train()
        self.enc2.train()
        self.bottleneck.train()
        self.dec1.train()
        self.dec2.train()
    
    def eval(self):
        """Set to evaluation mode"""
        self.enc1.eval()
        self.enc2.eval()
        self.bottleneck.eval()
        self.dec1.eval()
        self.dec2.eval()

In [None]:
# 測試 SimpleUNet

model = SimpleUNet(in_channels=1, n_classes=3, base_channels=16)
x = np.random.randn(2, 1, 64, 64)

logits = model.forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")
print(f"Expected: (2, 3, 64, 64) - same spatial size as input")

# 統計參數
params, _ = model.get_params_and_grads()
total_params = sum(p.size for p in params if p is not None)
print(f"\nTotal parameters: {total_params:,}")

---

## 第七部分：訓練 U-Net 做分割任務

In [None]:
# 生成訓練數據

def generate_segmentation_data(n_samples, img_size=64):
    """生成簡單的分割數據集
    
    3 類：背景(0), 圓形(1), 方形(2)
    """
    X = np.zeros((n_samples, 1, img_size, img_size))
    y = np.zeros((n_samples, img_size, img_size), dtype=int)
    
    for i in range(n_samples):
        # 隨機背景噪音
        X[i, 0] = np.random.randn(img_size, img_size) * 0.1
        
        # 隨機畫 0-2 個圓形
        n_circles = np.random.randint(0, 3)
        for _ in range(n_circles):
            cx = np.random.randint(12, img_size - 12)
            cy = np.random.randint(12, img_size - 12)
            r = np.random.randint(6, 12)
            
            yy, xx = np.ogrid[:img_size, :img_size]
            mask = (xx - cx)**2 + (yy - cy)**2 <= r**2
            X[i, 0, mask] = 0.7 + np.random.rand() * 0.3
            y[i, mask] = 1
        
        # 隨機畫 0-2 個方形
        n_squares = np.random.randint(0, 3)
        for _ in range(n_squares):
            sx = np.random.randint(5, img_size - 20)
            sy = np.random.randint(5, img_size - 20)
            size = np.random.randint(8, 15)
            
            X[i, 0, sy:sy+size, sx:sx+size] = 0.5 + np.random.rand() * 0.3
            y[i, sy:sy+size, sx:sx+size] = 2
    
    return X.astype(np.float32), y

# 生成數據
np.random.seed(42)
X_train, y_train = generate_segmentation_data(200, img_size=64)
X_test, y_test = generate_segmentation_data(50, img_size=64)

print(f"Training set: X={X_train.shape}, y={y_train.shape}")
print(f"Test set: X={X_test.shape}, y={y_test.shape}")

# 顯示樣本
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(4):
    axes[0, i].imshow(X_train[i, 0], cmap='gray')
    axes[0, i].set_title(f'Image {i}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(y_train[i], cmap='tab10', vmin=0, vmax=3)
    axes[1, i].set_title(f'Mask {i}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
def pixel_wise_cross_entropy(logits, y):
    """計算 pixel-wise cross entropy loss
    
    Parameters
    ----------
    logits : np.ndarray, shape (N, C, H, W)
    y : np.ndarray, shape (N, H, W) - integer labels
    
    Returns
    -------
    loss : float
    dlogits : np.ndarray, shape (N, C, H, W)
    """
    N, C, H, W = logits.shape
    
    # Softmax (along channel dimension)
    logits_max = logits.max(axis=1, keepdims=True)
    exp_logits = np.exp(logits - logits_max)
    probs = exp_logits / exp_logits.sum(axis=1, keepdims=True)
    
    # Gather probabilities for correct class
    # 需要對每個像素位置提取正確類別的機率
    n_idx = np.arange(N).reshape(-1, 1, 1)
    h_idx = np.arange(H).reshape(1, -1, 1)
    w_idx = np.arange(W).reshape(1, 1, -1)
    
    correct_probs = probs[n_idx, y, h_idx, w_idx]
    
    # Cross entropy loss
    loss = -np.mean(np.log(correct_probs + 1e-8))
    
    # Gradient
    dlogits = probs.copy()
    dlogits[n_idx, y, h_idx, w_idx] -= 1
    dlogits /= (N * H * W)
    
    return loss, dlogits


def pixel_accuracy(logits, y):
    """計算像素準確率"""
    preds = logits.argmax(axis=1)  # (N, H, W)
    return (preds == y).mean()


def iou_score(logits, y, n_classes=3):
    """計算 mean IoU (Intersection over Union)"""
    preds = logits.argmax(axis=1)
    ious = []
    
    for c in range(n_classes):
        pred_c = (preds == c)
        true_c = (y == c)
        
        intersection = (pred_c & true_c).sum()
        union = (pred_c | true_c).sum()
        
        if union > 0:
            ious.append(intersection / union)
    
    return np.mean(ious) if ious else 0.0

In [None]:
def train_unet(model, X_train, y_train, X_test, y_test,
               epochs=20, batch_size=8, lr=0.01, momentum=0.9):
    """訓練 U-Net"""
    n_samples = X_train.shape[0]
    n_batches = n_samples // batch_size
    
    # Momentum velocities
    params, _ = model.get_params_and_grads()
    velocities = [np.zeros_like(p) if p is not None else None for p in params]
    
    train_losses = []
    train_accs = []
    train_ious = []
    test_accs = []
    test_ious = []
    
    for epoch in range(epochs):
        # Shuffle
        indices = np.random.permutation(n_samples)
        X_shuffled = X_train[indices]
        y_shuffled = y_train[indices]
        
        epoch_loss = 0
        epoch_acc = 0
        epoch_iou = 0
        
        model.train()
        
        for i in range(n_batches):
            start = i * batch_size
            end = start + batch_size
            X_batch = X_shuffled[start:end]
            y_batch = y_shuffled[start:end]
            
            # Forward
            logits = model.forward(X_batch)
            loss, dlogits = pixel_wise_cross_entropy(logits, y_batch)
            
            # Backward
            model.backward(dlogits)
            
            # Update
            params, grads = model.get_params_and_grads()
            for j, (p, g) in enumerate(zip(params, grads)):
                if p is not None and g is not None:
                    velocities[j] = momentum * velocities[j] - lr * g
                    p += velocities[j]
            
            epoch_loss += loss
            epoch_acc += pixel_accuracy(logits, y_batch)
            epoch_iou += iou_score(logits, y_batch)
        
        # Epoch metrics
        train_loss = epoch_loss / n_batches
        train_acc = epoch_acc / n_batches
        train_iou = epoch_iou / n_batches
        
        # Test metrics
        model.eval()
        test_logits = model.forward(X_test)
        test_acc = pixel_accuracy(test_logits, y_test)
        test_iou = iou_score(test_logits, y_test)
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_ious.append(train_iou)
        test_accs.append(test_acc)
        test_ious.append(test_iou)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:3d}: loss={train_loss:.4f}, "
                  f"train_acc={train_acc:.4f}, train_iou={train_iou:.4f}, "
                  f"test_acc={test_acc:.4f}, test_iou={test_iou:.4f}")
    
    return train_losses, train_accs, train_ious, test_accs, test_ious

In [None]:
# 訓練 U-Net
print("Training SimpleUNet...")
print("=" * 60)

np.random.seed(42)
unet = SimpleUNet(in_channels=1, n_classes=3, base_channels=16)

results = train_unet(
    unet, X_train, y_train, X_test, y_test,
    epochs=25, batch_size=8, lr=0.01
)

train_losses, train_accs, train_ious, test_accs, test_ious = results

In [None]:
# 繪製訓練曲線

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Loss
axes[0].plot(train_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

# Pixel Accuracy
axes[1].plot(train_accs, label='Train')
axes[1].plot(test_accs, label='Test')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Pixel Accuracy')
axes[1].set_title('Pixel Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# IoU
axes[2].plot(train_ious, label='Train')
axes[2].plot(test_ious, label='Test')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Mean IoU')
axes[2].set_title('Mean IoU')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# 視覺化預測結果

unet.eval()
test_logits = unet.forward(X_test[:8])
test_preds = test_logits.argmax(axis=1)

fig, axes = plt.subplots(3, 8, figsize=(16, 6))

for i in range(8):
    # Input image
    axes[0, i].imshow(X_test[i, 0], cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_ylabel('Input', fontsize=12)
    
    # Ground truth
    axes[1, i].imshow(y_test[i], cmap='tab10', vmin=0, vmax=3)
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_ylabel('Ground Truth', fontsize=12)
    
    # Prediction
    axes[2, i].imshow(test_preds[i], cmap='tab10', vmin=0, vmax=3)
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_ylabel('Prediction', fontsize=12)

plt.suptitle('U-Net Segmentation Results', fontsize=14)
plt.tight_layout()
plt.show()

print("\n類別：藍色=背景, 橙色=圓形, 綠色=方形")

---

## 第八部分：Skip Connection 的重要性

為了展示 skip connection 的效果，我們比較有無 skip 的網路。

In [None]:
class SimpleUNetNoSkip:
    """
    沒有 skip connection 的 encoder-decoder 網路
    """
    
    def __init__(self, in_channels=1, n_classes=3, base_channels=32):
        # Encoder
        self.enc1_conv = DoubleConv(in_channels, base_channels)
        self.pool1 = MaxPool2D(2, 2)
        
        self.enc2_conv = DoubleConv(base_channels, base_channels * 2)
        self.pool2 = MaxPool2D(2, 2)
        
        # Bottleneck
        self.bottleneck = DoubleConv(base_channels * 2, base_channels * 4)
        
        # Decoder (no concat, so channels don't double)
        self.up2 = NearestUpsample(2)
        self.dec2_conv = DoubleConv(base_channels * 4, base_channels * 2)
        
        self.up1 = NearestUpsample(2)
        self.dec1_conv = DoubleConv(base_channels * 2, base_channels)
        
        # Output
        self.out_conv = Conv2D(base_channels, n_classes, 1, 1, 0)
    
    def forward(self, x):
        # Encoder
        x = self.enc1_conv.forward(x)
        x = self.pool1.forward(x)
        
        x = self.enc2_conv.forward(x)
        x = self.pool2.forward(x)
        
        # Bottleneck
        x = self.bottleneck.forward(x)
        
        # Decoder (no skip connections)
        x = self.up2.forward(x)
        x = self.dec2_conv.forward(x)
        
        x = self.up1.forward(x)
        x = self.dec1_conv.forward(x)
        
        # Output
        return self.out_conv.forward(x)
    
    def backward(self, dlogits):
        dout = self.out_conv.backward(dlogits)
        
        dout = self.dec1_conv.backward(dout)
        dout = self.up1.backward(dout)
        
        dout = self.dec2_conv.backward(dout)
        dout = self.up2.backward(dout)
        
        dout = self.bottleneck.backward(dout)
        
        dout = self.pool2.backward(dout)
        dout = self.enc2_conv.backward(dout)
        
        dout = self.pool1.backward(dout)
        dout = self.enc1_conv.backward(dout)
        
        return dout
    
    def get_params_and_grads(self):
        params = []
        grads = []
        
        def add_double_conv(dc):
            params.extend([dc.conv1.W, dc.conv1.b, dc.bn1.gamma, dc.bn1.beta,
                          dc.conv2.W, dc.conv2.b, dc.bn2.gamma, dc.bn2.beta])
            grads.extend([dc.conv1.dW, dc.conv1.db, dc.bn1.dgamma, dc.bn1.dbeta,
                         dc.conv2.dW, dc.conv2.db, dc.bn2.dgamma, dc.bn2.dbeta])
        
        add_double_conv(self.enc1_conv)
        add_double_conv(self.enc2_conv)
        add_double_conv(self.bottleneck)
        add_double_conv(self.dec2_conv)
        add_double_conv(self.dec1_conv)
        params.extend([self.out_conv.W, self.out_conv.b])
        grads.extend([self.out_conv.dW, self.out_conv.db])
        
        return params, grads
    
    def train(self):
        for dc in [self.enc1_conv, self.enc2_conv, self.bottleneck,
                   self.dec2_conv, self.dec1_conv]:
            dc.train()
    
    def eval(self):
        for dc in [self.enc1_conv, self.enc2_conv, self.bottleneck,
                   self.dec2_conv, self.dec1_conv]:
            dc.eval()

In [None]:
# 訓練沒有 skip connection 的網路
print("Training UNet WITHOUT skip connections...")
print("=" * 60)

np.random.seed(42)
unet_no_skip = SimpleUNetNoSkip(in_channels=1, n_classes=3, base_channels=16)

results_no_skip = train_unet(
    unet_no_skip, X_train, y_train, X_test, y_test,
    epochs=25, batch_size=8, lr=0.01
)

_, _, train_ious_no_skip, _, test_ious_no_skip = results_no_skip

In [None]:
# 比較有無 skip connection

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Training IoU
axes[0].plot(train_ious, 'b-', label='With Skip Connections')
axes[0].plot(train_ious_no_skip, 'r--', label='Without Skip Connections')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Mean IoU')
axes[0].set_title('Training Mean IoU')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Test IoU
axes[1].plot(test_ious, 'b-', label='With Skip Connections')
axes[1].plot(test_ious_no_skip, 'r--', label='Without Skip Connections')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Mean IoU')
axes[1].set_title('Test Mean IoU')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Test IoU:")
print(f"  With skip connections: {test_ious[-1]:.4f}")
print(f"  Without skip connections: {test_ious_no_skip[-1]:.4f}")

In [None]:
# 視覺化比較預測結果

unet.eval()
unet_no_skip.eval()

logits_with_skip = unet.forward(X_test[:4])
logits_no_skip = unet_no_skip.forward(X_test[:4])

preds_with_skip = logits_with_skip.argmax(axis=1)
preds_no_skip = logits_no_skip.argmax(axis=1)

fig, axes = plt.subplots(4, 4, figsize=(12, 12))

titles = ['Input', 'Ground Truth', 'With Skip', 'Without Skip']

for i in range(4):
    axes[i, 0].imshow(X_test[i, 0], cmap='gray')
    axes[i, 1].imshow(y_test[i], cmap='tab10', vmin=0, vmax=3)
    axes[i, 2].imshow(preds_with_skip[i], cmap='tab10', vmin=0, vmax=3)
    axes[i, 3].imshow(preds_no_skip[i], cmap='tab10', vmin=0, vmax=3)
    
    for j in range(4):
        axes[i, j].axis('off')
        if i == 0:
            axes[i, j].set_title(titles[j])

plt.suptitle('Comparison: With vs Without Skip Connections', fontsize=14)
plt.tight_layout()
plt.show()

print("觀察：有 skip connection 的網路能更好地保留邊界細節")

---

## 總結

### U-Net 的核心設計

1. **Encoder-Decoder 結構**
   - Encoder：逐步降低解析度，提取 high-level 特徵
   - Decoder：逐步恢復解析度，生成 pixel-wise 輸出

2. **Skip Connections**
   - 把 encoder 的特徵直接接到對應的 decoder
   - 保留細節資訊（邊界、紋理）
   - 幫助梯度流動

3. **上採樣方法**
   - Nearest Neighbor：簡單快速
   - Bilinear Interpolation：更平滑
   - Transposed Convolution：可學習

### 實作要點

- 每個 encoder block：DoubleConv → MaxPool，保存 skip features
- 每個 decoder block：Upsample → Concat(skip) → DoubleConv
- Backward 時要正確處理 skip connection 的梯度分流

### 應用場景

- 醫學影像分割（原始應用）
- 語義分割
- 實例分割（結合其他技術）
- 圖像修復、超解析度等