In [1]:
import mlx.core as mx
import mlx.nn as nn
from sklearn import datasets

In [2]:
# Create toy dataset

X, y = datasets.make_blobs(
    n_samples=100000,
    centers=2,
    n_features=4,
    random_state=42
)

X.shape, y.shape

((100000, 4), (100000,))

In [3]:
# Initialize weights matrix and bias

mx.random.seed(1234)

b = mx.random.normal(shape=[1])
W = mx.random.normal(shape=(X.shape[-1],))
W.shape

(4,)

In [4]:
# Split training and test data set
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train.shape, X_test.shape

((80000, 4), (20000, 4))

In [5]:
# Training loop
epochs = 100

def predict_logits(X, W, b):
    matmul = X @ W + b
    return nn.sigmoid(matmul)

def loss_fn(X, W, b, y_true):
    logits = predict_logits(X, W, b)
    return nn.losses.binary_cross_entropy(logits, y_true)

# Compute gradient with respect to the argument at index-1: W
value_and_grad = mx.value_and_grad(loss_fn, 1)

X_train = mx.array(X_train)
y_train = mx.array(y_train)

for epoch in range(epochs):
    loss, grad = value_and_grad(X_train, W, b, y_train)
    W = W - 0.1 * grad
    mx.eval(W)
    if epoch % 10 == 0:
        print(f"Epoch: {epoch} | Loss: {loss}")

Epoch: 0 | Loss: array(0.505014, dtype=float32)
Epoch: 10 | Loss: array(0.504643, dtype=float32)
Epoch: 20 | Loss: array(0.504379, dtype=float32)
Epoch: 30 | Loss: array(0.504182, dtype=float32)
Epoch: 40 | Loss: array(0.504029, dtype=float32)
Epoch: 50 | Loss: array(0.503907, dtype=float32)
Epoch: 60 | Loss: array(0.503808, dtype=float32)
Epoch: 70 | Loss: array(0.503725, dtype=float32)
Epoch: 80 | Loss: array(0.503656, dtype=float32)
Epoch: 90 | Loss: array(0.503596, dtype=float32)


In [6]:
# Test loop
correct_nums = 0
X_test = mx.array(X_test)
y_test = mx.array(y_test)

for X, y in zip(X_test, y_test):
    logits = predict_logits(X, W, b)
    prediction = mx.round(logits)
    if prediction == y:
        correct_nums += 1

accuracy = correct_nums / len(y_test)
print(f"Accuracy: {accuracy}")

Accuracy: 1.0
