In [5]:
import numpy as np
import pickle
import os
# Suppress TensorFlow logs if desired (optional)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
# from scipy.special import softmax # Using manual implementation now

# Use manual softmax for stability and consistency
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    if x.ndim == 1:
        x = x.reshape(1, -1)
    # Subtract max for numerical stability
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / np.sum(e_x, axis=1, keepdims=True)

def load_mnist():
    """
    Load MNIST dataset using TensorFlow's datasets module
    """
    print("Loading MNIST dataset using TensorFlow...")
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Normalize pixel values to [0, 1] and ensure float32
    X_train = x_train.astype(np.float32) / 255.0
    X_test = x_test.astype(np.float32) / 255.0

    y_train = np.array(y_train, dtype=np.int32) # Use int32 for labels
    y_test = np.array(y_test, dtype=np.int32)

    print(f"Data loaded. Training set: {X_train.shape}, Test set: {X_test.shape}")
    return X_train, y_train, X_test, y_test


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    """
    Transform images to columns for efficient convolution operation.
    Ensures float32 output.
    """
    # Ensure input is float32
    input_data = input_data.astype(np.float32, copy=False)
    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad,pad), (pad,pad)], 'constant')
    # Initialize output with float32
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w), dtype=np.float32)

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col


def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
    """
    Transform columns back to image format.
    Ensures float32 output.
    """
    N, C, H, W = input_shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1
    col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)

    # Initialize output with float32
    img = np.zeros((N, C, H + 2*pad, W + 2*pad), dtype=np.float32)

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]

    if pad > 0:
        return img[:, :, pad:-pad, pad:-pad]
    else:
        return img


class Convolution:
    def __init__(self, input_channels, output_channels, kernel_size=3, stride=1, pad=1, learning_rate=0.01):
        scale = np.sqrt(1.0 / (input_channels * kernel_size * kernel_size)).astype(np.float32)
        self.W = scale * np.random.randn(output_channels, input_channels, kernel_size, kernel_size).astype(np.float32)
        self.b = np.zeros(output_channels, dtype=np.float32)
        self.stride = stride
        self.pad = pad
        self.lr = learning_rate
        self.x = None
        self.col = None
        self.col_W = None
        self.dW = None
        self.db = None

    def forward(self, x):
        FN, C, FH, FW = self.W.shape
        N, C, H, W = x.shape
        out_h = (H + 2*self.pad - FH) // self.stride + 1
        out_w = (W + 2*self.pad - FW) // self.stride + 1
        col = im2col(x, FH, FW, self.stride, self.pad)
        col_W = self.W.reshape(FN, -1).T
        out = np.dot(col, col_W) + self.b
        out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
        self.x = x
        self.col = col
        self.col_W = col_W
        return out

    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        # N, _, out_h, out_w = dout.shape # Get N from dout if needed, or rely on self.x
        dout_reshaped = dout.transpose(0, 2, 3, 1).reshape(-1, FN)
        self.db = np.sum(dout_reshaped, axis=0)
        dW_col = np.dot(self.col.T, dout_reshaped)
        self.dW = dW_col.T.reshape(FN, C, FH, FW)
        dcol = np.dot(dout_reshaped, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
        return dx


class MaxPooling:
    def __init__(self, pool_size=2, stride=2, pad=0):
        self.pool_h = pool_size
        self.pool_w = pool_size
        self.stride = stride
        self.pad = pad
        self.x = None
        self.arg_max = None

    def forward(self, x):
        N, C, H, W = x.shape
        out_h = (H + 2*self.pad - self.pool_h) // self.stride + 1
        out_w = (W + 2*self.pad - self.pool_w) // self.stride + 1
        x_reshaped = x.reshape(N*C, 1, H, W)
        col = im2col(x_reshaped, self.pool_h, self.pool_w, self.stride, self.pad)
        # col shape: (N*C*out_h*out_w, pool_h*pool_w)
        arg_max = np.argmax(col, axis=1)
        out = np.max(col, axis=1)
        out = out.reshape(N, C, out_h, out_w)
        self.x = x
        self.arg_max = arg_max
        return out

    def backward(self, dout):
        N, C, out_h, out_w = dout.shape
        N_x, C_x, H_x, W_x = self.x.shape
        dout_flat = dout.flatten()
        pool_size = self.pool_h * self.pool_w
        dmax_size = N * C * out_h * out_w # Or self.arg_max.size
        # Initialize with float32
        dmax = np.zeros((dmax_size, pool_size), dtype=np.float32)
        dmax[np.arange(dmax_size), self.arg_max] = dout_flat
        dcol = dmax
        dx_reshaped_shape = (N*C, 1, H_x, W_x)
        dx_reshaped = col2im(dcol, dx_reshaped_shape, self.pool_h, self.pool_w, self.stride, self.pad)
        dx = dx_reshaped.reshape(self.x.shape)
        return dx


class Flatten:
    def __init__(self):
        self.original_shape = None

    def forward(self, x):
        self.original_shape = x.shape
        batch_size = x.shape[0]
        return x.reshape(batch_size, -1)

    def backward(self, dout):
        return dout.reshape(self.original_shape)


class FullyConnected:
    def __init__(self, input_size, output_size, learning_rate=0.01):
        # He initialization might be better with ReLU, Glorot here
        scale = np.sqrt(2.0 / (input_size + output_size)).astype(np.float32)
        self.W = scale * np.random.randn(input_size, output_size).astype(np.float32)
        self.b = np.zeros(output_size, dtype=np.float32)
        self.x = None
        self.lr = learning_rate
        self.dW = None
        self.db = None

    def forward(self, x):
        self.x = x.astype(np.float32, copy=False) # Ensure float32
        return np.dot(self.x, self.W) + self.b

    def backward(self, dout):
        dx = np.dot(dout, self.W.T)
        self.dW = np.dot(self.x.T, dout)
        self.db = np.sum(dout, axis=0)
        return dx


class ReLU:
    def __init__(self):
        self.mask = None

    def forward(self, x):
        self.mask = (x <= 0)
        out = x.astype(np.float32, copy=True) # Work on a float32 copy
        out[self.mask] = 0
        return out

    def backward(self, dout):
        dout[self.mask] = 0
        return dout


class SoftmaxWithLoss:
    def __init__(self):
        self.loss = None
        self.y = None # Output probabilities
        self.t = None # Target labels (one-hot)

    def forward(self, x, t):
        # Ensure input is float32
        x = x.astype(np.float32, copy=False)
        self.t = t # Store original t for potential use

        # Handle integer labels -> convert to one-hot
        if t.ndim == 1 or t.shape[1] == 1:
             num_classes = x.shape[1]
             t_flat = t.flatten().astype(np.int32) # Ensure integer indices
             t_one_hot = np.zeros((t_flat.size, num_classes), dtype=np.float32)
             t_one_hot[np.arange(t_flat.size), t_flat] = 1.0
             self.t_one_hot = t_one_hot # Store one-hot version
        else:
             self.t_one_hot = t.astype(np.float32, copy=False) # Assume already one-hot

        # Calculate stable softmax
        self.y = softmax(x) # Use the stable softmax function defined outside

        # Calculate cross-entropy loss
        epsilon = 1e-7 # Small value to prevent log(0)
        batch_size = self.t_one_hot.shape[0]
        self.loss = -np.sum(self.t_one_hot * np.log(self.y + epsilon)) / batch_size

        return self.loss

    def backward(self, dout=1):
        batch_size = self.t_one_hot.shape[0]
        # Gradient of loss w.r.t softmax input
        dx = (self.y - self.t_one_hot) / batch_size
        dx = dx * dout # Apply upstream gradient scaling
        return dx


class CNN:
    def __init__(self, learning_rate=0.01):
        input_channels = 1
        conv1_filters = 32
        conv2_filters = 64
        flattened_size = conv2_filters * 7 * 7
        fc1_units = 128
        fc2_units = 10
        self.lr = learning_rate

        self.conv1 = Convolution(input_channels, conv1_filters, 3, 1, 1, self.lr)
        self.relu1 = ReLU()
        self.pool1 = MaxPooling(2, 2)
        self.conv2 = Convolution(conv1_filters, conv2_filters, 3, 1, 1, self.lr)
        self.relu2 = ReLU()
        self.pool2 = MaxPooling(2, 2)
        self.flatten = Flatten()
        self.fc1 = FullyConnected(flattened_size, fc1_units, self.lr)
        self.relu3 = ReLU()
        self.fc2 = FullyConnected(fc1_units, fc2_units, self.lr)
        self.softmax = SoftmaxWithLoss()

        self.layers = [
            self.conv1, self.relu1, self.pool1,
            self.conv2, self.relu2, self.pool2,
            self.flatten, self.fc1, self.relu3, self.fc2
        ]
        self.params = [self.conv1, self.conv2, self.fc1, self.fc2]

    def predict(self, x):
        # Ensure input has 4 dimensions (N, C, H, W) and float32
        if x.ndim == 3:
            x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        elif x.ndim == 2:
             x = x.reshape(1, 1, x.shape[0], x.shape[1])
        x = x.astype(np.float32, copy=False) # Ensure float32

        h = x
        for layer in self.layers:
             h = layer.forward(h)
        return h

    def loss(self, x, t):
        y = self.predict(x)
        return self.softmax.forward(y, t)

    def accuracy(self, x, t, batch_size=100): # Added batch_size
        n_data = x.shape[0]
        acc = 0.0

        # Ensure x is correctly shaped N, C, H, W
        if x.ndim == 3:
             x = x.reshape(n_data, 1, x.shape[1], x.shape[2])
        x = x.astype(np.float32, copy=False) # Ensure float32

        # Ensure targets t are class indices (N,)
        if t.ndim == 2:
            t = np.argmax(t, axis=1)
        t = t.astype(np.int32) # Ensure int indices

        for i in range(0, n_data, batch_size):
            x_batch = x[i : i + batch_size]
            t_batch = t[i : i + batch_size]
            y_batch = self.predict(x_batch) # Predict handles internal float32 conversion
            y_pred = np.argmax(y_batch, axis=1)
            acc += np.sum(y_pred == t_batch)

        return acc / n_data

    def gradient(self, x, t):
        # Ensure input is 4D float32 for forward pass
        if x.ndim == 3:
            x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        x = x.astype(np.float32, copy=False)

        # 1. Forward pass
        self.loss(x, t)
        # 2. Backward pass
        dout = 1.0 # Start with float gradient
        dout = self.softmax.backward(dout)
        for layer in reversed(self.layers):
            dout = layer.backward(dout)
        # Gradients are now stored in self.params layers

    def update_params(self):
        for param_layer in self.params:
             # Ensure gradients are float32 (should be if inputs were)
             dW = param_layer.dW.astype(np.float32, copy=False)
             db = param_layer.db.astype(np.float32, copy=False)
             param_layer.W -= self.lr * dW
             param_layer.b -= self.lr * db

    def train(self, x_train, t_train, x_val, t_val, epochs=5, batch_size=100):
        train_size = x_train.shape[0]
        iter_per_epoch = max(train_size // batch_size, 1)
        train_loss_list, train_acc_list, val_acc_list = [], [], []

        print(f"Starting training for {epochs} epochs...")
        for epoch in range(epochs):
            idx = np.random.permutation(train_size)
            x_train_shuffled = x_train[idx]
            t_train_shuffled = t_train[idx]
            epoch_loss = 0.0

            for i in range(iter_per_epoch):
                start = i * batch_size
                end = start + batch_size
                x_batch = x_train_shuffled[start:end] # Shape (N, H, W)
                t_batch = t_train_shuffled[start:end] # Shape (N,)
                # gradient method handles reshape and dtype
                self.gradient(x_batch, t_batch)
                self.update_params()
                loss = self.softmax.loss
                train_loss_list.append(loss)
                epoch_loss += loss
                if (i + 1) % 100 == 0:
                    print(f"  Epoch {epoch + 1}, Iteration {i + 1}/{iter_per_epoch}, Batch Loss: {loss:.4f}")

            print(f"Epoch {epoch + 1}/{epochs} finished. Evaluating...")
            eval_batch_size = 500
            eval_size = 1000 # Evaluate on a subset
            train_acc = self.accuracy(x_train[:eval_size], t_train[:eval_size], batch_size=eval_batch_size)
            val_acc = self.accuracy(x_val[:eval_size], t_val[:eval_size], batch_size=eval_batch_size)
            train_acc_list.append(train_acc)
            val_acc_list.append(val_acc)
            avg_epoch_loss = epoch_loss / iter_per_epoch

            print(f"  Avg Loss: {avg_epoch_loss:.4f}, Train Acc (on {eval_size}): {train_acc:.4f}, Val Acc (on {eval_size}): {val_acc:.4f}")
            print("-" * 30)

        return train_loss_list, train_acc_list, val_acc_list


    def save_model(self, file_path):
        params_to_save = {
            'conv1_W': self.conv1.W, 'conv1_b': self.conv1.b,
            'conv2_W': self.conv2.W, 'conv2_b': self.conv2.b,
            'fc1_W': self.fc1.W, 'fc1_b': self.fc1.b,
            'fc2_W': self.fc2.W, 'fc2_b': self.fc2.b,
            'lr': self.lr
        }
        with open(file_path, 'wb') as f:
            pickle.dump(params_to_save, f)
        print(f"Model parameters saved to {file_path}")

    @staticmethod
    def load_model(file_path, learning_rate=0.01):
        with open(file_path, 'rb') as f:
            params_loaded = pickle.load(f)
        lr = params_loaded.get('lr', learning_rate)
        model = CNN(learning_rate=lr)
        model.conv1.W = params_loaded['conv1_W']
        model.conv1.b = params_loaded['conv1_b']
        model.conv2.W = params_loaded['conv2_W']
        model.conv2.b = params_loaded['conv2_b']
        model.fc1.W = params_loaded['fc1_W']
        model.fc1.b = params_loaded['fc1_b']
        model.fc2.W = params_loaded['fc2_W']
        model.fc2.b = params_loaded['fc2_b']
        print(f"Model parameters loaded from {file_path}")
        return model


# Example usage
if __name__ == "__main__":
    X_train, y_train, X_test, y_test = load_mnist()
    model = CNN(learning_rate=0.01)
    print("Training CNN...")
    # Train for a few epochs if desired
    train_loss, train_acc, val_acc = model.train(X_train, y_train, X_test, y_test, epochs=1, batch_size=100)

    print("Evaluating on full test set...")
    # accuracy method now handles reshape and batching
    test_acc = model.accuracy(X_test, y_test, batch_size=500) # Use batching
    print(f"Test accuracy: {test_acc:.4f}")

    model.save_model('mnist_cnn_params.pkl')

    # Optional: Load and test again
    # print("\nLoading model and re-evaluating...")
    # loaded_model = CNN.load_model('mnist_cnn_params.pkl')
    # # Ensure X_test is passed directly, accuracy handles reshape
    # test_acc_loaded = loaded_model.accuracy(X_test, y_test, batch_size=500)
    # print(f"Loaded Model Test accuracy: {test_acc_loaded:.4f}")

Loading MNIST dataset using TensorFlow...
Data loaded. Training set: (60000, 28, 28), Test set: (10000, 28, 28)
Training CNN...
Starting training for 1 epochs...
  Epoch 1, Iteration 100/600, Batch Loss: 0.6225
  Epoch 1, Iteration 200/600, Batch Loss: 0.5307
  Epoch 1, Iteration 300/600, Batch Loss: 0.2840
  Epoch 1, Iteration 400/600, Batch Loss: 0.3330
  Epoch 1, Iteration 500/600, Batch Loss: 0.3331
  Epoch 1, Iteration 600/600, Batch Loss: 0.2047
Epoch 1/1 finished. Evaluating...
  Avg Loss: 0.4851, Train Acc (on 1000): 0.9290, Val Acc (on 1000): 0.9290
------------------------------
Evaluating on full test set...
Test accuracy: 0.9388
Model parameters saved to mnist_cnn_params.pkl
