# im2col 技巧 - 卷積加速的核心

## 學習目標

1. 理解 im2col 如何把卷積轉換成矩陣乘法
2. 實作 im2col forward 和 col2im backward
3. 使用 GEMM（矩陣乘法）加速卷積
4. 比較不同實作的效能

## 核心想法

> **卷積的本質是大量的向量內積**，而矩陣乘法正是把多個內積打包在一起。
> 透過 im2col，我們可以利用高度優化的 BLAS 矩陣乘法來計算卷積。

In [None]:
import numpy as np
import time
import matplotlib.pyplot as plt

np.random.seed(42)

---

## 第一部分：卷積到矩陣乘法的轉換

### 1.1 為什麼這樣做？

卷積的計算模式：
```
output[i,j] = Σ input_patch[i,j] * kernel
            = inner_product(input_patch.flatten(), kernel.flatten())
```

如果把所有 patch 排成一個矩陣，卷積就變成矩陣乘法：

```
┌──────────────────────┐   ┌──────────┐   ┌─────────────┐
│  patch_0 (flattened) │   │ kernel_0 │   │  output_0   │
│  patch_1 (flattened) │ × │ kernel_1 │ = │  output_1   │
│  patch_2 (flattened) │   │   ...    │   │    ...      │
│       ...            │   │ kernel_n │   │  output_n   │
└──────────────────────┘   └──────────┘   └─────────────┘
   (M, K)                     (K, N)          (M, N)

M = 輸出位置數 (out_H * out_W * N_batch)
K = patch 大小 (C_in * kH * kW)
N = 輸出 channel 數 (C_out)
```

In [None]:
# 視覺化 im2col 的概念

def visualize_im2col_concept():
    """用簡單例子展示 im2col"""
    
    # 4x4 輸入，2x2 kernel，stride=1
    x = np.arange(16).reshape(4, 4)
    print("Input (4x4):")
    print(x)
    print()
    
    # 2x2 kernel，輸出 3x3 = 9 個位置
    # 每個 patch 有 4 個元素
    
    print("Patches (9 patches, each 2x2=4 elements):")
    kH, kW = 2, 2
    patches = []
    for i in range(3):  # out_H
        for j in range(3):  # out_W
            patch = x[i:i+kH, j:j+kW].flatten()
            patches.append(patch)
            print(f"  Patch ({i},{j}): {patch}")
    
    # im2col 結果
    col = np.array(patches)
    print(f"\nim2col result (9 x 4):")
    print(col)
    print(f"\nShape: {col.shape} = (out_positions, patch_size)")
    
    # 展示矩陣乘法
    kernel = np.array([[1, 0], [0, 1]])  # 簡單的對角 kernel
    kernel_flat = kernel.flatten()  # (4,)
    
    print(f"\nKernel:")
    print(kernel)
    print(f"Kernel flattened: {kernel_flat}")
    
    # 矩陣乘法
    output = col @ kernel_flat  # (9,) @ (4,) 變成 (9,) ... 實際上是 (9, 4) @ (4,) = (9,)
    output = output.reshape(3, 3)
    
    print(f"\nOutput (via matrix multiply):")
    print(output)
    
    # 驗證：和 naive 卷積相同
    output_naive = np.zeros((3, 3))
    for i in range(3):
        for j in range(3):
            output_naive[i, j] = np.sum(x[i:i+2, j:j+2] * kernel)
    
    print(f"\nOutput (via naive convolution):")
    print(output_naive)
    print(f"\nMatch: {np.allclose(output, output_naive)}")

visualize_im2col_concept()

---

## 第二部分：im2col 實作

### 2.1 Method 1: 使用迴圈（慢但清晰）

In [None]:
def im2col_loop(x, kH, kW, stride=1, padding=0):
    """
    im2col 的迴圈實作（慢但清晰）
    
    Parameters
    ----------
    x : np.ndarray, shape (N, C, H, W)
        輸入圖片
    kH, kW : int
        kernel 大小
    stride : int
        步幅
    padding : int
        零填充
    
    Returns
    -------
    col : np.ndarray, shape (N * out_H * out_W, C * kH * kW)
        展開後的矩陣
    """
    N, C, H, W = x.shape
    
    # Padding
    if padding > 0:
        x = np.pad(x, ((0, 0), (0, 0), (padding, padding), (padding, padding)), 
                   mode='constant', constant_values=0)
    
    H_pad, W_pad = x.shape[2], x.shape[3]
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    # 輸出矩陣
    col = np.zeros((N * out_H * out_W, C * kH * kW))
    
    # 填充
    idx = 0
    for n in range(N):
        for i in range(out_H):
            for j in range(out_W):
                h_start = i * stride
                w_start = j * stride
                # 取出 patch 並展平
                patch = x[n, :, h_start:h_start+kH, w_start:w_start+kW]
                col[idx] = patch.flatten()
                idx += 1
    
    return col


# 測試
x = np.random.randn(2, 3, 8, 8)  # 2 張 3 channel 的 8x8 圖片
col = im2col_loop(x, kH=3, kW=3, stride=1, padding=1)
print(f"Input shape: {x.shape}")
print(f"im2col output shape: {col.shape}")
print(f"Expected: (N*out_H*out_W, C*kH*kW) = (2*8*8, 3*3*3) = (128, 27)")

### 2.2 Method 2: 使用 stride_tricks（快）

In [None]:
def im2col_strided(x, kH, kW, stride=1, padding=0):
    """
    im2col 的高效實作（使用 stride_tricks）
    
    核心想法：
    1. 用 as_strided 建立一個 view，包含所有 patches
    2. 只是改變了 strides，沒有實際複製數據
    3. reshape 時才複製（如果需要）
    """
    N, C, H, W = x.shape
    
    # Padding
    if padding > 0:
        x = np.pad(x, ((0, 0), (0, 0), (padding, padding), (padding, padding)),
                   mode='constant', constant_values=0)
    
    H_pad, W_pad = x.shape[2], x.shape[3]
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    # 計算 strides
    # x 的 strides: (N_stride, C_stride, H_stride, W_stride)
    # 我們要建立 shape 為 (N, C, kH, kW, out_H, out_W) 的 view
    
    # 原始 strides（以 bytes 為單位）
    s0, s1, s2, s3 = x.strides
    
    # 新的 shape 和 strides
    shape = (N, C, kH, kW, out_H, out_W)
    strides = (s0,                  # 跨 batch
               s1,                  # 跨 channel
               s2,                  # 跨 kernel height
               s3,                  # 跨 kernel width
               s2 * stride,         # 跨 output height
               s3 * stride)         # 跨 output width
    
    # 建立 strided view
    col_strided = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
    
    # Reshape: (N, C, kH, kW, out_H, out_W) -> (N * out_H * out_W, C * kH * kW)
    # 先 transpose 成 (N, out_H, out_W, C, kH, kW)
    col = col_strided.transpose(0, 4, 5, 1, 2, 3)
    # 再 reshape
    col = col.reshape(N * out_H * out_W, -1)
    
    return col


# 測試和驗證
x = np.random.randn(2, 3, 8, 8)
col_loop = im2col_loop(x, 3, 3, stride=1, padding=1)
col_strided = im2col_strided(x, 3, 3, stride=1, padding=1)

print(f"Loop output shape:    {col_loop.shape}")
print(f"Strided output shape: {col_strided.shape}")
print(f"Results match: {np.allclose(col_loop, col_strided)}")

In [None]:
# 效能比較

x = np.random.randn(4, 32, 64, 64).astype(np.float32)

# Loop version
start = time.perf_counter()
col1 = im2col_loop(x, 3, 3, stride=1, padding=1)
time_loop = time.perf_counter() - start

# Strided version
start = time.perf_counter()
col2 = im2col_strided(x, 3, 3, stride=1, padding=1)
time_strided = time.perf_counter() - start

print(f"im2col performance (4, 32, 64, 64) with 3x3 kernel:")
print(f"  Loop:    {time_loop:.4f}s")
print(f"  Strided: {time_strided:.4f}s")
print(f"  Speedup: {time_loop / time_strided:.1f}x")

---

## 第三部分：col2im（backward 用）

In [None]:
def col2im(col, x_shape, kH, kW, stride=1, padding=0):
    """
    im2col 的反操作（用於 backward）
    
    把 col 矩陣「折回」成原始的圖片格式。
    注意：重疊的位置會被加總（這正是 backward 需要的行為）。
    
    Parameters
    ----------
    col : np.ndarray, shape (N * out_H * out_W, C * kH * kW)
    x_shape : tuple
        原始輸入的 shape (N, C, H, W)
    
    Returns
    -------
    x : np.ndarray, shape (N, C, H, W)
    """
    N, C, H, W = x_shape
    
    # 計算 padding 後的大小
    H_pad = H + 2 * padding
    W_pad = W + 2 * padding
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    # Reshape col: (N*out_H*out_W, C*kH*kW) -> (N, out_H, out_W, C, kH, kW)
    col = col.reshape(N, out_H, out_W, C, kH, kW)
    # Transpose: (N, C, kH, kW, out_H, out_W)
    col = col.transpose(0, 3, 4, 5, 1, 2)
    
    # 初始化輸出（帶 padding）
    x_pad = np.zeros((N, C, H_pad, W_pad), dtype=col.dtype)
    
    # 把 col 的值加回去
    # col shape: (N, C, kH, kW, out_H, out_W)
    for i in range(kH):
        i_max = i + stride * out_H
        for j in range(kW):
            j_max = j + stride * out_W
            # 這裡用 += 因為重疊的位置需要加總
            x_pad[:, :, i:i_max:stride, j:j_max:stride] += col[:, :, i, j, :, :]
    
    # 移除 padding
    if padding > 0:
        return x_pad[:, :, padding:-padding, padding:-padding]
    return x_pad


# 測試 col2im
x = np.random.randn(2, 3, 8, 8)
col = im2col_strided(x, 3, 3, stride=1, padding=1)
x_reconstructed = col2im(col, x.shape, 3, 3, stride=1, padding=1)

print(f"Original x shape: {x.shape}")
print(f"col shape: {col.shape}")
print(f"Reconstructed x shape: {x_reconstructed.shape}")

# 注意：因為 3x3 kernel 會有 9 倍的重疊，重建不會完全相等
# 但中間的值（被完整覆蓋 9 次）應該是 9 倍
print(f"\nCenter values comparison (should be 9x):")
print(f"  Original center: {x[0, 0, 3, 3]:.4f}")
print(f"  Reconstructed:   {x_reconstructed[0, 0, 3, 3]:.4f}")
print(f"  Ratio:           {x_reconstructed[0, 0, 3, 3] / x[0, 0, 3, 3]:.1f}x")

---

## 第四部分：使用 im2col 實作卷積

In [None]:
class Conv2D_im2col:
    """
    使用 im2col 加速的 2D 卷積層
    """
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # He initialization
        scale = np.sqrt(2.0 / (in_channels * kernel_size * kernel_size))
        self.W = np.random.randn(out_channels, in_channels, kernel_size, kernel_size) * scale
        self.b = np.zeros(out_channels)
        
        # 梯度
        self.dW = None
        self.db = None
        
        # Cache
        self.cache = None
    
    def forward(self, x):
        """
        Forward pass using im2col + matrix multiplication
        
        Parameters
        ----------
        x : np.ndarray, shape (N, C_in, H, W)
        
        Returns
        -------
        out : np.ndarray, shape (N, C_out, H_out, W_out)
        """
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        
        # 計算輸出大小
        H_pad = H + 2 * self.padding
        W_pad = W + 2 * self.padding
        out_H = (H_pad - kH) // self.stride + 1
        out_W = (W_pad - kW) // self.stride + 1
        
        # im2col: (N*out_H*out_W, C_in*kH*kW)
        col = im2col_strided(x, kH, kW, self.stride, self.padding)
        
        # Reshape weights: (C_out, C_in*kH*kW)
        W_col = self.W.reshape(self.out_channels, -1)
        
        # Matrix multiplication: (N*out_H*out_W, C_in*kH*kW) @ (C_in*kH*kW, C_out)
        # Result: (N*out_H*out_W, C_out)
        out_col = col @ W_col.T + self.b
        
        # Reshape output: (N, out_H, out_W, C_out) -> (N, C_out, out_H, out_W)
        out = out_col.reshape(N, out_H, out_W, self.out_channels).transpose(0, 3, 1, 2)
        
        # Cache for backward
        self.cache = (x, col, W_col)
        
        return out
    
    def backward(self, dout):
        """
        Backward pass
        
        Parameters
        ----------
        dout : np.ndarray, shape (N, C_out, out_H, out_W)
        
        Returns
        -------
        dx : np.ndarray, shape (N, C_in, H, W)
        """
        x, col, W_col = self.cache
        N, C, H, W = x.shape
        kH = kW = self.kernel_size
        
        # Reshape dout: (N, C_out, out_H, out_W) -> (N*out_H*out_W, C_out)
        dout_col = dout.transpose(0, 2, 3, 1).reshape(-1, self.out_channels)
        
        # Gradient for weights: dW = col.T @ dout_col
        # col: (N*out_H*out_W, C_in*kH*kW)
        # dout_col: (N*out_H*out_W, C_out)
        # dW_col: (C_in*kH*kW, C_out)
        dW_col = col.T @ dout_col  # (C_in*kH*kW, C_out)
        self.dW = dW_col.T.reshape(self.W.shape)  # (C_out, C_in, kH, kW)
        
        # Gradient for bias
        self.db = dout_col.sum(axis=0)
        
        # Gradient for input: dcol = dout_col @ W_col
        # dout_col: (N*out_H*out_W, C_out)
        # W_col: (C_out, C_in*kH*kW)
        dcol = dout_col @ W_col  # (N*out_H*out_W, C_in*kH*kW)
        
        # col2im
        dx = col2im(dcol, x.shape, kH, kW, self.stride, self.padding)
        
        return dx


# 測試
conv = Conv2D_im2col(in_channels=3, out_channels=16, kernel_size=3, padding=1)
x = np.random.randn(2, 3, 32, 32)

# Forward
out = conv.forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

# Backward
dout = np.random.randn(*out.shape)
dx = conv.backward(dout)
print(f"dx shape: {dx.shape}")
print(f"dW shape: {conv.dW.shape}")
print(f"db shape: {conv.db.shape}")

In [None]:
# Gradient check

def numerical_gradient(f, x, eps=1e-5):
    """數值梯度"""
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    
    while not it.finished:
        idx = it.multi_index
        old_val = x[idx]
        
        x[idx] = old_val + eps
        fxh1 = f(x)
        
        x[idx] = old_val - eps
        fxh2 = f(x)
        
        grad[idx] = (fxh1 - fxh2) / (2 * eps)
        x[idx] = old_val
        
        it.iternext()
    
    return grad

# 小型測試
np.random.seed(42)
conv = Conv2D_im2col(in_channels=2, out_channels=3, kernel_size=3, padding=1)
x = np.random.randn(1, 2, 4, 4)

# Forward and backward
out = conv.forward(x)
dout = np.random.randn(*out.shape)
dx_analytic = conv.backward(dout)

# Numerical gradient for dx
def f_x(x_):
    return np.sum(conv.forward(x_) * dout)

dx_numeric = numerical_gradient(f_x, x.copy())

# Compare
diff = np.abs(dx_analytic - dx_numeric)
max_diff = np.max(diff)
rel_error = max_diff / (np.maximum(np.abs(dx_analytic).max(), np.abs(dx_numeric).max()) + 1e-8)

print(f"Gradient check for dx:")
print(f"  Max difference: {max_diff:.2e}")
print(f"  Relative error: {rel_error:.2e}")
print(f"  Status: {'PASS' if rel_error < 1e-4 else 'FAIL'}")

---

## 第五部分：效能比較

In [None]:
def conv2d_naive_4d(x, W, b=None, stride=1, padding=0):
    """Naive 4D 卷積（6 層迴圈）"""
    N, C_in, H, W_in = x.shape
    C_out, _, kH, kW = W.shape
    
    if padding > 0:
        x = np.pad(x, ((0, 0), (0, 0), (padding, padding), (padding, padding)))
    
    H_pad, W_pad = x.shape[2], x.shape[3]
    out_H = (H_pad - kH) // stride + 1
    out_W = (W_pad - kW) // stride + 1
    
    output = np.zeros((N, C_out, out_H, out_W))
    
    for n in range(N):
        for c_out in range(C_out):
            for i in range(out_H):
                for j in range(out_W):
                    h_start = i * stride
                    w_start = j * stride
                    patch = x[n, :, h_start:h_start+kH, w_start:w_start+kW]
                    output[n, c_out, i, j] = np.sum(patch * W[c_out])
    
    if b is not None:
        output += b.reshape(1, -1, 1, 1)
    
    return output


def benchmark_conv(sizes, C_in=3, C_out=32, kernel_size=3):
    """比較 naive 和 im2col 的效能"""
    results = []
    
    W = np.random.randn(C_out, C_in, kernel_size, kernel_size).astype(np.float32)
    b = np.random.randn(C_out).astype(np.float32)
    conv = Conv2D_im2col(C_in, C_out, kernel_size, padding=1)
    conv.W = W.copy()
    conv.b = b.copy()
    
    print(f"{'Size':>10} {'Naive':>12} {'im2col':>12} {'Speedup':>10}")
    print("-" * 50)
    
    for size in sizes:
        x = np.random.randn(1, C_in, size, size).astype(np.float32)
        
        # Naive (skip if too slow)
        if size <= 32:
            start = time.perf_counter()
            out_naive = conv2d_naive_4d(x, W, b, padding=1)
            time_naive = time.perf_counter() - start
        else:
            time_naive = np.nan
        
        # im2col
        start = time.perf_counter()
        for _ in range(10):
            out_im2col = conv.forward(x)
        time_im2col = (time.perf_counter() - start) / 10
        
        if not np.isnan(time_naive):
            speedup = time_naive / time_im2col
            print(f"{size:>10} {time_naive:>10.4f}s {time_im2col:>10.4f}s {speedup:>9.1f}x")
        else:
            print(f"{size:>10} {'skip':>12} {time_im2col:>10.4f}s {'N/A':>10}")
        
        results.append({
            'size': size,
            'naive': time_naive,
            'im2col': time_im2col
        })
    
    return results

print("Convolution Benchmark (3x3 kernel, padding=1):")
print("=" * 50)
results = benchmark_conv([8, 16, 32, 64, 128, 256])

In [None]:
# 繪製效能圖

sizes = [r['size'] for r in results]
times_naive = [r['naive'] for r in results]
times_im2col = [r['im2col'] for r in results]

# 過濾 NaN
valid = [(s, n, i) for s, n, i in zip(sizes, times_naive, times_im2col) if not np.isnan(n)]
sizes_valid = [v[0] for v in valid]
naive_valid = [v[1] for v in valid]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 時間比較
axes[0].semilogy(sizes_valid, naive_valid, 'o-', label='Naive (6 loops)', linewidth=2, markersize=8)
axes[0].semilogy(sizes, times_im2col, 's-', label='im2col + GEMM', linewidth=2, markersize=8)
axes[0].set_xlabel('Image Size (pixels)', fontsize=12)
axes[0].set_ylabel('Time (seconds, log scale)', fontsize=12)
axes[0].set_title('Convolution Performance', fontsize=14)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Speedup
speedups = [n / i for n, i in zip(naive_valid, [r['im2col'] for r in results[:len(naive_valid)]])]
axes[1].bar(range(len(sizes_valid)), speedups, color='green', alpha=0.7)
axes[1].set_xticks(range(len(sizes_valid)))
axes[1].set_xticklabels(sizes_valid)
axes[1].set_xlabel('Image Size (pixels)', fontsize=12)
axes[1].set_ylabel('Speedup (x times)', fontsize=12)
axes[1].set_title('im2col Speedup over Naive', fontsize=14)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## 第六部分：im2col 的記憶體考量

im2col 的缺點是**記憶體使用量大**：

- 原始輸入：`N * C * H * W`
- im2col 後：`N * out_H * out_W * C * kH * kW`

對於 3x3 kernel，記憶體增加約 **9 倍**！

In [None]:
def analyze_memory_usage(N, C, H, W, kH, kW, padding=0):
    """分析 im2col 的記憶體使用"""
    bytes_per_float = 4  # float32
    
    # 原始輸入
    input_size = N * C * H * W
    input_bytes = input_size * bytes_per_float
    
    # im2col 輸出
    out_H = H + 2 * padding - kH + 1
    out_W = W + 2 * padding - kW + 1
    col_size = N * out_H * out_W * C * kH * kW
    col_bytes = col_size * bytes_per_float
    
    # 輸出
    # (假設 C_out = C)
    output_size = N * C * out_H * out_W
    output_bytes = output_size * bytes_per_float
    
    print(f"Memory analysis for ({N}, {C}, {H}, {W}) with {kH}x{kW} kernel:")
    print(f"  Input:    {input_bytes / 1024**2:.2f} MB ({input_size:,} floats)")
    print(f"  im2col:   {col_bytes / 1024**2:.2f} MB ({col_size:,} floats)")
    print(f"  Output:   {output_bytes / 1024**2:.2f} MB")
    print(f"  Memory expansion: {col_bytes / input_bytes:.1f}x")
    print(f"  Total peak memory: {(input_bytes + col_bytes + output_bytes) / 1024**2:.2f} MB")

# 分析不同配置
print("Small model:")
analyze_memory_usage(32, 64, 56, 56, 3, 3, padding=1)

print("\nLarge model:")
analyze_memory_usage(32, 256, 56, 56, 3, 3, padding=1)

print("\n7x7 kernel:")
analyze_memory_usage(32, 64, 56, 56, 7, 7, padding=3)

---

## 總結

### im2col 的核心概念

1. **把卷積轉成矩陣乘法**
   - 所有 patch 展開成一個大矩陣
   - 用高度優化的 GEMM 計算

2. **im2col**：Forward 用
   - 把輸入展開成 `(N*out_H*out_W, C*kH*kW)`
   - 使用 `stride_tricks` 避免實際複製

3. **col2im**：Backward 用
   - 把梯度折回原始 shape
   - 重疊位置會被加總

### 效能提升

| 方法 | 時間複雜度相同，但... |
|-----|----------------------|
| Naive | 6 層 Python 迴圈開銷大 |
| im2col | 1 次 GEMM，利用 BLAS 多執行緒 |

### 取捨

- **優點**：速度快 10-100 倍
- **缺點**：記憶體增加 kH*kW 倍（3x3 kernel 約 9 倍）

### 實際應用

- 所有主流深度學習框架都用 im2col 或類似技術
- GPU 上還有更高效的 CUDA 實作（cuDNN）