In [1]:
import mlx.core as mx
import mlx.nn as nn
from mlx.data import datasets

In [2]:
# Load training set
mnist_train = datasets.load_mnist(train=True)

In [3]:
def get_streamed_data(data, batch_size=0, shuffled=True):
    buffer = data.shuffle() if shuffled else data
    stream = buffer.to_stream()
    stream = stream.key_transform("image", lambda x: x.astype("float32"))
    stream = stream.batch(batch_size) if batch_size > 0 else stream
    return stream.prefetch(4, 2)

In [4]:
mnist_trainstream = get_streamed_data(data=mnist_train, batch_size=32, shuffled=True)

batch = next(mnist_trainstream)
X, y = mx.array(batch["image"]), mx.array(batch["label"])

X.shape, y.shape

((32, 28, 28, 1), (32,))

In [5]:
# Create a convolution layer with kernel_size 2x2 and 1 stride
conv2d = nn.Conv2d(
    in_channels=X.shape[-1],
    out_channels=1,
    kernel_size=2,
    padding=0,
    stride=1,
)

res = conv2d(X)
res.shape

(32, 27, 27, 1)

In [6]:
# Calculating output dimension of a Conv2d layer
def get_output_dim(width, stride, kernel_size, padding):
    """
    Assuming that the input image has dimension of width x width (square image).
    We only need the width in this case
    """ 
    return ((width - kernel_size + (2 * padding)) / stride) + 1

output_dim = get_output_dim(X.shape[1], 1, 2, 0)
print(f"Output image: {output_dim} x {output_dim}")

stride = 2
padding = 0
kernel_size = 2

output_dim = get_output_dim(X.shape[1], stride, kernel_size, padding)
print(f"Output image: {output_dim} x {output_dim}")

Output image: 27.0 x 27.0
Output image: 14.0 x 14.0


In [8]:
# Pooling layer
pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
output = pool2d(res)
output.shape

(32, 13, 13, 1)

In [10]:
def get_pooling_layer_out_dim(width, height, stride, kernel_size, padding):
    """
    Assuming that stride and kernel is a square (e.g., 2x2)
    """
    w_out = ((width + 2 * padding - kernel_size) / stride) + 1
    h_out = ((height + 2 * padding - kernel_size) / stride) + 1
    return w_out, h_out

w_out, h_out = get_pooling_layer_out_dim(X.shape[1], X.shape[1], stride, kernel_size, padding)
print(f"Pooling layer output w x h -> {w_out} x {h_out}")

Pooling layer output w x h -> 14.0 x 14.0
