In [1]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.data import datasets
from datasets_utils import cifar100


In [2]:
def get_streamed_data(data, batch_size=0, shuffled=True):
    def transform_image(x):
        return x.astype("float32") / 255.0

    buffer = data.shuffle() if shuffled else data
    stream = buffer.to_stream()
    stream = stream.key_transform("image", transform_image)
    stream = stream.batch(batch_size) if batch_size > 0 else stream
    return stream.prefetch(4, 2)

In [3]:
train_data = datasets.load_cifar100(train=True)
test_data = datasets.load_cifar100(train=False)

len(train_data), len(test_data)

(50000, 10000)

In [4]:
# partition train_data into 5 partitions and return the first partition (index 0)
partitioned = train_data.partition(5, 0)
print(f"Partition length: {len(partitioned)}")

streamed_partition = get_streamed_data(partitioned, batch_size=1, shuffled=True)
batch = next(streamed_partition)
print(f"Batch image shape: {batch['image'].shape}\nBatch label shape: {batch['label'].shape}")
print(f"Original data length: {len(train_data)}")

Partition length: 10000
Batch image shape: (1, 32, 32, 3)
Batch label shape: (1,)
Original data length: 50000


In [5]:
class Model(nn.Module):
    def __init__(self, input_channel, input_width, conv_filters, output_dims):
        super().__init__()
        conv2d_kernel_size = 3
        conv2d_stride = 1
        conv2d_padding = 0

        pool2d_kernel_size = 2
        pool2d_stride = 2
        pool2d_padding = 0

        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(
                in_channels=input_channel,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=conv_filters,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool2d_kernel_size, stride=pool2d_stride, padding=pool2d_padding),
        )

        self.batch_norm1 = nn.BatchNorm(conv_filters)

        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(
                in_channels=conv_filters,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=conv_filters,
                out_channels=conv_filters,
                kernel_size=(conv2d_kernel_size, conv2d_kernel_size),
                stride=conv2d_stride,
                padding=conv2d_padding
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool2d_kernel_size, stride=pool2d_stride, padding=pool2d_padding),
        )

        self.batch_norm2 = nn.BatchNorm(conv_filters)

        self.fully_connected = nn.Sequential(
            nn.Linear(input_dims=5*5*conv_filters, output_dims=output_dims),
        )

    def __call__(self, x):
        x = self.conv_layer1(x)
        x = self.batch_norm1(x)
        x = self.conv_layer2(x)
        x = self.batch_norm2(x)
        x = mx.flatten(x, start_axis=1, end_axis=-1)
        x = self.fully_connected(x)
        return x

In [6]:
NUM_CLASSES = len(cifar100.labels)
INPUT_WIDTH = 32
INPUT_CHANNEL = 3
CONV_FILTERS = 128

model = Model(
    input_channel=INPUT_CHANNEL,
    input_width=INPUT_WIDTH,
    conv_filters=CONV_FILTERS,
    output_dims=NUM_CLASSES
)

mx.eval(model)
model

Model(
  (conv_layer1): Sequential(
    (layers.0): Conv2d(3, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.1): ReLU()
    (layers.2): Conv2d(128, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.3): ReLU()
    (layers.4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
  )
  (batch_norm1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_layer2): Sequential(
    (layers.0): Conv2d(128, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.1): ReLU()
    (layers.2): Conv2d(128, 128, kernel_size=(3,), stride=(1, 1), padding=(0, 0), dilation=1, bias=True)
    (layers.3): ReLU()
    (layers.4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
  )
  (batch_norm2): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fully_connected): Sequential(
    (layers.0): Linear(input_dims=3200

In [7]:
import trainer
import mlx.optimizers as optim

epochs = 50
optimizer = optim.Adam(learning_rate=0.0001)
n_partitions = 5

def get_partitions(partition_start, partition_num):
    ranges = [i for i in range(partition_num)]
    ranges.remove(partition_start)
    return partition_start, ranges

validation_scores = {k: [] for k in range(n_partitions) }

for epoch in range(epochs):
    train_epoch_loss = []
    train_epoch_accuracy = []
    train_epoch_throughput = []
    shuffled_train_data = train_data.shuffle()
    for partition_no in range(n_partitions):
        # get validation partition and train partitions
        validation_partition_idx, train_partitions = get_partitions(partition_no, n_partitions)

        # train on all partitions except the validation partition
        for train_partition_idx in train_partitions:
            train_partition = shuffled_train_data.partition(n_partitions, train_partition_idx)
            data = get_streamed_data(data=train_partition, batch_size=256, shuffled=True)
            train_loss, train_acc, throughput = trainer.train_epoch(model, data, optimizer, epoch, verbose=False)

            train_epoch_loss.append(train_loss.item())
            train_epoch_accuracy.append(train_acc.item())
            train_epoch_throughput.append(throughput.item())
        validation_partition = shuffled_train_data.partition(n_partitions, validation_partition_idx)
        validation_data = get_streamed_data(data=validation_partition, batch_size=256, shuffled=False)
        validation_score = trainer.test_epoch(model, validation_data, epoch).item()
        validation_scores[partition_no].append(validation_score)
        
        print(" | ".join((f"Epoch: {epoch+1}", f"Partition: {partition_no}", f"Score: {validation_score}")))

    train_epoch_loss = mx.mean(mx.array(train_epoch_loss))
    train_epoch_accuracy = mx.mean(mx.array(train_epoch_accuracy))
    train_epoch_throughput = mx.mean(mx.array(train_epoch_throughput))
    print(" | ".join(
                (
                    f"Epoch: {epoch+1}",
                    f"avg. Train loss {train_epoch_loss.item():.3f}",
                    f"avg. Train acc {train_epoch_accuracy.item():.3f}",
                    f"Throughput: {train_epoch_throughput.item():.2f} images/sec",
                )))

Epoch: 1 | Partition: 0 | Score: 0.1699923723936081
Epoch: 1 | Partition: 1 | Score: 0.2138671875
Epoch: 1 | Partition: 2 | Score: 0.25708451867103577
Epoch: 1 | Partition: 3 | Score: 0.2903042733669281
Epoch: 1 | Partition: 4 | Score: 0.32500001788139343
Epoch: 1 | avg. Train loss 4.410 | avg. Train acc 0.227 | Throughput: 1370.69 images/sec
Epoch: 2 | Partition: 0 | Score: 0.3428710997104645
Epoch: 2 | Partition: 1 | Score: 0.3623046875
Epoch: 2 | Partition: 2 | Score: 0.38398438692092896
Epoch: 2 | Partition: 3 | Score: 0.4040471613407135
Epoch: 2 | Partition: 4 | Score: 0.43463611602783203
Epoch: 2 | avg. Train loss 4.273 | avg. Train acc 0.378 | Throughput: 1371.07 images/sec
Epoch: 3 | Partition: 0 | Score: 0.4432617127895355
Epoch: 3 | Partition: 1 | Score: 0.46004238724708557
Epoch: 3 | Partition: 2 | Score: 0.482614666223526
Epoch: 3 | Partition: 3 | Score: 0.5029004216194153
Epoch: 3 | Partition: 4 | Score: 0.5239062309265137
Epoch: 3 | avg. Train loss 4.180 | avg. Train acc 

In [8]:
# get precision, recall, and f1-score
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np

test_stream = get_streamed_data(data=test_data, batch_size=256, shuffled=False)

y_true = []
y_pred = []
model.eval()
for batch in test_stream:
    X, y = batch["image"], batch["label"]
    X, y = mx.array(X), mx.array(y)
    logits = model(X)
    prediction = mx.argmax(mx.softmax(logits), axis=1)
    y_true = y_true + y.tolist()
    y_pred = y_pred + prediction.tolist()
    
y_true = np.array(y_true)
y_pred = np.array(y_pred)

precision = precision_score(y_true, y_pred, average="weighted")
recall = recall_score(y_true, y_pred, average="weighted")
f1 = f1_score(y_true, y_pred, average="weighted")
print(f"Precision: {precision}\nRecall: {recall}\nF1 Score: {f1}")


Precision: 0.4562159620317154
Recall: 0.4623
F1 Score: 0.45719586228928
