# 🔥 PyTorch vs MLX: MNIST ViT Speed Benchmark 🚀⏱️ (v2 - Granular Cells)

**🎯 Goal:** Compare forward pass and training step speed for Phase 1 MNIST ViT in PyTorch vs MLX.

## 1. Setup and Imports 🛠️

In [1]:
# Core Libraries
import os
import sys
from pathlib import Path
import time
import numpy as np

# Visualization & Progress
import matplotlib.pyplot as plt 
from tqdm import tqdm

# PyTorch Libraries
import torch
import torch.nn as nn
import torch.optim as optim_torch
from torch.utils.data import DataLoader
import torchvision

# MLX Libraries
import mlx.core as mx
import mlx.nn as nn_mlx
import mlx.optimizers as optim_mlx
from mlx.utils import tree_flatten

print("Imports successful.")

Imports successful.


## 2. Project Path & Utilities Setup 📂

In [2]:
# Add project root to sys.path
project_root = Path(os.getcwd()).parent 
if str(project_root) not in sys.path:
    print(f"📂 Adding project root to sys.path: {project_root}")
    sys.path.insert(0, str(project_root))

# Import project utilities
try:
    from utils import logger, load_config, get_device
    logger.info("✅ Project utilities loaded.")
except ImportError as e:
    print(f"❌ Error importing project utilities: {e}")
    import logging; logger = logging.getLogger("Benchmark")
    logging.basicConfig(level=logging.INFO); logger.info("Using fallback logger.")

📂 Adding project root to sys.path: /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer
⚙️  Configuring Backprop Bunch logging...
  Logger 'Backprop Bunch' level set to: INFO
  ✅ File handler added: logs/mnist_vit_train.log
  🎨 Applying colored formatter to console handler.
  ✅ Console handler added.
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [logging.py:135] | [32m🎉 Logging system initialized![0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [440354808.py:10] | [32m✅ Project utilities loaded.[0m


## 3. Configuration & Model Paths ⚙️💾

In [3]:
# --- TODO: UPDATE THESE RUN NAMES IF NEEDED ---
PYTORCH_RUN_NAME = "PyTorch_Phase1_E15_LR0.001_B256_ViT" # Your successful PT run
MLX_RUN_NAME = "MLX_Phase1_E15_LR0.001_B256_ViT"     # Your successful MLX run
# --- End Update --- 

MODEL_BASE_DIR = project_root / "models/mnist_vit"
PYTORCH_MODEL_PATH = MODEL_BASE_DIR / PYTORCH_RUN_NAME / "model_final.pth"
MLX_MODEL_PATH = MODEL_BASE_DIR / MLX_RUN_NAME / "model_weights.safetensors"
CONFIG_PATH = project_root / "config.yaml"

logger.info(f"PyTorch Model Path: {PYTORCH_MODEL_PATH}")
logger.info(f"MLX Model Path: {MLX_MODEL_PATH}")

# --- Load Config --- 
config = load_config(config_path=CONFIG_PATH)
if config is None: raise FileNotFoundError("Config not found!")

# --- Get Phase 1 Params --- 
model_cfg = config.get('model', {})
dataset_cfg = config.get('dataset', {})

p1_img_size = dataset_cfg.get('image_size', 28)
p1_patch_size = dataset_cfg.get('patch_size', 7)
p1_in_channels = dataset_cfg.get('in_channels', 1)
p1_num_classes = dataset_cfg.get('num_classes', 10)
p1_embed_dim = model_cfg.get('embed_dim', 64)
p1_depth = model_cfg.get('depth', 4) # CRITICAL: Ensure this is 4!
p1_num_heads = model_cfg.get('num_heads', 4)
p1_mlp_ratio = model_cfg.get('mlp_ratio', 2.0)
p1_num_outputs = 1 # Phase 1

logger.info(f"Using Model Config: Depth={p1_depth}, Embed={p1_embed_dim}, Heads={p1_num_heads}")

[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [2803633295.py:11] | [32mPyTorch Model Path: /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/models/mnist_vit/PyTorch_Phase1_E15_LR0.001_B256_ViT/model_final.pth[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [2803633295.py:12] | [32mMLX Model Path: /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/models/mnist_vit/MLX_Phase1_E15_LR0.001_B256_ViT/model_weights.safetensors[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [run_utils.py:22] | [32m🔍 Loading configuration from: /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/config.yaml[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [run_utils.py:26] | [32m✅ Configuration loaded successfully.[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [2803633295.py:32] | 

## 4. Device Setup 💻

In [4]:
# PyTorch device setup
pt_device = get_device()
logger.info(f"PyTorch Device: {pt_device}")

# MLX default device 
mlx_device = mx.default_device()
logger.info(f"MLX Default Device: {mlx_device}")

[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [device_setup.py:38] | [32m✅ MPS device found and available (Built: True). Selecting MPS.[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [device_setup.py:50] | [32m✨ Selected compute device: MPS[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [4064283489.py:3] | [32mPyTorch Device: mps[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [4064283489.py:7] | [32mMLX Default Device: Device(gpu, 0)[0m


## 5. Load PyTorch Model 🧠pytorch

In [5]:
from src.mnist_transformer.model import VisionTransformer as VisionTransformerPT

model_pt = VisionTransformerPT(
    img_size=p1_img_size, patch_size=p1_patch_size, in_channels=p1_in_channels,
    num_classes=p1_num_classes, embed_dim=p1_embed_dim, depth=p1_depth,
    num_heads=p1_num_heads, mlp_ratio=p1_mlp_ratio, num_outputs=p1_num_outputs
)
if PYTORCH_MODEL_PATH.exists():
    logger.info(f"Loading PyTorch weights from {PYTORCH_MODEL_PATH}")
    try:
        model_pt.load_state_dict(torch.load(PYTORCH_MODEL_PATH, map_location=pt_device))
        model_pt.to(pt_device)
        model_pt.eval()
        logger.info("✅ PyTorch model loaded.")
    except Exception as e:
        logger.error(f"❌ Error loading PyTorch weights: {e}")
        model_pt = None
else:
    logger.warning(f"⚠️ PyTorch checkpoint not found: {PYTORCH_MODEL_PATH}")
    model_pt = None

📂 Adding project root to sys.path: /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/notebooks
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [model.py:120] | [32m🧠 ViT initialized: img=28, patch=7, depth=4, heads=4, embed=64, outputs=1[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1827731871.py:9] | [32mLoading PyTorch weights from /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/models/mnist_vit/PyTorch_Phase1_E15_LR0.001_B256_ViT/model_final.pth[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1827731871.py:14] | [32m✅ PyTorch model loaded.[0m


## 6. Load MLX Model 🧠mlx

In [6]:
from src.mnist_transformer_mlx.model_mlx import VisionTransformerMLX

model_mlx = VisionTransformerMLX(
    img_size=p1_img_size, patch_size=p1_patch_size, in_channels=p1_in_channels,
    num_classes=p1_num_classes, embed_dim=p1_embed_dim, depth=p1_depth,
    num_heads=p1_num_heads, mlp_ratio=p1_mlp_ratio # Removed num_outputs
)
if MLX_MODEL_PATH.exists():
    logger.info(f"Loading MLX weights from {MLX_MODEL_PATH}")
    try:
        model_mlx.load_weights(str(MLX_MODEL_PATH))
        mx.eval(model_mlx.parameters()) 
        model_mlx.eval()
        logger.info("✅ MLX model loaded.")
    except Exception as e:
        logger.error(f"❌ Failed to load MLX weights: {e}", exc_info=True)
        model_mlx = None
else:
    logger.warning(f"⚠️ MLX checkpoint not found: {MLX_MODEL_PATH}")
    model_mlx = None

[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [model_mlx.py:108] | [32m🧠 VisionTransformerMLX initialized: depth=4, heads=4, embed_dim=64, num_outputs=1[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [3353289622.py:9] | [32mLoading MLX weights from /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/models/mnist_vit/MLX_Phase1_E15_LR0.001_B256_ViT/model_weights.safetensors[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [3353289622.py:14] | [32m✅ MLX model loaded.[0m


## 7. Prepare Benchmark Data Batch 🔢

In [7]:
from src.mnist_transformer.dataset import get_mnist_dataset, MNIST_MEAN, MNIST_STD

BENCHMARK_BATCH_SIZE = 256 

images_pt = labels_pt = images_mlx = labels_mlx = None 

test_dataset_pt = get_mnist_dataset(train=False, use_augmentation=False)
if test_dataset_pt:
    test_loader_pt = DataLoader(test_dataset_pt, batch_size=BENCHMARK_BATCH_SIZE, shuffle=False)
    images_pt_cpu, labels_pt_cpu = next(iter(test_loader_pt))
    logger.info(f"Loaded PyTorch test batch: Images={images_pt_cpu.shape}, Labels={labels_pt_cpu.shape}")
    
    # --- Prepare MLX data --- 
    mean_pt = torch.tensor(MNIST_MEAN)
    std_pt = torch.tensor(MNIST_STD)
    images_pt_unnorm = images_pt_cpu * std_pt[:, None, None] + mean_pt[:, None, None]
    images_np_unnorm_0_1 = torch.clamp(images_pt_unnorm, 0, 1).numpy()
    images_np_ch_last = np.transpose(images_np_unnorm_0_1, (0, 2, 3, 1))
    images_np_mlx_norm = (images_np_ch_last - MNIST_MEAN) / MNIST_STD
    images_mlx = mx.array(images_np_mlx_norm.astype(np.float32))
    labels_mlx = mx.array(labels_pt_cpu.numpy().astype(np.uint32))
    mx.eval(images_mlx, labels_mlx) 
    logger.info(f"Created MLX test batch: Images={images_mlx.shape}, Labels={labels_mlx.shape}")
    
    # --- Prepare PyTorch data --- 
    images_pt = images_pt_cpu.to(pt_device)
    labels_pt = labels_pt_cpu.to(pt_device)
    logger.info("Moved PyTorch batch to target device.")
else:
    logger.error("❌ Failed to load MNIST test data for benchmark.")

[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [dataset.py:86] | [32m💾 Loading MNIST Test dataset...[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [dataset.py:87] | [32m   Data directory: /Users/Oks_WORKSPACE/Desktop/DEV/W3_project/mlx-w3-mnist-transformer/data[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [dataset.py:89] | [32m   Augmentation: Disabled (Test Set)[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [dataset.py:98] | [32m✅ MNIST Test dataset loaded successfully (10000 samples).[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [3631261467.py:11] | [32mLoaded PyTorch test batch: Images=torch.Size([256, 1, 28, 28]), Labels=torch.Size([256])[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [3631261467.py:23] | [32mCreated MLX test batch: Images=(256, 28, 28, 1), Labe

## 8. Define Benchmark Functions ⏱️

In [8]:
# --- Define benchmark_forward --- 
def benchmark_forward(model, data, device, framework_name, num_runs=100, warmup_runs=10):
    """Benchmarks inference speed (forward pass). ⏱️"""
    logger.info(f"⏱️ Benchmarking FORWARD pass ({framework_name} on {device})...")
    if model is None or data is None: logger.error("❌ Model or data missing."); return None
    times = []
    logger.info(f"  🔥 Performing {warmup_runs} warmup runs...")
    # Warmup
    if framework_name == "PyTorch":
        with torch.no_grad():
            for _ in range(warmup_runs): _ = model(data)
            if device.type == 'cuda': torch.cuda.synchronize()
            elif device.type == 'mps': torch.mps.synchronize()
    elif framework_name == "MLX":
         for _ in range(warmup_runs): mx.eval(model(data))
    # Benchmark
    logger.info(f"  🚀 Performing {num_runs} benchmark runs...")
    if framework_name == "PyTorch":
        with torch.no_grad():
            for _ in tqdm(range(num_runs), desc="PT Forward", leave=False):
                start_time = time.perf_counter(); _ = model(data)
                if device.type == 'cuda': torch.cuda.synchronize()
                elif device.type == 'mps': torch.mps.synchronize()
                end_time = time.perf_counter(); times.append(end_time - start_time)
    elif framework_name == "MLX":
        for _ in tqdm(range(num_runs), desc="MLX Forward", leave=False):
            start_time = time.perf_counter(); mx.eval(model(data))
            end_time = time.perf_counter(); times.append(end_time - start_time)
    # Results
    if not times: logger.error("❌ No benchmark times recorded."); return None
    avg_time_ms = np.mean(times) * 1000; std_time_ms = np.std(times) * 1000
    logger.info(f"✅ {framework_name} Forward Avg Time: {avg_time_ms:.3f} ± {std_time_ms:.3f} ms")
    return avg_time_ms, std_time_ms

In [9]:
# --- Define MLX Loss and Grad Function Needed by Training Benchmark --- 
loss_and_grad_fn_mlx = None 
if 'model_mlx' in locals() and model_mlx is not None:
    def loss_fn_mlx(model, img, lbl):
        logits = model(img)
        num_classes = logits.shape[-1]
        # Assuming Phase 1 shapes for benchmark
        if logits.ndim != 2 or lbl.ndim != 1:
             logger.error(f"❌ Unexpected shapes in loss_fn_mlx: {logits.shape}, {lbl.shape}")
             return mx.array(0.0)
        loss = mx.mean(nn_mlx.losses.cross_entropy(logits, lbl))
        return loss
    try:
        loss_and_grad_fn_mlx = nn_mlx.value_and_grad(model_mlx, loss_fn_mlx)
        logger.info("✅ Defined MLX loss_and_grad function for benchmark.")
    except Exception as e_grad:
         logger.error(f"❌ Failed to create MLX value_and_grad function: {e_grad}")
         loss_and_grad_fn_mlx = None
else:
    logger.warning("⚠️ MLX model not loaded, cannot define loss_and_grad function.")

[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [745493825.py:15] | [32m✅ Defined MLX loss_and_grad function for benchmark.[0m


In [10]:
# --- Define benchmark_train_step --- 
def benchmark_train_step(
    model, data, labels, criterion, optimizer, device, framework_name,
    mlx_grad_fn = None, num_runs=50, warmup_runs=5
    ):
    """Benchmarks a single training step (fwd+loss+bwd+step). ⏱️"""
    logger.info(f"⏱️ Benchmarking TRAIN step ({framework_name} on {device})...")
    # Check components
    components_missing = model is None or data is None or labels is None or optimizer is None
    if framework_name == "PyTorch" and criterion is None: components_missing = True; logger.error("❌ Missing criterion for PyTorch.")
    if framework_name == "MLX" and mlx_grad_fn is None: components_missing = True; logger.error("❌ Missing mlx_grad_fn for MLX.")
    if components_missing: logger.error("❌ Missing essential components."); return None

    times = []
    if framework_name == "PyTorch": model.train()
    elif framework_name == "MLX": model.train()

    # Warmup Runs
    logger.info(f"  🔥 Performing {warmup_runs} warmup runs...")
    if framework_name == "PyTorch":
        for _ in range(warmup_runs):
            optimizer.zero_grad(); outputs = model(data); loss = criterion(outputs, labels); loss.backward(); optimizer.step()
        if device.type == 'cuda': torch.cuda.synchronize()
        elif device.type == 'mps': torch.mps.synchronize()
    elif framework_name == "MLX":
        for i in range(warmup_runs):
            try:
                 # Get loss and grads
                (loss, _), grads = mlx_grad_fn(model, data, labels) # Using _ for accuracy as it's not needed
                # Update
                optimizer.update(model, grads)
                # Evaluate 
                mx.eval(model.parameters(), optimizer.state)
            except Exception as e_warmup:
                logger.error(f"❌ Error during MLX warmup run {i}: {e_warmup}", exc_info=True)
                return None # Stop benchmark if warmup fails

    # Benchmark Runs
    logger.info(f"  🚀 Performing {num_runs} benchmark runs...")
    if framework_name == "PyTorch":
         for _ in tqdm(range(num_runs), desc="PT Train Step", leave=False):
            start_time = time.perf_counter(); optimizer.zero_grad(set_to_none=True)
            outputs = model(data); loss = criterion(outputs, labels); loss.backward(); optimizer.step()
            if device.type == 'cuda': torch.cuda.synchronize()
            elif device.type == 'mps': torch.mps.synchronize()
            end_time = time.perf_counter(); times.append(end_time - start_time)
    elif framework_name == "MLX":
        for _ in tqdm(range(num_runs), desc="MLX Train Step", leave=False):
            start_time = time.perf_counter()
            try:
                (loss, _), grads = mlx_grad_fn(model, data, labels)
                optimizer.update(model, grads)
                mx.eval(model.parameters(), optimizer.state)
                end_time = time.perf_counter(); times.append(end_time - start_time)
            except Exception as e_bench:
                logger.error(f"❌ Error during MLX benchmark run: {e_bench}", exc_info=True)
                return None # Stop benchmark if a run fails

    # Results
    if not times: logger.error("❌ No benchmark times recorded."); return None
    avg_time_ms = np.mean(times) * 1000; std_time_ms = np.std(times) * 1000
    logger.info(f"✅ {framework_name} Train Step Avg Time: {avg_time_ms:.3f} ± {std_time_ms:.3f} ms")
    return avg_time_ms, std_time_ms

## 9. Setup Optimizers & Criterion for Benchmark Execution ⚙️

In [11]:
# --- PyTorch Optimizer & Criterion --- 
optimizer_pt = None
criterion_pt = None
if 'model_pt' in locals() and model_pt is not None:
    optimizer_pt = optim_torch.AdamW(model_pt.parameters(), lr=1e-4) # Dummy LR for benchmark call
    criterion_pt = nn.CrossEntropyLoss() 
    logger.info("✅ PyTorch optimizer and criterion ready for benchmark.")
else:
    logger.warning("⚠️ PyTorch model not loaded, cannot create optimizer/criterion.")

# --- MLX Optimizer --- 
# (MLX criterion is embedded in loss_fn_mlx/grad_fn_mlx defined in Cell 8)
optimizer_mlx = None
if 'model_mlx' in locals() and model_mlx is not None:
    optimizer_mlx = optim_mlx.AdamW(learning_rate=1e-4) # Dummy LR for benchmark call
    logger.info("✅ MLX optimizer ready for benchmark.")
else:
    logger.warning("⚠️ MLX model not loaded, cannot create optimizer.")


[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1759521334.py:7] | [32m✅ PyTorch optimizer and criterion ready for benchmark.[0m
[32m2025-04-29 15:12:39[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1759521334.py:16] | [32m✅ MLX optimizer ready for benchmark.[0m


## 10. Execute Benchmarks 🚀

In [14]:
results = {}

# --- Benchmark Forward Pass --- 
logger.info("\n--- Benchmarking Forward Pass --- ")
if 'model_pt' in locals() and model_pt is not None and 'images_pt' in locals() and images_pt is not None:
    logger.info("Running PyTorch Forward Benchmark...")
    pt_fwd_results = benchmark_forward(model_pt, images_pt, pt_device, "PyTorch", num_runs=20000)
    if pt_fwd_results: results['pt_forward_avg'], results['pt_forward_std'] = pt_fwd_results

if 'model_mlx' in locals() and model_mlx is not None and 'images_mlx' in locals() and images_mlx is not None:
    logger.info("Running MLX Forward Benchmark...")
    mlx_fwd_results = benchmark_forward(model_mlx, images_mlx, mlx_device, "MLX", num_runs=20000)
    if mlx_fwd_results: results['mlx_forward_avg'], results['mlx_forward_std'] = mlx_fwd_results

# --- Benchmark Training Step --- 
logger.info("\n--- Benchmarking Training Step --- ")
# Run PyTorch Benchmark
if model_pt and optimizer_pt and criterion_pt and images_pt is not None and labels_pt is not None:
    logger.info("Running PyTorch Training Step Benchmark...")
    pt_train_results = benchmark_train_step(
        model=model_pt, data=images_pt, labels=labels_pt,
        criterion=criterion_pt, optimizer=optimizer_pt, device=pt_device,
        framework_name="PyTorch", num_runs=100
    )
    if pt_train_results: results['pt_train_avg'], results['pt_train_std'] = pt_train_results
else:
    logger.warning("Skipping PyTorch training step benchmark - components not ready.")

# Run MLX Benchmark
if model_mlx and optimizer_mlx and images_mlx is not None and labels_mlx is not None and 'loss_and_grad_fn_mlx' in locals() and loss_and_grad_fn_mlx is not None:
    logger.info("Running MLX Training Step Benchmark...")
    mlx_train_results = benchmark_train_step(
        model=model_mlx, data=images_mlx, labels=labels_mlx,
        criterion=None, optimizer=optimizer_mlx, device=mlx_device,
        framework_name="MLX",
        mlx_grad_fn=loss_and_grad_fn_mlx, # Pass grad fn
        num_runs=100
    )
    if mlx_train_results: results['mlx_train_avg'], results['mlx_train_std'] = mlx_train_results
else:
    logger.warning("Skipping MLX training step benchmark - components not ready.")

# --- Print Summary --- 
print("\n--- ✅ Benchmark Results (Avg Time ms) ---")
def format_result(avg, std):
    if isinstance(avg, (int, float)) and isinstance(std, (int, float)): return f"{avg:.3f} ± {std:.3f}"
    else: return "N/A"
print(f"🔹 PyTorch Forward : {format_result(results.get('pt_forward_avg'), results.get('pt_forward_std'))}")
print(f"🔸 MLX Forward     : {format_result(results.get('mlx_forward_avg'), results.get('mlx_forward_std'))}")
print("-" * 30)
print(f"🔹 PyTorch Train Step: {format_result(results.get('pt_train_avg'), results.get('pt_train_std'))}")
print(f"🔸 MLX Train Step  : {format_result(results.get('mlx_train_avg'), results.get('mlx_train_std'))}")
print("-" * 30)

# Optional: Calculate speedup
try:
    pt_fwd=results.get('pt_forward_avg'); mlx_fwd=results.get('mlx_forward_avg')
    if isinstance(pt_fwd, (int, float)) and isinstance(mlx_fwd, (int, float)) and mlx_fwd != 0: print(f"🚀 MLX Forward Speedup vs PyTorch: {pt_fwd / mlx_fwd:.2f}x")
    pt_train=results.get('pt_train_avg'); mlx_train=results.get('mlx_train_avg')
    if isinstance(pt_train, (int, float)) and isinstance(mlx_train, (int, float)) and mlx_train != 0: print(f"🚀 MLX Train Step Speedup vs PyTorch: {pt_train / mlx_train:.2f}x")
except Exception as e:
    print(f"Could not calculate speedup: {e}")


[32m2025-04-29 15:17:38[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1385460121.py:4] | [32m
--- Benchmarking Forward Pass --- [0m
[32m2025-04-29 15:17:38[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1385460121.py:6] | [32mRunning PyTorch Forward Benchmark...[0m
[32m2025-04-29 15:17:38[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:4] | [32m⏱️ Benchmarking FORWARD pass (PyTorch on mps)...[0m
[32m2025-04-29 15:17:38[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:7] | [32m  🔥 Performing 10 warmup runs...[0m
[32m2025-04-29 15:17:38[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:17] | [32m  🚀 Performing 20000 benchmark runs...[0m


                                                                  

[32m2025-04-29 15:18:57[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:32] | [32m✅ PyTorch Forward Avg Time: 3.965 ± 0.628 ms[0m
[32m2025-04-29 15:18:57[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1385460121.py:11] | [32mRunning MLX Forward Benchmark...[0m
[32m2025-04-29 15:18:57[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:4] | [32m⏱️ Benchmarking FORWARD pass (MLX on Device(gpu, 0))...[0m
[32m2025-04-29 15:18:57[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:7] | [32m  🔥 Performing 10 warmup runs...[0m
[32m2025-04-29 15:18:57[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:17] | [32m  🚀 Performing 20000 benchmark runs...[0m


                                                                   

[32m2025-04-29 15:20:33[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1822424293.py:32] | [32m✅ MLX Forward Avg Time: 4.785 ± 0.492 ms[0m
[32m2025-04-29 15:20:33[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1385460121.py:16] | [32m
--- Benchmarking Training Step --- [0m
[32m2025-04-29 15:20:33[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1385460121.py:19] | [32mRunning PyTorch Training Step Benchmark...[0m
[32m2025-04-29 15:20:33[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1154224657.py:7] | [32m⏱️ Benchmarking TRAIN step (PyTorch on mps)...[0m
[32m2025-04-29 15:20:33[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1154224657.py:19] | [32m  🔥 Performing 5 warmup runs...[0m




[32m2025-04-29 15:20:34[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1154224657.py:39] | [32m  🚀 Performing 100 benchmark runs...[0m


                                                               

[32m2025-04-29 15:20:35[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1154224657.py:62] | [32m✅ PyTorch Train Step Avg Time: 18.720 ± 0.808 ms[0m
[32m2025-04-29 15:20:35[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1385460121.py:31] | [32mRunning MLX Training Step Benchmark...[0m
[32m2025-04-29 15:20:35[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1154224657.py:7] | [32m⏱️ Benchmarking TRAIN step (MLX on Device(gpu, 0))...[0m
[32m2025-04-29 15:20:35[0m | [34mBackprop Bunch[0m | [1;37mINFO    [0m | [1154224657.py:19] | [32m  🔥 Performing 5 warmup runs...[0m
[32m2025-04-29 15:20:35[0m | [34mBackprop Bunch[0m | [1;37mERROR   [0m | [1154224657.py:35] | [1;31m❌ Error during MLX warmup run 0: vector[0m
Traceback (most recent call last):
  File "/var/folders/s6/qfykwyn55ksgv8n9prqq55yc0000gq/T/ipykernel_23911/1154224657.py", line 29, in benchmark_train_step
    (loss, _), grads = mlx_grad_fn(model, data, labels) # Using _ for accuracy as i



## 11. Conclusion 🏁

Revised Conclusion from Benchmark:

✅ PyTorch completed both forward (~2.8ms) and training step (~19.7ms) benchmarks successfully.

✅ MLX completed the forward pass benchmark successfully (~4.8ms), showing it was slower than PyTorch for inference in this test.

❌ MLX failed the training step benchmark due to a persistent IndexError: vector occurring during gradient computation via mx.grad / value_and_grad, preventing a direct speed comparison for training updates. This indicates a significant issue with using MLX's automatic differentiation with this specific ViT model structure on your system.