In [None]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

In [None]:
BATCH_SIZE = 128
EPOCHS = 10   
BASE_LR = 0.001  
WEIGHT_DECAY = 1e-4

# 1. Get cifar 10 data

In [None]:
def get_cifar10_dataloaders(batch_size=128):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)
    return trainloader, testloader

In [None]:
def show_cifar_preview(dataloader, rows=4, cols=4):
    """
    Display a grid of CIFAR-10 images, each with its own label
    """
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
                'dog', 'frog', 'horse', 'ship', 'truck')

    # 1. get a Batch
    dataiter = iter(dataloader)
    images, labels = next(dataiter)

    # 2. define normalization parameters (must match Transforms)
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).reshape(3, 1, 1)
    std = torch.tensor([0.2023, 0.1994, 0.2010]).reshape(3, 1, 1)

    # figsize controls the size of the entire image, make it bigger to avoid crowding
    fig, axes = plt.subplots(rows, cols, figsize=(10, 10))

    # Flatten the multi-dimensional array for easier looping (e.g., 4x4 -> 16 axes)
    axes = axes.flatten()

    for i in range(rows * cols):
        ax = axes[i]
        
        if i < len(images):
            img = images[i]
            label_idx = labels[i].item()
            label_name = classes[label_idx]

            img = img * std + mean
            
            # Transpose dimensions: (C, H, W) -> (H, W, C)
            npimg = img.numpy().transpose(1, 2, 0)
            
            # Ensure values are within [0, 1] to avoid matplotlib warnings or display issues
            npimg = np.clip(npimg, 0, 1)

            ax.imshow(npimg)
            ax.set_title(label_name, fontsize=12, color='blue')
            ax.axis('off')
        else:
            ax.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
print("Loading Data...")
train_loader, test_loader = get_cifar10_dataloaders(batch_size=BATCH_SIZE)
show_cifar_preview(train_loader)

# 2. Setup Model

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, dims, heads=4):
        super().__init__()
        self.heads = heads
        self.scale = (dims // heads) ** -0.5
        self.qkv = nn.Linear(dims, dims * 3, bias=False)
        self.proj = nn.Linear(dims, dims)

    def __call__(self, x):
        B, H, W, C = x.shape
        x_flat = x.reshape(B, -1, C)
        N = x_flat.shape[1]
        
        qkv = self.qkv(x_flat)
        q, k, v = mx.split(qkv, 3, axis=-1)
        
        head_dim = C // self.heads
        q = q.reshape(B, N, self.heads, head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(B, N, self.heads, head_dim).transpose(0, 2, 1, 3)
        v = v.reshape(B, N, self.heads, head_dim).transpose(0, 2, 1, 3)
        
        attn = mx.softmax((q @ k.transpose(0, 1, 3, 2)) * self.scale, axis=-1)
        out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C)
        
        out = self.proj(out)
        return out.reshape(B, H, W, C)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm(out_channels)
        self.pool = pool
        
    def __call__(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = nn.relu(x)
        if self.pool:
            x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        return x

class CifarAttentionNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.prep = ConvBlock(3, 64)
        
        self.layer1_conv = ConvBlock(64, 128, pool=True)
        self.layer1_res = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
        
        self.layer2_conv = ConvBlock(128, 256, pool=True)
        self.attention = SelfAttention(dims=256, heads=4)
        
        self.layer3_conv = ConvBlock(256, 512, pool=True)
        self.layer3_res = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
        
        self.classifier = nn.Linear(512, num_classes)

    def __call__(self, x):
        x = self.prep(x)
        x = self.layer1_conv(x)
        x = x + self.layer1_res(x)
        x = self.layer2_conv(x)
        x = x + self.attention(x)
        x = self.layer3_conv(x)
        x = x + self.layer3_res(x)
        x = mx.max(x, axis=[1, 2])
        x = self.classifier(x)
        return x

# 3. Define loss function and plot history

In [None]:
def loss_fn(model, X, y):
    logits = model(X)
    return nn.losses.cross_entropy(logits, y, reduction="mean")

def eval_fn(model, loader):
    """calculate loss and accuracy on validation set"""
    correct = 0
    total = 0
    total_loss = 0.0
    steps = 0
    for X, y in loader:
        X = mx.array(X.numpy()).transpose(0, 2, 3, 1)
        y = mx.array(y.numpy())
        logits = model(X)
        
        # 計算 loss
        loss = nn.losses.cross_entropy(logits, y, reduction="mean")
        total_loss += loss.item()
        steps += 1
        
        # 計算 accuracy
        preds = mx.argmax(logits, axis=1)
        correct += mx.sum(preds == y).item()
        total += y.shape[0]
    
    return correct / total, total_loss / steps

def plot_history(history):
    acc = history['train_acc']
    val_acc = history['val_acc']
    loss = history['train_loss']
    val_loss = history['val_loss']
    epochs = range(1, len(acc) + 1)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, acc, 'b-', label='Training Acc')
    plt.plot(epochs, val_acc, 'r-', label='Validation Acc')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.grid(True, alpha=0.3)
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss, 'b-', label='Training Loss')
    plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    plt.tight_layout()
    plt.show()

In [None]:
def main():    
    print("Initializing Model...")
    model = CifarAttentionNet()
    mx.eval(model.parameters())
    
    optimizer = optim.AdamW(learning_rate=BASE_LR, weight_decay=WEIGHT_DECAY)
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    # use nn.value_and_grad to handle nn.Module
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

    def step(X, y):
        loss, grads = loss_and_grad_fn(model, X, y)
        optimizer.update(model, grads)
        return loss    
    for epoch in range(EPOCHS):
        model.train()
        
        # Cosine Annealing
        lr = 0.5 * BASE_LR * (1 + math.cos(math.pi * epoch / EPOCHS))
        optimizer.learning_rate = lr
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:02d}/{EPOCHS}")
        
        epoch_loss = 0.0
        epoch_correct = 0
        epoch_total = 0
        steps = 0
        
        for X, y in pbar:
            X = mx.array(X.numpy()).transpose(0, 2, 3, 1)
            y = mx.array(y.numpy())
            
            loss = step(X, y)
            mx.eval(model.state, optimizer.state)
            
            # calulate training accuracy
            logits = model(X)
            preds = mx.argmax(logits, axis=1)
            batch_correct = mx.sum(preds == y).item()
            batch_size = y.shape[0]
            
            l = loss.item()
            epoch_loss += l
            epoch_correct += batch_correct
            epoch_total += batch_size
            steps += 1
            
            pbar.set_postfix(loss=f"{l:.4f}", acc=f"{batch_correct/batch_size:.4f}", lr=f"{lr:.5f}")
        
        avg_train_loss = epoch_loss / steps
        avg_train_acc = epoch_correct / epoch_total
        
        # calculate accuracy and loss on validation set
        model.eval()
        val_acc, val_loss = eval_fn(model, test_loader)
        
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"  └─ Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    print(f"\nFinal Test Accuracy: {history['val_acc'][-1]:.2%}")
    plot_history(history)

if __name__ == "__main__":
    main()