In [1]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.data import datasets
from datasets_utils import cifar100


In [2]:
def get_streamed_data(data, batch_size=0, shuffled=True):
    def transform_image(x):
        return x.astype("float32") / 255.0

    buffer = data.shuffle() if shuffled else data
    stream = buffer.to_stream()
    stream = stream.key_transform("image", transform_image)
    stream = stream.batch(batch_size) if batch_size > 0 else stream
    return stream.prefetch(4, 2)

In [3]:
train_data = datasets.load_cifar100(train=True)
test_data = datasets.load_cifar100(train=False)

len(train_data), len(test_data)

(50000, 10000)

In [4]:
# partition train_data into 5 partitions and return the first partition (index 0)
partitioned = train_data.partition(5, 0)
print(f"Partition length: {len(partitioned)}")

streamed_partition = get_streamed_data(partitioned, batch_size=1, shuffled=True)
batch = next(streamed_partition)
print(f"Batch image shape: {batch['image'].shape}\nBatch label shape: {batch['label'].shape}")
print(f"Original data length: {len(train_data)}")

Partition length: 10000
Batch image shape: (1, 32, 32, 3)
Batch label shape: (1,)
Original data length: 50000


In [5]:
class Model(nn.Module):
    def __init__(self, input_channel, input_width, conv_filters, output_dims):
        super().__init__()
        conv2d_kernel_size = 3
        conv2d_stride = 1
        conv2d_padding = 0

        pool2d_kernel_size = 2
        pool2d_stride = 2
        pool2d_padding = 0

        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(
                in_channels=input_channel,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=conv_filters,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool2d_kernel_size, stride=pool2d_stride, padding=pool2d_padding),
        )

        self.batch_norm1 = nn.BatchNorm(conv_filters)

        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(
                in_channels=conv_filters,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=conv_filters,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool2d_kernel_size, stride=pool2d_stride, padding=pool2d_padding),
        )

        self.batch_norm2 = nn.BatchNorm(conv_filters)

        self.fully_connected = nn.Sequential(
            nn.Linear(input_dims=5*5*conv_filters, output_dims=output_dims),
        )

    def __call__(self, x):
        x = self.conv_layer1(x)
        x = self.batch_norm1(x)
        x = self.conv_layer2(x)
        x = self.batch_norm2(x)
        x = mx.flatten(x, start_axis=1, end_axis=-1)
        x = self.fully_connected(x)
        return x

In [6]:
NUM_CLASSES = len(cifar100.labels)
INPUT_WIDTH = 32
INPUT_CHANNEL = 3
CONV_FILTERS = 128

model = Model(
    input_channel=INPUT_CHANNEL,
    input_width=INPUT_WIDTH,
    conv_filters=CONV_FILTERS,
    output_dims=NUM_CLASSES
)

mx.eval(model)
model

Model(
  (conv_layer1): Sequential(
    (layers.0): Conv2d(3, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.1): ReLU()
    (layers.2): Conv2d(128, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.3): ReLU()
    (layers.4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
  )
  (batch_norm1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_layer2): Sequential(
    (layers.0): Conv2d(128, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.1): ReLU()
    (layers.2): Conv2d(128, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.3): ReLU()
    (layers.4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
  )
  (batch_norm2): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fully_connected): Sequential(
    (layers.0): Linear(input_dims=3200

In [7]:
import mlx.optimizers as optim
from dataclasses import dataclass

@dataclass
class Experiment:
    model: nn.Module
    optimizer: optim.Optimizer
    train_losses = []
    train_accuracy = []
    validation_accuracy = []
    epoch: int = 30
    f1_score: float = 0.0
    recall_score: float = 0.0
    precision_score: float = 0.0
    description: str = ""

In [None]:
import trainer

# Create five experiments with each model having slightly different or even same
# Hyperparameters: learning rate, convolutional filters, etc.
# For the sake of fun :)

exp1 = Experiment(
    model=Model(input_channel=INPUT_CHANNEL,
                input_width=INPUT_WIDTH,
                output_dims=NUM_CLASSES,
                conv_filters=32),
    epoch=30,
    optimizer=optim.Adam(learning_rate=0.0001),
)

exp2 = Experiment(
    model=Model(input_channel=INPUT_CHANNEL,
                input_width=INPUT_WIDTH,
                output_dims=NUM_CLASSES,
                conv_filters=32),
    epoch=30,
    optimizer=optim.Adam(learning_rate=0.0005),
)

exp3 = Experiment(
    model=Model(input_channel=INPUT_CHANNEL,
                input_width=INPUT_WIDTH,
                output_dims=NUM_CLASSES,
                conv_filters=64),
    epoch=30,
    optimizer=optim.Adam(learning_rate=0.0001),
)

exp4 = Experiment(
    model=Model(input_channel=INPUT_CHANNEL,
                input_width=INPUT_WIDTH,
                output_dims=NUM_CLASSES,
                conv_filters=64),
    epoch=30,
    optimizer=optim.Adam(learning_rate=0.0005),
)

exp5 = Experiment(
    model=Model(input_channel=INPUT_CHANNEL,
                input_width=INPUT_WIDTH,
                output_dims=NUM_CLASSES,
                conv_filters=128),
    epoch=30,
    optimizer=optim.Adam(learning_rate=0.0001),
)

experiments = [exp1, exp2, exp3, exp4, exp5]
n_experiments = len(experiments)

def get_partitions(partition_start, partition_num):
    ranges = [i for i in range(partition_num)]
    ranges.remove(partition_start)
    return partition_start, ranges


for experiment_no, experiment in enumerate(experiments):
    validation_partition_idx, train_partitions = get_partitions(experiment_no, n_experiments)
    for epoch in range(experiment.epoch):
        train_losses = []
        train_accs = []
        train_throughputs = []
        for train_partition_idx in train_partitions:
            train_partition = train_data.partition(n_experiments, train_partition_idx)
            data = get_streamed_data(data=train_partition, batch_size=256, shuffled=True)
            train_loss, train_acc, throughput = trainer.train_epoch(
                experiment.model,
                data,
                experiment.optimizer,
                epoch,
                verbose=False,
            )
            train_losses.append(train_loss.item())
            train_accs.append(train_acc.item())
            train_throughputs.append(throughput.item())
        # Begin validation
        validation_partition = train_data.partition(n_experiments, validation_partition_idx)
        validation_data = get_streamed_data(data=validation_partition, batch_size=256, shuffled=False)
        validation_score = trainer.test_epoch(experiment.model, validation_data, epoch)

        train_losses = mx.mean(mx.array(train_losses))
        train_accs = mx.mean(mx.array(train_accs))
        train_throughputs = mx.mean(mx.array(train_throughputs))
    
        # Append results
        experiment.train_losses.append(train_losses.item())
        experiment.train_accuracy.append(train_accs.item())
        experiment.validation_accuracy.append(validation_score.item())

        print(" | ".join(
                (
                    f"Experiment: {experiment_no+1}",
                    f"Epoch: {epoch+1}",
                    f"avg. Train loss {train_losses.item():.3f}",
                    f"avg. Train acc {train_accs.item():.3f}",
                    f"avg. Validation score {validation_score.item():.3f}",
                    f"Throughput: {train_throughputs.item():.2f} images/sec",
                )))

Experiment: 1 | Epoch: 1 | avg. Train loss 4.533 | avg. Train acc 0.106 | avg. Validation score 0.148 | Throughput: 5918.00 images/sec
Experiment: 1 | Epoch: 2 | avg. Train loss 4.470 | avg. Train acc 0.176 | avg. Validation score 0.189 | Throughput: 5987.50 images/sec
Experiment: 1 | Epoch: 3 | avg. Train loss 4.444 | avg. Train acc 0.206 | avg. Validation score 0.201 | Throughput: 5993.67 images/sec
Experiment: 1 | Epoch: 4 | avg. Train loss 4.424 | avg. Train acc 0.228 | avg. Validation score 0.213 | Throughput: 6003.05 images/sec
Experiment: 1 | Epoch: 5 | avg. Train loss 4.409 | avg. Train acc 0.242 | avg. Validation score 0.218 | Throughput: 5992.96 images/sec
Experiment: 1 | Epoch: 6 | avg. Train loss 4.395 | avg. Train acc 0.259 | avg. Validation score 0.233 | Throughput: 5993.68 images/sec
Experiment: 1 | Epoch: 7 | avg. Train loss 4.383 | avg. Train acc 0.271 | avg. Validation score 0.241 | Throughput: 5998.39 images/sec
Experiment: 1 | Epoch: 8 | avg. Train loss 4.372 | avg.

In [None]:
# # get precision, recall, and f1-score
# from sklearn.metrics import f1_score, precision_score, recall_score
# import numpy as np

# test_stream = get_streamed_data(data=test_data, batch_size=256, shuffled=False)

# y_true = []
# y_pred = []
# model.eval()
# for batch in test_stream:
#     X, y = batch["image"], batch["label"]
#     X, y = mx.array(X), mx.array(y)
#     logits = model(X)
#     prediction = mx.argmax(mx.softmax(logits), axis=1)
#     y_true = y_true + y.tolist()
#     y_pred = y_pred + prediction.tolist()
    
# y_true = np.array(y_true)
# y_pred = np.array(y_pred)

# precision = precision_score(y_true, y_pred, average="weighted")
# recall = recall_score(y_true, y_pred, average="weighted")
# f1 = f1_score(y_true, y_pred, average="weighted")
# print(f"Precision: {precision}\nRecall: {recall}\nF1 Score: {f1}")
