In [2]:
# Install required CV dependencies for this environment using uv
!uv pip install -r /Users/amisra/dev/ERA-v4/requirements/cv.txt

[2mUsing Python 3.12.11 environment at: /Users/amisra/dev/ERA-v4/.venv[0m
[2mAudited [1m7 packages[0m [2min 6ms[0m[0m


In [3]:
import torch

In [4]:
assert torch.backends.mps.is_available(), "MPS acceleration not available on this Mac. Ensure PyTorch with MPS support is installed and Metal is enabled."
print("MPS acceleration available: True")

MPS acceleration available: True


In [5]:
device = torch.device("mps")
print(f"Using device: {device}")

Using device: mps


In [6]:
# Quick tensor sanity check on MPS
a = torch.randn(1024, 1024, device=device)
b = torch.randn(1024, 1024, device=device)
c = a @ b
print(f"Matmul successful on {c.device}")

Matmul successful on mps:0


In [None]:
import time


def benchmark_elementwise_mul(size: int = 10000, iters: int = 20, device: torch.device = torch.device("cpu")) -> tuple[float, float]:
    """Benchmark element-wise multiply on a preallocated tensor pair.
    Mirrors the separate CPU/MPS cells methodology while ensuring proper MPS sync.
    Returns (total_seconds, avg_seconds_per_iter).
    """
    torch.manual_seed(0)
    x = torch.rand((size, size), dtype=torch.float32, device=device)
    y = torch.rand((size, size), dtype=torch.float32, device=device)

    # Warmup a few runs to trigger kernel compilation/caching
    for _ in range(5):
        _ = x * y
        if device.type == "mps":
            torch.mps.synchronize()

    start = time.perf_counter()
    for _ in range(iters):
        _ = x * y
        if device.type == "mps":
            torch.mps.synchronize()
    end = time.perf_counter()

    total_seconds = end - start
    avg_seconds = total_seconds / iters
    return total_seconds, avg_seconds

# Calibrated to finish < ~5 seconds total on typical Apple Silicon while showing MPS speedup.
# 10k x 10k ~= 100M elements; adjust iters down if needed in your environment.
SIZE = 10000
ITERS = 20

cpu = torch.device("cpu")
mps = torch.device("mps")

print(f"Benchmark (element-wise multiply): {SIZE}x{SIZE}, {ITERS} iterations")

cpu_total, cpu_avg = benchmark_elementwise_mul(size=SIZE, iters=ITERS, device=cpu)
mps_total, mps_avg = benchmark_elementwise_mul(size=SIZE, iters=ITERS, device=mps)

print(f"CPU  total: {cpu_total*1000:.2f} ms  | per-op: {cpu_avg*1000:.2f} ms")
print(f"MPS  total: {mps_total*1000:.2f} ms  | per-op: {mps_avg*1000:.2f} ms")

speedup = cpu_total / mps_total if mps_total > 0 else float("inf")
print(f"Speedup (CPU/MPS): {speedup:.2f}x")


Benchmark: 512x512 matmul, 10 iterations
CPU  total: 2.3 ms  | per-op: 0.2 ms
MPS  total: 6.3 ms  | per-op: 0.6 ms
Speedup (CPU/MPS): 0.36x


In [8]:
import torch
device = torch.device('cpu')
x = torch.rand((10000, 10000), dtype=torch.float32)
y = torch.rand((10000, 10000), dtype=torch.float32)
x = x.to(device)
y = y.to(device)

In [9]:
%%timeit
x*y

31.9 ms ± 4.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
import torch
device = torch.device('mps')
x = torch.rand((10000, 10000), dtype=torch.float32)
y = torch.rand((10000, 10000), dtype=torch.float32)
x = x.to(device)
y = y.to(device)

In [11]:
%%timeit
x * y

1.56 μs ± 46.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
