# torch benchmark

ref) https://pytorch.org/tutorials/recipes/recipes/benchmark.html

## Example 1: Benchmarking a single operation

In [None]:
def diagmul(mat, diag):
    return mat * diag

In [None]:
import torch.utils.benchmark as benchmark
import torch

A = torch.randn(256, 256)
d = torch.randn(256)
num_threads = torch.get_num_threads()
t0 = benchmark.Timer(
    stmt="torch.matmul(A, torch.diag(d))", globals={"d": d, "A": A}, num_threads=num_threads
).blocked_autorange(min_run_time=1)
t1 = benchmark.Timer(
    stmt="diagmul(d, A)",
    setup="from __main__ import diagmul",
    globals={"d": d, "A": A},
    num_threads=num_threads,
).blocked_autorange(min_run_time=1)

In [None]:
(A * d).shape

In [None]:
print("Matmul: ", t0)
print("Diagmul: ", t1)

In [None]:
t2 = benchmark.Timer(
    stmt="d.mul(A)", globals={"d": d, "A": A}, num_threads=num_threads
).blocked_autorange(min_run_time=1)
print("Matmul: ", t2)

In [None]:
from itertools import product

results = []
sizes = [1, 64, 1024]
for m, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = "Diagonal matrix multiplication"
    sub_label = f"[{m}, {n}]"
    A = torch.rand((m, n))
    d = torch.rand(n)
    for num_threads in [1, 4, 8]:
        results.append(
            benchmark.Timer(
                stmt="torch.matmul(A, torch.diag(d))",
                globals={"d": d, "A": A},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="mm",
            ).blocked_autorange(min_run_time=1)
        )
        results.append(
            benchmark.Timer(
                stmt="diagmul(A,d)",
                setup="from __main__ import diagmul",
                globals={"d": d, "A": A},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="vm",
            ).blocked_autorange(min_run_time=1)
        )

compare = benchmark.Compare(results)
compare.print()

## layer power

In [None]:
import torch

from src.utils import eqprop_utils


def layer_power_1(G, in_v, out_v):
    """Calculate sum_{i,j} G_{ij}*(in_v_i - out_v_j)^2"""
    return torch.sum(G.mT * eqprop_utils.deltaV(in_v, out_v).pow(2), dim=(1, 2))


def layer_power_2(G, in_v, out_v):
    """Calculate sum_{i,j} G_{ij}*(in_v_i - out_v_j)^2"""
    in_v = in_v.unsqueeze(1)
    out_v = out_v.unsqueeze(2)
    return (
        torch.bmm(in_v.pow(2), G).sum(dim=(1, 2))
        + torch.bmm(G, out_v.pow(2)).sum(dim=(1, 2))
        - 2 * (in_v @ G @ out_v).squeeze()
    )


shape = (64, 512, 256)
G = torch.randn(shape)
in_v = torch.randn(shape[0], shape[1])
out_v = torch.randn(shape[0], shape[2])

In [None]:
torch.allclose(layer_power_1(G, in_v, out_v), layer_power_2(G, in_v, out_v), atol=1e-7)

In [None]:
t3 = benchmark.Timer(
    stmt="torch.bmm(in_v.unsqueeze(1).pow(2), G).sum(dim=(1,2))",
    globals={"G": G, "in_v": in_v, "out_v": out_v},
    num_threads=1,
).blocked_autorange(min_run_time=1)
t4 = benchmark.Timer(
    stmt="torch.bmm(G, out_v.unsqueeze(2).pow(2)).squeeze().sum(dim=(1))",
    globals={"G": G, "in_v": in_v, "out_v": out_v},
    num_threads=1,
).blocked_autorange(min_run_time=1)
print("sum 2d: ", t3)
print("squeeze and sum: ", t4)

In [None]:
import torch.utils.benchmark as benchmark

num_threads = torch.get_num_threads()
t0 = benchmark.Timer(
    stmt="layer_power_1(G, in_v, out_v)",
    setup="from __main__ import layer_power_1",
    globals={"in_v": in_v, "out_v": out_v, "G": G},
    num_threads=num_threads,
).blocked_autorange(min_run_time=1)
t1 = benchmark.Timer(
    stmt="layer_power_2(G, in_v, out_v)",
    setup="from __main__ import layer_power_2",
    globals={"in_v": in_v, "out_v": out_v, "G": G},
    num_threads=num_threads,
).blocked_autorange(min_run_time=1)

print("layer_power_1: ", t0)
print("layer_power_2: ", t1)

In [None]:
from itertools import product

results = []
sizes = [64, 256, 1024]
for m, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = "Layer power"
    sub_label = f"[{m}, {n}]"
    shape = (64, m, n)
    G = torch.randn(shape)
    in_v = torch.randn(shape[0], shape[1])
    out_v = torch.randn(shape[0], shape[2])
    for num_threads in [1, 4, 8]:
        results.append(
            benchmark.Timer(
                stmt="layer_power_1(G, in_v, out_v)",
                setup="from __main__ import layer_power_1",
                globals={"in_v": in_v, "out_v": out_v, "G": G},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="deltaV",
            ).blocked_autorange(min_run_time=1)
        )
        results.append(
            benchmark.Timer(
                stmt="layer_power_2(G, in_v, out_v)",
                setup="from __main__ import layer_power_2",
                globals={"in_v": in_v, "out_v": out_v, "G": G},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="sq+sq-prod",
            ).blocked_autorange(min_run_time=1)
        )

compare = benchmark.Compare(results)
compare.print()

## Add diag

In [None]:
# update diag elements in batched matrix
import torch

A = torch.ones(2, 3, 3)
v = torch.ones(2, 3)

In [None]:
%%timeit
A.diagonal(dim1=1, dim2=2)[:] += v

## eqprop grad

In [None]:
import torch

# from src.utils import eqprop_utils

In [None]:
def grad1(in_V_free, out_V_free, in_V, out_V):
    """Calculate the gradient of the layer power with respect to the free variables"""
    free_dV = eqprop_utils.deltaV(in_V_free, out_V_free)
    nudge_dV = eqprop_utils.deltaV(in_V, out_V)
    return nudge_dV.pow(2).mean(dim=0) - free_dV.pow(2).mean(dim=0)


def grad2(
    in_V_free: torch.Tensor, out_V_free: torch.Tensor, in_V: torch.Tensor, out_V: torch.Tensor
) -> torch.Tensor:
    res = 2 * (
        torch.bmm(out_V_free.unsqueeze(2), in_V_free.unsqueeze(1)).squeeze().mean(dim=0)
        - torch.bmm(out_V.unsqueeze(2), in_V.unsqueeze(1)).squeeze().mean(dim=0)
    )
    res += in_V.pow(2).mean(dim=0) - in_V_free.pow(2).mean(dim=0)
    res += (out_V.pow(2).mean(dim=0) - out_V_free.pow(2).mean(dim=0)).unsqueeze(1)
    return res

In [None]:
grad3 = torch.jit.script(grad2)

In [None]:
shape_in = (128, 1024)
shape_out = (128, 512)
in_V_free = torch.randn(shape_in)
out_V_free = torch.randn(shape_out)
in_V = torch.randn(shape_in)
out_V = torch.randn(shape_out)

# torch.allclose(grad1(in_V_free, out_V_free, in_V, out_V), grad2(in_V_free, out_V_free, in_V, out_V), atol=1e-6)

In [None]:
import torch.utils.benchmark as benchmark

num_threads = torch.get_num_threads()
t1 = benchmark.Timer(
    stmt="grad2(in_V_free, out_V_free, in_V, out_V)",
    setup="from __main__ import grad2",
    globals={"in_V_free": in_V_free, "out_V_free": out_V_free, "in_V": in_V, "out_V": out_V},
    num_threads=num_threads,
).blocked_autorange(min_run_time=1)
t0 = benchmark.Timer(
    stmt="grad3(in_V_free, out_V_free, in_V, out_V)",
    setup="from __main__ import grad3",
    globals={"in_V_free": in_V_free, "out_V_free": out_V_free, "in_V": in_V, "out_V": out_V},
    num_threads=num_threads,
).blocked_autorange(min_run_time=1)


print("layer_power_3: ", t0)
print("layer_power_2: ", t1)

## OTS

## scalar division

In [None]:
import torch
import torch.utils.benchmark as benchmark

A = torch.randn(1024, 256)
beta = torch.rand(1)
num_threads = torch.get_num_threads()
t0 = benchmark.Timer(
    stmt="A/beta", globals={"beta": beta, "A": A}, num_threads=num_threads
).blocked_autorange(min_run_time=1)
t1 = benchmark.Timer(
    stmt="A*(1/beta)",
    globals={"beta": beta, "A": A},
    num_threads=num_threads,
).blocked_autorange(min_run_time=1)
print("A/beta: ", t0)
print("A*(1/beta): ", t1)