# Implementation: Max Pooling

**Goal**: Implement Max Pooling manually.

In [None]:
import numpy as np

def max_pool(image, pool_size=2, stride=2):
    h, w = image.shape
    out_h = (h - pool_size) // stride + 1
    out_w = (w - pool_size) // stride + 1
    
    output = np.zeros((out_h, out_w))
    
    for i in range(out_h):
        for j in range(out_w):
            curr_h = i * stride
            curr_w = j * stride
            # Slice the window
            window = image[curr_h : curr_h+pool_size, curr_w : curr_w+pool_size]
            # Capture Max
            output[i, j] = np.max(window)
            
    return output

# Mock Data (4x4)
img = np.array([
    [1, 3, 2, 4],
    [5, 6, 1, 2],
    [8, 2, 1, 0],
    [1, 0, 7, 5]
])

pooled = max_pool(img)
print("Original (4x4):\n", img)
print("\nMax Pooled (2x2):\n", pooled)

## Visual Check
*   Window 1 (Top-Left): [1,3,5,6] -> Max is 6.
*   Window 2 (Top-Right): [2,4,1,2] -> Max is 4.
*   Window 3 (Bottom-Left): [8,2,1,0] -> Max is 8.
*   Window 4 (Bottom-Right): [1,0,7,5] -> Max is 7.