In [1]:
import numpy as np
from keras.datasets import mnist
from keras.utils import to_categorical

# Load data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:1000] / 255.0  # Use small subset for speed
y_train = to_categorical(y_train[:1000], 10)

x_test = x_test[:200] / 255.0
y_test = to_categorical(y_test[:200], 10)

# Initialize weights
np.random.seed(0)
conv_filter = np.random.randn(8, 3, 3) * 0.1  # 8 filters, 3x3
fc_w = np.random.randn(8 * 13 * 13, 10) * 0.1
fc_b = np.zeros((1, 10))


# Activation
def relu(x):
    return np.maximum(0, x)

def relu_deriv(x):
    return x > 0

# Softmax
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / np.sum(e_x, axis=1, keepdims=True)

# Convolution
def conv2d(x, filters):
    n_filters, k, _ = filters.shape
    h_out = x.shape[1] - k + 1
    w_out = x.shape[2] - k + 1
    out = np.zeros((n_filters, h_out, w_out))
    for f in range(n_filters):
        for i in range(h_out):
            for j in range(w_out):
                out[f, i, j] = np.sum(x[:, i:i+k, j:j+k] * filters[f])
    return out

# Max pooling
def max_pool(x, size=2, stride=2):
    c, h, w = x.shape
    h_out = (h - size) // stride + 1
    w_out = (w - size) // stride + 1
    out = np.zeros((c, h_out, w_out))
    for ch in range(c):
        for i in range(0, h - size + 1, stride):
            for j in range(0, w - size + 1, stride):
                out[ch, i//stride, j//stride] = np.max(x[ch, i:i+size, j:j+size])
    return out

# Flatten
def flatten(x):
    return x.reshape(-1)

# Training
lr = 0.01
epochs = 3
batch_size = 10

for epoch in range(epochs):
    for i in range(0, len(x_train), batch_size):
        X_batch = x_train[i:i+batch_size]
        Y_batch = y_train[i:i+batch_size]
        
        dW = np.zeros_like(conv_filter)
        dfc_w = np.zeros_like(fc_w)
        dfc_b = np.zeros_like(fc_b)
        loss = 0
        
        for x_img, y_true in zip(X_batch, Y_batch):
            # Forward
            x_img = x_img.reshape(1, 28, 28)
            conv_out = conv2d(x_img, conv_filter)
            relu_out = relu(conv_out)
            pooled = max_pool(relu_out)
            flat = flatten(pooled).reshape(1, -1)
            logits = flat @ fc_w + fc_b
            probs = softmax(logits)
            
            loss += -np.sum(y_true * np.log(probs + 1e-9))
            
            # Backward (FC only for simplicity now)
            dlogits = probs - y_true.reshape(1, -1)
            dfc_w += flat.T @ dlogits
            dfc_b += dlogits

        # Update
        fc_w -= lr * dfc_w / batch_size
        fc_b -= lr * dfc_b / batch_size

    print(f"Epoch {epoch+1}, Loss: {loss/len(X_batch):.4f}")

Epoch 1, Loss: 2.0540
Epoch 2, Loss: 1.8441
Epoch 3, Loss: 1.6576
