Skip to content

Commit

Permalink
⚡️ add mlx_cpu backend
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Jan 3, 2024
1 parent deda7a6 commit ed51428
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
1 change: 0 additions & 1 deletion mlx_benchmark/operations/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.nn.functional as F

from base_benchmark import BaseBenchmark
from utils import load_mnist


class Conv2d(BaseBenchmark):
Expand Down
16 changes: 14 additions & 2 deletions mlx_benchmark/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
from tqdm import tqdm

from config import USE_MLX

if USE_MLX:
import mlx.core as mx
import mlx.nn as mx_nn

from utils import print_benchmark
from operations import *


def run_processes(operations, args):
"""
Runs all operations (i.e. operations) in serial, on separate processes.
Runs all operations in serial, on separate processes.
Using processes avoids exploding memory within the main process during the bench.
"""
all_times = defaultdict(dict)
Expand Down Expand Up @@ -54,8 +59,15 @@ def run(op, args, queue=None):

# MLX benchmark.
if args.include_mlx:
# GPU
mx.set_default_device(mx.gpu)
mlx_time = op.run(framework="mlx")
times[op_name]["mlx_gpu"] = mlx_time

# CPU
mx.set_default_device(mx.cpu)
mlx_time = op.run(framework="mlx")
times[op_name]["mlx"] = mlx_time
times[op_name]["mlx_cpu"] = mlx_time

# CPU PyTorch benchmarks.
if args.include_cpu:
Expand Down
23 changes: 14 additions & 9 deletions mlx_benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,33 @@ def print_benchmark(times, args, reduce_mean=False):

# Column headers
headers = []
if args.include_cpu:
headers.append("cpu")
if args.include_mlx:
headers.append("mlx_gpu")
headers.append("mlx_cpu")
if args.include_mps:
headers.append("mps")
if args.include_mlx:
headers.append("mlx")
if args.include_cpu:
headers.append("cpu")
if args.include_cuda:
headers.append("cuda")

if args.include_mps and args.include_mlx:
headers.append("mps/mlx speedup (%)")
h = "mps/mlx_gpu speedup"
headers.append(h)
for k, v in times.items():
v["mps/mlx speedup (%)"] = (v["mps"] / v["mlx"] - 1) * 100
v[h] = (v["mps"] / v["mlx_gpu"] - 1) + 1

if args.include_cpu and args.include_mlx:
headers.append("cpu/mlx speedup (%)")
h = "mlx_cpu/mlx_gpu speedup"
headers.append(h)
for k, v in times.items():
v["cpu/mlx speedup (%)"] = (v["cpu"] / v["mlx"] - 1) * 100
v[h] = (v["mlx_cpu"] / v["mlx_gpu"] - 1) + 1

max_name_length = max(len(name) for name in times.keys())

# Formatting the header row
header_row = (
"| Layer" + " " * (max_name_length - 5) + " | " + " | ".join(headers) + " |"
"| Operation" + " " * (max_name_length - 5) + " | " + " | ".join(headers) + " |"
)
header_line_parts = ["-" * (max_name_length + 6)] + [
"-" * max(6, len(header)) for header in headers
Expand Down

0 comments on commit ed51428

Please sign in to comment.