The idea behind this notebook is to  explore Apple's MLX library and see how it performs.

In [None]:
import time
from functools import partial

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import mnist

# Torch is still required to load and process the data
import torchvision
import torch
import numpy as np
import torchvision

We load the data just like a standard notebook: 

In [31]:
def load_data(dataset = "mnist", 
              path = 'data', 
              train = True, 
              batch_size = 256, 
              transforms = torchvision.transforms.ToTensor(),
              download  = True):
    '''
    Returns the dataset and dataloader for the specified dataset.
    
    Supported datasets: [mnist, cifar, fashion, emnist, kmnist, svhn]
    '''
    if dataset.lower() == 'mnist':
        dataset = torchvision.datasets.MNIST(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'fashion':
        dataset = torchvision.datasets.FashionMNIST(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'cifar':
        dataset = torchvision.datasets.CIFAR10(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'emnist':
        dataset = torchvision.datasets.EMNIST(path, train=train, transform=transforms, download=download, split='letters')
    elif dataset.lower() == 'kmnist':
        dataset = torchvision.datasets.KMNIST(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'svhn':
        dataset = torchvision.datasets.SVHN(path + '/SVHN', split='train' if train else 'test', transform=transforms, download=download)
    else:
        raise ValueError('Invalid dataset. Options: [mnist, cifar, fashion, emnist, kmnist, svhn]')
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataset, loader

In [32]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.out = nn.Linear(64, 10)
    
    def __call__(self, x):
        x = self.fc1(x)
        x = nn.relu(x)
        x = self.fc2(x)
        x = nn.relu(x)
        x = self.fc3(x)
        x = nn.relu(x)
        x = self.out(x)
        return x

In [None]:
def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

def loss_fn(model, X, y):
    return nn.losses.cross_entropy(model(X), y, reduction="mean")

train_images, train_labels, test_images, test_labels = map(
    mx.array, getattr(mnist, "mnist")()
)

In [None]:
# create model
model = MLP()
# Set parameters
mx.eval(model.parameters())

# optim is from mlx.optimizers
optimizer = optim.SGD(learning_rate=0.01)
# nn is mlx.nn
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
    loss, grads = loss_and_grad_fn(model, X, y)
    optimizer.update(model, grads)
    return loss

@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)


for epoch in range(10):
    tic = time.perf_counter()
    train_loader = batch_iterate(64, train_images, train_labels)
    for X, y in train_loader:
        step(X, y)
        mx.eval(model.state)
    accuracy = eval_fn(test_images, test_labels)
    toc = time.perf_counter()
    print(
        f"Epoch: {epoch+1}, Test Accuracy: {accuracy}",
        f"Time: {toc - tic:.3f}"
    )


Epoch: 1, Test Accuracy: 0.5331000089645386 Time: 1.028
Epoch: 2, Test Accuracy: 0.8323999643325806 Time: 0.941
Epoch: 3, Test Accuracy: 0.8858000040054321 Time: 0.913
Epoch: 4, Test Accuracy: 0.9019999504089355 Time: 0.929
Epoch: 5, Test Accuracy: 0.9106000065803528 Time: 0.937
Epoch: 6, Test Accuracy: 0.9225999712944031 Time: 0.937
Epoch: 7, Test Accuracy: 0.9325999617576599 Time: 0.941
Epoch: 8, Test Accuracy: 0.9357999563217163 Time: 0.932
Epoch: 9, Test Accuracy: 0.9402999877929688 Time: 0.935
Epoch: 10, Test Accuracy: 0.9457999467849731 Time: 0.907


In [36]:
model = MLP()
mx.eval(model.parameters())

def loss_fn(model, X, y):
    return nn.losses.cross_entropy(model(X), y, reduction="mean")

optimizer = optim.SGD(learning_rate=0.01)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
    loss, grads = loss_and_grad_fn(model, X, y)
    optimizer.update(model, grads)
    return loss

@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)


_, train_loader = load_data(batch_size=64)
_, test_loader = load_data(batch_size=64, train=False)

for epoch in range(10):
    tic = time.perf_counter()
    for X, y in train_loader:
        X, y = mx.array(X), mx.array(y)
        X = X.flatten(start_axis=1)
        step(X, y)
        mx.eval(model.state)
    
    accuracy = 0
    n = 0

    for X, y in test_loader:
        X, y = mx.array(X), mx.array(y)
        X = X.flatten(start_axis=1)
        accuracy += (model(X).argmax(axis=1) == y).sum()
        n += len(y)

    toc = time.perf_counter()
    print(
        f"Epoch: {epoch+1}, Test Accuracy: {accuracy / n}",
        f"Time: {toc - tic:.3f}"
    )

Epoch: 1, Test Accuracy: 0.49810001254081726 Time: 3.200
Epoch: 2, Test Accuracy: 0.8302000164985657 Time: 3.153
Epoch: 3, Test Accuracy: 0.8851000070571899 Time: 3.232
Epoch: 4, Test Accuracy: 0.9004999995231628 Time: 3.211
Epoch: 5, Test Accuracy: 0.9106000065803528 Time: 3.123
Epoch: 6, Test Accuracy: 0.9247999787330627 Time: 3.130
Epoch: 7, Test Accuracy: 0.9301000237464905 Time: 3.160
Epoch: 8, Test Accuracy: 0.939300000667572 Time: 3.111
Epoch: 9, Test Accuracy: 0.9438999891281128 Time: 3.109
Epoch: 10, Test Accuracy: 0.9462000131607056 Time: 3.271
