Performance test

In [None]:
import torch
from tqdm.auto import tqdm

device = "cuda"
n = 10000
W = 0.5 * torch.randn(n, n, device=device) / n**0.5
x = torch.randn(n, device=device)
phi = torch.tanh
dt = 1e-3
for i in (pbar := tqdm(range(10000))):
    x += (-x + phi(W @ x)) * dt
    if i % 1000 == 0:
        pbar.set_description(f"{x.mean():.2f}, {x.std():.2f}")

Empirical DNNs

In [None]:
# del W
torch.cuda.empty_cache()

Bottleneck is generating Levy samples from scipy. Use [torchlevy](https://github.com/KU-LIM-Lab/torchlevy) instead, which translates the same algorithm to torch primitives.

The speedup is massive (60x) but it has occasional issues with generating nans. The speedup holds even against a pure numpy implementation (45x). Using torchlevy on cpu is faster than a pure numpy implementation by about 3x.

In [None]:
# !pip install git+https://github.com/UNIST-LIM-Lab/torchlevy.git

In [None]:
import RMT

from tqdm.auto import tqdm
import numpy as np
import torch

from functools import partial
import multiprocessing as mp

mp.set_start_method("spawn", force=True)

alphas = np.linspace(1, 2, 20)
sigmas_W = np.linspace(0.1, 2, 20)


with mp.Pool(7, partial(torch.set_default_device, "cuda")) as pool:
    results = list(
        tqdm(
            pool.imap(
                partial(
                    RMT.worker,
                    alphas=alphas,
                    sigma_b=0,
                    phi=torch.tanh,
                    width=1000,
                    depth=50,
                ),
                sigmas_W,
            ),
            total=len(sigmas_W),
        )
    )
import matplotlib.pyplot as plt

plt.imshow(results, aspect="auto", extent=(1, 2, 0.1, 2), origin="lower")
plt.colorbar(label="Mean log singular value")
plt.xlabel(r"$\alpha$")
plt.ylabel(r"$\sigma_W$")
plt.title("Mean log singular value of Jacobian of one layer")
plt.show()

In [None]:
# optimal num of func calls per subjob
# = total num of func calls / max num of concurrent subjobs on the queue
len(['' for alpha100 in range(100, 201, 5)
        for g100 in range(1, 301, 5)
        for seed in range(50)]) / 300

210.0

In [None]:
# walltime in minutes for each subjob
# = num of func calls per subjob * seconds per func call / 60 seconds per minute
210 * 5 / 60

17.5