In [1]:
import math, time, tracemalloc, statistics as stats
import torch
from torch.autograd.forward_ad import dual_level, make_dual, unpack_dual

In [2]:
def f(x1, x2):
    return x1 * x2 + torch.log(x1)

In [3]:
def grad_forward(x1, x2):

    with dual_level():

        x1d = make_dual(x1, torch.ones_like(x1))      # seed (1,0)
        x2d = make_dual(x2, torch.zeros_like(x2))
        y1 = f(x1d, x2d)
        i, dy_dx1 = unpack_dual(y1)                   # directional derivative w.r.t x1


    with dual_level():
        x1d = make_dual(x1, torch.zeros_like(x1))     # seed (0,1)
        x2d = make_dual(x2, torch.ones_like(x2))
        y2 = f(x1d, x2d)
        j, dy_dx2 = unpack_dual(y2)                   # directional derivative w.r.t x2


    return torch.stack([dy_dx1, dy_dx2])

In [4]:
def grad_reverse(x1, x2):
    x1 = x1.detach().requires_grad_(True)
    x2 = x2.detach().requires_grad_(True)
    y = f(x1, x2)
    g = torch.autograd.grad(y, (x1, x2), create_graph=False)
    return torch.stack(g)

In [6]:
def bench(fn, *args, repeats=2000, warmup=200):
    for i in range(warmup):
        i = fn(*args)

    
    times = []
    tracemalloc.start()
    start_snapshot = tracemalloc.take_snapshot()


    t0 = time.perf_counter()
    for i in range(repeats):
        i = fn(*args)
    t1 = time.perf_counter()

    end_snapshot = tracemalloc.take_snapshot()
    tracemalloc.stop()

    stats_diff = end_snapshot.compare_to(start_snapshot, 'lineno')
    peak_bytes = sum([s.size_diff for s in stats_diff if s.size_diff > 0])

    total_time = t1 - t0
    avg_time = (total_time / repeats) * 1e6
    return avg_time, peak_bytes




In [7]:
x1 = torch.tensor(2.0)  
x2 = torch.tensor(3.0)

t_forward, mem_forward = bench(grad_forward, x1, x2)
t_rev, mem_rev = bench(grad_reverse, x1, x2)

print("n=2 (scalar output)")
print(f"Forward-mode:  avg {t_forward:.2f} µs | peak Python heap ~ {mem_forward/1024:.1f} KiB")
print(f"Reverse-mode:  avg {t_rev:.2f} µs | peak Python heap ~ {mem_rev/1024:.1f} KiB")



n=2 (scalar output)
Forward-mode:  avg 68.86 µs | peak Python heap ~ 2.3 KiB
Reverse-mode:  avg 41.04 µs | peak Python heap ~ 210.8 KiB


In [9]:
def f_n(x):
    return (x[:-1] * x[1:]).sum() + torch.log(x).sum()

def grad_forward_n(x):
    grads = []
    for i in range(x.numel()):
        with dual_level():
            seed = torch.zeros_like(x)
            seed[i] = 1.0
            xd = make_dual(x, seed)
            y = f_n(xd)
            _, dy = unpack_dual(y)
            grads.append(dy)
    return torch.stack(grads)


def grad_reverse_n(x):
    x = x.detach().requires_grad_(True)
    y = f_n(x)
    (g,) = torch.autograd.grad(y, (x,), create_graph=False)
    return g

for n in [2, 8, 32, 128, 512,1024]:
    x = torch.full((n,), 2.0)  # positive for log
    tf, mf = bench(grad_forward_n, x, repeats=200, warmup=50)
    tr, mr = bench(grad_reverse_n, x, repeats=200, warmup=50)
    print(f"\nn={n}")
    print(f"Forward-mode:  avg {tf:.2f} µs | peak Python heap ~ {mf/1024:.1f} KiB")
    print(f"Reverse-mode:  avg {tr:.2f} µs | peak Python heap ~ {mr/1024:.1f} KiB")


n=2
Forward-mode:  avg 82.90 µs | peak Python heap ~ 0.7 KiB
Reverse-mode:  avg 61.19 µs | peak Python heap ~ 0.6 KiB

n=8
Forward-mode:  avg 311.47 µs | peak Python heap ~ 0.3 KiB
Reverse-mode:  avg 60.53 µs | peak Python heap ~ 0.6 KiB

n=32
Forward-mode:  avg 1237.85 µs | peak Python heap ~ 1.0 KiB
Reverse-mode:  avg 60.00 µs | peak Python heap ~ 0.6 KiB

n=128
Forward-mode:  avg 5084.80 µs | peak Python heap ~ 1.4 KiB
Reverse-mode:  avg 61.05 µs | peak Python heap ~ 0.6 KiB

n=512
Forward-mode:  avg 21197.46 µs | peak Python heap ~ 1.2 KiB
Reverse-mode:  avg 63.74 µs | peak Python heap ~ 0.6 KiB

n=1024
Forward-mode:  avg 43220.63 µs | peak Python heap ~ 1.6 KiB
Reverse-mode:  avg 64.92 µs | peak Python heap ~ 0.6 KiB
