In [None]:
# Install required CV dependencies for this environment using uv
!uv pip install -r /Users/amisra/dev/ERA-v4/requirements/cv.txt

[2mUsing Python 3.12.11 environment at: /Users/amisra/dev/ERA-v4/.venv[0m
[2mAudited [1m7 packages[0m [2min 224ms[0m[0m


In [6]:
import torch

In [7]:
assert torch.backends.mps.is_available(), "MPS acceleration not available on this Mac. Ensure PyTorch with MPS support is installed and Metal is enabled."
print("MPS acceleration available: True")

MPS acceleration available: True


In [9]:
device = torch.device("mps")
print(f"Using device: {device}")

Using device: mps


In [10]:
# Quick tensor sanity check on MPS
a = torch.randn(1024, 1024, device=device)
b = torch.randn(1024, 1024, device=device)
c = a @ b
print(f"Matmul successful on {c.device}")

Matmul successful on mps:0


In [None]:
import time

def benchmark_matmul(size: int = 512, iters: int = 10, device: torch.device = torch.device("cpu")) -> tuple[float, float]:
    torch.manual_seed(0)
    a = torch.randn(size, size, device=device)
    b = torch.randn(size, size, device=device)

    # Warmup
    for _ in range(3):
        _ = a @ b
        if device.type == "mps":
            torch.mps.synchronize()

    start = time.perf_counter()
    for _ in range(iters):
        _ = a @ b
        if device.type == "mps":
            torch.mps.synchronize()
    end = time.perf_counter()

    total_seconds = end - start
    avg_seconds = total_seconds / iters
    return total_seconds, avg_seconds

# Keep the test intentionally small to finish well under 5 seconds on typical Apple Silicon
size = 512
iters = 10

cpu = torch.device("cpu")
mps = torch.device("mps")

print(f"Benchmark: {size}x{size} matmul, {iters} iterations")

cpu_total, cpu_avg = benchmark_matmul(size=size, iters=iters, device=cpu)
mps_total, mps_avg = benchmark_matmul(size=size, iters=iters, device=mps)

print(f"CPU  total: {cpu_total*1000:.1f} ms  | per-op: {cpu_avg*1000:.1f} ms")
print(f"MPS  total: {mps_total*1000:.1f} ms  | per-op: {mps_avg*1000:.1f} ms")

speedup = cpu_total / mps_total if mps_total > 0 else float('inf')
print(f"Speedup (CPU/MPS): {speedup:.2f}x")
