# Benchmarks for different implementations of Muon

In [None]:
from typing import Callable

import torch
import torch.utils.benchmark as benchmark

In [15]:
gs = [
    *[torch.rand(768, 4*768, dtype=torch.float32, device="cuda") for _ in range(12)],
    *[torch.rand(4*768, 768, dtype=torch.float32, device="cuda") for _ in range(12)],
    *[torch.rand(768, 768, dtype=torch.float32, device="cuda") for _ in range(4*12)],
]

s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
s3 = torch.cuda.Stream()

In [None]:
def benchmark_in_us(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e3


def single(f: Callable[[torch.Tensor, int], torch.Tensor], gs: list[torch.Tensor], steps: int):
    g = gs[0]
    f(g, steps)


def multiple(f: Callable[[torch.Tensor, int], torch.Tensor], gs: list[torch.Tensor], steps: int):
    for g in gs:
        f(g, steps)


def multiplexed_2_streams(f: Callable[[torch.Tensor, int], torch.Tensor], gs: list[torch.Tensor], steps: int):
    with torch.cuda.stream(s1):
        for g in gs[:9]:
            f(g, steps)
    with torch.cuda.stream(s2):
        for g in gs[9:]:
            f(g, steps)


def multiplexed_3_streams(f: Callable[[torch.Tensor, int], torch.Tensor], gs: list[torch.Tensor], steps: int):
    with torch.cuda.stream(s1):
        for g in gs[:6]:
            f(g, steps)
    with torch.cuda.stream(s2):
        for g in gs[6:12]:
            f(g, steps)
    with torch.cuda.stream(s3):
        for g in gs[12:]:
            f(g, steps)


def run_benchmarks(f: Callable[[torch.Tensor, int], torch.Tensor]):
    # Warmup
    benchmark_in_us(single, f, gs, 4)
    benchmark_in_us(single, f, gs, 5)

    runtime_single_4_steps = benchmark_in_us(single, f, gs, 4)
    runtime_single_5_steps = benchmark_in_us(single, f, gs, 5)
    runtime_multiple_4_steps = benchmark_in_us(multiple, f, gs, 4) / len(gs)
    runtime_multiple_5_steps = benchmark_in_us(multiple, f, gs, 5) / len(gs)
    runtime_mux_2_streams_4_steps = benchmark_in_us(multiplexed_2_streams, f, gs, 4) / len(gs)
    runtime_mux_2_streams_5_steps = benchmark_in_us(multiplexed_2_streams, f, gs, 5) / len(gs)
    runtime_mux_3_streams_4_steps = benchmark_in_us(multiplexed_3_streams, f, gs, 4) / len(gs)
    runtime_mux_3_streams_5_steps = benchmark_in_us(multiplexed_3_streams, f, gs, 5) / len(gs)
    print(f"Single, 4 steps: {runtime_single_4_steps} ms")
    print(f"Single, 5 steps: {runtime_single_5_steps} ms")
    print(f"Multiple, 4 steps: {runtime_multiple_4_steps} ms")
    print(f"Multiple, 5 steps: {runtime_multiple_5_steps} ms")
    print(f"Multiplexed 2 streams, 4 steps: {runtime_mux_2_streams_4_steps} ms")
    print(f"Multiplexed 2 streams, 5 steps: {runtime_mux_2_streams_5_steps} ms")
    print(f"Multiplexed 3 streams, 4 steps: {runtime_mux_3_streams_4_steps} ms")
    print(f"Multiplexed 3 streams, 5 steps: {runtime_mux_3_streams_5_steps} ms")

## 5-step Muon

In [None]:
@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, n: int):
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    X.div_(X.norm() + 1e-7)
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(n):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("Original - naive compile")
run_benchmarks(zeropower_via_newtonschulz5)

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, n: int):
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    X.div_(X.norm() + 1e-7)
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(n):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("Original")
run_benchmarks(zeropower_via_newtonschulz5)

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        X = a * X + (b * A + c * A @ A) @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("X = a * X + (b * A + c * A @ A) @ X")
run_benchmarks(zeropower_via_newtonschulz5)

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        X = (a * I + A @ (b * I + c * A)) @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("X = (a * I + A @ (b * I + c * A)) @ X")
run_benchmarks(zeropower_via_newtonschulz5)

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        S = A @ (b * I + c * A)
        torch.diagonal(S).add_(a)
        X = S @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("w/ S")
run_benchmarks(zeropower_via_newtonschulz5)

## 4-Step Muon

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for a, b, c in (
        (4.8969, -14.0610, 10.1415),
        (4.7285, -10.0664, 5.4487),
        (4.0968, -5.9557, 2.3200),
        (3.0319, -3.3993, 1.1814),
    ):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("Original - 4-step")
run_benchmarks(zeropower_via_newtonschulz5)

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for a, b, c in (
        (4.8969, -14.0610, 10.1415),
        (4.7285, -10.0664, 5.4487),
        (4.0968, -5.9557, 2.3200),
        (3.0319, -3.3993, 1.1814),
    ):
        A = X @ X.T
        X = (a * I + b * A @ (I + c/b * A)) @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("X = (a * I + b * A @ (I + c/b * A)) @ X - 4-step")
run_benchmarks(zeropower_via_newtonschulz5)

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for a, b, c in (
        (4.8969, -14.0610, 10.1415),
        (4.7285, -10.0664, 5.4487),
        (4.0968, -5.9557, 2.3200),
        (3.0319, -3.3993, 1.1814),
    ):
        A = X @ X.T
        S = A @ (b * I + c * A)
        torch.diagonal(S).add_(a)
        X = S @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("w/ S - 4-step")
run_benchmarks(zeropower_via_newtonschulz5)