In [1]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.data.datasets import load_mnist

In [2]:
# Load MNIST dataset
# If not exists, the mlx.data package will download
# the dataset into your local directory
# you can provide `root` arguments to set the destination
mnist_train = load_mnist(train=True)
mnist_test = load_mnist(train=False)

len(mnist_train), len(mnist_test)

(60000, 10000)

In [3]:
def get_streamed_data(data, batch_size=32, shuffled=True):
    buffer = data.shuffle() if shuffled else data
    return (buffer
            .to_stream()
            .key_transform("image", lambda x: x.astype("float32").reshape(-1,))
            .batch(batch_size)
            .prefetch(4, 2)
    )

mnist_trainstream = get_streamed_data(mnist_train)

## Without .reshape(-1,), the shape of the batch is (32, 28, 28, 1)
## With .reshape(-1,) the shape of the batch is (32, 784) -> easier to handle with fully-connected layer

# Uncomment below lines to see the shape
first_batch = next(mnist_trainstream)
X, y = first_batch["image"], first_batch["label"]
X.shape, y.shape

((32, 784), (32,))

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims):
        super().__init__()
        self.sequential = nn.Sequential(
            nn.Linear(input_dims=input_dims, output_dims=hidden_dims),
            nn.ReLU(),
            nn.Linear(input_dims=hidden_dims, output_dims=hidden_dims),
            nn.ReLU(),
            nn.Linear(input_dims=hidden_dims, output_dims=hidden_dims),
            nn.ReLU(),
            nn.Linear(input_dims=hidden_dims, output_dims=output_dims)
        )
    def __call__(self, x):
        return self.sequential(x)

model = MLP(input_dims=784, hidden_dims=64, output_dims=10)
mx.eval(model.parameters())
model

MLP(
  (sequential): Sequential(
    (layers.0): Linear(input_dims=784, output_dims=64, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=64, output_dims=64, bias=True)
    (layers.3): ReLU()
    (layers.4): Linear(input_dims=64, output_dims=64, bias=True)
    (layers.5): ReLU()
    (layers.6): Linear(input_dims=64, output_dims=10, bias=True)
  )
)

In [5]:
def loss_fn(model, X, y):
    logits = model(X)
    return nn.losses.cross_entropy(logits, y, reduction="mean")

def eval_fn(model, X, y):
    logits = model(X)
    pred = nn.softmax(logits)
    return mx.mean(mx.argmax(pred, axis=1) == y)


In [6]:
# Start training loop

epochs = 50
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)


optimizer = optim.SGD(learning_rate=0.01)
for epoch in range(epochs):
    epoch_loss = 0.0
    train_acc = 0.0
    epoch_counter = 0
    for batch in get_streamed_data(mnist_train, batch_size=256):
        X, y = batch["image"], batch["label"]
        # Need to convert X and y into mlx.core.array type
        X, y = mx.array(X), mx.array(y)
        # Compute loss and its gradient with respect to the model's trainable parameters
        loss, grad = loss_and_grad_fn(model, X, y)
        # Step the optimizer
        optimizer.update(model, grad)
        # Evaluate computational graph
        mx.eval(model.parameters(), optimizer.state)
        epoch_loss += loss.item()
        epoch_counter += 1
        train_acc += eval_fn(model, X, y).item()
    epoch_loss /= epoch_counter
    train_acc /= epoch_counter

    test_acc_counter = 0.0
    test_acc = 0.0
    for batch in get_streamed_data(mnist_test, batch_size=32, shuffled=False):
        X, y = batch["image"], batch["label"]
        X, y = mx.array(X), mx.array(y)
        acc = eval_fn(model, X, y)
        test_acc += acc.item()
        test_acc_counter += 1
    test_acc /= test_acc_counter
    print(f"Epoch: {epoch} | Train Loss: {epoch_loss}, Train Accuracy: {train_acc} | Test Accuracy: {test_acc}")
        

Epoch: 0 | Train Loss: 0.47950050482090484, Train Accuracy: 0.8891788563829788 | Test Accuracy: 0.9236222044728435
Epoch: 1 | Train Loss: 0.1966099038403085, Train Accuracy: 0.9639572252618506 | Test Accuracy: 0.9353035143769968
Epoch: 2 | Train Loss: 0.1525057036508905, Train Accuracy: 0.9753379877577437 | Test Accuracy: 0.9518769968051118
Epoch: 3 | Train Loss: 0.13113700432029177, Train Accuracy: 0.980474290949233 | Test Accuracy: 0.9060503194888179
Epoch: 4 | Train Loss: 0.11118966062018212, Train Accuracy: 0.9848404255319149 | Test Accuracy: 0.9563698083067093
Epoch: 5 | Train Loss: 0.09633462732618159, Train Accuracy: 0.9876218973322117 | Test Accuracy: 0.9616613418530351
Epoch: 6 | Train Loss: 0.08666957761854567, Train Accuracy: 0.9899268617021276 | Test Accuracy: 0.9544728434504792
Epoch: 7 | Train Loss: 0.07647806671547129, Train Accuracy: 0.9918882978723405 | Test Accuracy: 0.9634584664536742
Epoch: 8 | Train Loss: 0.06889768651825316, Train Accuracy: 0.9930186170212766 | Te

In [14]:
# Testing out with one random sample
test_stream = get_streamed_data(mnist_test, batch_size=1)
test_batch = next(test_stream)
X, y = mx.array(test_batch["image"]), mx.array(test_batch["label"])

# See how the model produce logits
logits = model(X)
print(f"Logits: {logits}")

# See we can compute the softmax from the logits
softmax = nn.softmax(logits)
print(f"Softmax-ed: {softmax}")

# Get predicted label and true label
predicted_label = mx.argmax(softmax, axis=1).item()
confidence_level = mx.max(softmax, axis=1).item()
print(f"Predicted label: {predicted_label}, True label: {y.item()} | Confidence level: {confidence_level}")

Logits: array([[-1.13314, 0.136331, 1.44552, ..., -5.11404, 25.6358, -0.246675]], dtype=float32)
Softmax-ed: array([[2.36799e-12, 8.42763e-12, 3.1209e-11, ..., 4.42079e-14, 1, 5.74603e-12]], dtype=float32)
Predicted label: 8, True label: 8 | Confidence level: 1.0
