Performance test

In [None]:
import torch
from tqdm.auto import tqdm

device = "cuda"
n = 10000
W = 0.5 * torch.randn(n, n, device=device) / n**0.5
x = torch.randn(n, device=device)
phi = torch.tanh
dt = 1e-3
for i in (pbar := tqdm(range(10000))):
    x += (-x + phi(W @ x)) * dt
    if i % 1000 == 0:
        pbar.set_description(f"{x.mean():.2f}, {x.std():.2f}")

Empirical DNNs

In [None]:
# del W
torch.cuda.empty_cache()

Bottleneck is generating Levy samples from scipy. Use [torchlevy](https://github.com/KU-LIM-Lab/torchlevy) instead, which translates the same algorithm to torch primitives.

The speedup is massive (60x) but it has occasional issues with generating nans. The speedup holds even against a pure numpy implementation (45x). Using torchlevy on cpu is faster than a pure numpy implementation by about 3x.

In [None]:
# !pip install git+https://github.com/UNIST-LIM-Lab/torchlevy.git

In [None]:
import RMT

from tqdm.auto import tqdm
import numpy as np
import torch

from functools import partial
import multiprocessing as mp

mp.set_start_method("spawn", force=True)

alphas = np.linspace(1, 2, 20)
sigmas_W = np.linspace(0.1, 2, 20)


with mp.Pool(7, partial(torch.set_default_device, "cuda")) as pool:
    results = list(
        tqdm(
            pool.imap(
                partial(
                    RMT.worker,
                    alphas=alphas,
                    sigma_b=0,
                    phi=torch.tanh,
                    width=1000,
                    depth=50,
                ),
                sigmas_W,
            ),
            total=len(sigmas_W),
        )
    )
import matplotlib.pyplot as plt

plt.imshow(results, aspect="auto", extent=(1, 2, 0.1, 2), origin="lower")
plt.colorbar(label="Mean log singular value")
plt.xlabel(r"$\alpha$")
plt.ylabel(r"$\sigma_W$")
plt.title("Mean log singular value of Jacobian of one layer")
plt.show()

In [None]:
# optimal num of func calls per subjob
# = total num of func calls / max num of concurrent subjobs on the queue
len(['' for alpha100 in range(100, 201, 5)
        for g100 in range(1, 301, 5)
        for seed in range(50)]) / 300

In [None]:
# walltime in minutes for each subjob
# = num of func calls per subjob * seconds per func call / 60 seconds per minute
210 * 5 / 60

In [None]:
import numpy as np
data = np.load("/import/silo3/wardak/width1000_depth50.npz")
data

In [None]:
from tqdm.auto import tqdm
import re


def to_tup(
    s, pattern=re.compile(r"alpha(?P<alpha>[\d.]+)_g(?P<g>[\d.]+)_seed(?P<seed>\d+)")
):
    return tuple(int(x) for x in pattern.match(s).groups())


means = {to_tup(k): v[~np.isneginf(v)].mean() for k, v in tqdm(data.items())}
stds = {to_tup(k): v[~np.isneginf(v)].std() for k, v in tqdm(data.items())}

In [None]:
import scipy.stats
skewnesses = {to_tup(k): scipy.stats.skew(v[~np.isneginf(v)]) for k, v in tqdm(data.items())}
kurtoses = {to_tup(k): scipy.stats.kurtosis(v[~np.isneginf(v)]) for k, v in tqdm(data.items())}

In [None]:
# sqrt( ( CV(sing val)^2 + 1 )^L - 1 )
means_notlog = {to_tup(k): np.exp(v).mean() for k, v in tqdm(data.items())}
stds_notlog = {to_tup(k): np.exp(v).std() for k, v in tqdm(data.items())}

In [None]:
import matplotlib.pyplot as plt

from itertools import groupby


def groupby_mean(data_dict, sort_key=lambda x: x[0][:-1]):
    return {
        k: np.mean([gv for gk, gv in list(grouper)])
        for k, grouper in groupby(sorted(data_dict.items(), key=sort_key), sort_key)
    }


for stats, name in [
    (means, "mean"),
    (stds, "std"),
    (None, "CV"),
    (skewnesses, "skewness"),
    (kurtoses, "kurtosis"),
    (None, "dist_CV"),
]:
    if name == "CV":
        xyz_means = np.array([(k[0], k[1], v) for k, v in groupby_mean(means).items()])
        xyz_stds = np.array([(k[0], k[1], v) for k, v in groupby_mean(stds).items()])
        xyz = xyz_means.copy()
        xyz[:, 2] = xyz_stds[:, 2] / abs(xyz_means[:, 2])
    elif name == "dist_CV":
        xyz_means = np.array([(k[0], k[1], v) for k, v in groupby_mean(means_notlog).items()])
        xyz_stds = np.array([(k[0], k[1], v) for k, v in groupby_mean(stds_notlog).items()])
        xyz = xyz_means.copy()
        L = 1
        # xyz[:, 2] = (np.sqrt(((xyz_stds[:, 2] / abs(xyz_means[:, 2]))**2 + 1)**L - 1))
        xyz[:, 2] = np.clip(xyz_stds[:, 2]/ abs(xyz_means[:, 2]), None, 2)
    else:
        xyz = np.array([(k[0], k[1], v) for k, v in groupby_mean(stats).items()])

    xyz[:, 0] /= 100
    xyz[:, 1] /= 100

    mask = xyz_means[:, 0] > 190
    mask = np.ones_like(xyz_means[:, 0], dtype=bool)
    print(mask.sum())

    plt.tricontourf(xyz[mask, 0], xyz[mask, 1], xyz[mask, 2], levels=200)
    if name == 'dist_CV':
        plt.colorbar()
    else:
        plt.colorbar(label=f"{name} of log singular value")
    if name == "CV":
        plt.tricontour(
            xyz[mask, 0],
            xyz[mask, 1],
            xyz[mask, 2],
            levels=[2, 2.05, 2.1],
            colors=["black", "blue", "brown"],
        )
    plt.xlabel(r"$\alpha$")
    plt.ylabel(r"$\sigma_W$")
    if name == "dist_CV":
        plt.title(f"Log-estimate of the CV of pairwise distances")
    else:
        plt.title(f"{name} of log singular value of Jacobian of one layer")
    plt.show()

Sing vals

In [None]:
import torch
torch.set_default_device("cpu")
from tqdm.auto import tqdm
import RMT
import matplotlib.pyplot as plt


sing_vals = torch.linspace(0, 3, 100)
pdfs = RMT.singular_value_pdf(1.5, torch.ones(10000), sing_vals, 10000)

plt.plot(sing_vals.cpu(), pdfs.cpu())

In [None]:
import torch
from tqdm.auto import tqdm
import multiprocessing as mp
from functools import partial
import RMT
import numpy as np

import matplotlib.pyplot as plt

mp.set_start_method("spawn", force=True)

torch.set_default_device("cuda")

sing_vals = torch.linspace(0, 3, 101)
alphas = torch.linspace(1, 2, 10)


with mp.Pool(5, torch.set_default_device, ("cuda",)) as pool:
    pdfs = list(
        tqdm(
            pool.imap(
                partial(
                    RMT.singular_value_pdf,
                    chi_samples=torch.ones(1000),
                    sing_vals=sing_vals,
                    num_steps=10000,
                    leave=False,
                ),
                alphas,
            ),
            total=len(alphas),
        )
    )

plt.plot(sing_vals.cpu(), np.transpose([pdflist.cpu() for pdflist in pdfs]), "o-")
plt.legend([f"alpha={alpha:.2f}" for alpha in alphas])

In [None]:
import torch
from torchlevy import stable_dist
from scipy.stats import levy_stable

from tqdm.auto import tqdm
import numpy as np


def fast_integral(integrand, zmin, zmax, dz, ndim=1):
    zs = torch.Tensor(np.r_[zmin:zmax:dz])
    if ndim > 1:
        zgrid = torch.meshgrid(*((zs,) * ndim))
    else:
        zgrid = (zs,)
    out = integrand(*zgrid)
    return out.sum(tuple(range(ndim))) * dz**ndim


def q_map(q, alpha, sigma_W, sigma_b, phi, z_min=-100, z_max=100, dz=0.1):
    q = torch.atleast_1d(q)

    def integrand(z):
        return (
            torch.Tensor(levy_stable.pdf(2 ** (-1 / alpha) * z.cpu(), alpha, 0))[
                :, None
            ]
            # stable_dist.pdf(2 ** (-1 / alpha) * z, alpha, is_cache=True)[:, None]
            * abs(phi((q[None, :] / 2) ** (1 / alpha) * z[:, None])) ** alpha
        )

    integral = fast_integral(integrand, z_min, z_max, dz)
    return (sigma_W**alpha) * integral + sigma_b**alpha


def q_star(
    alpha,
    sigma_W,
    sigma_b=0,
    phi=torch.tanh,
    q_init=3.0,
    max_iterations=500,
    tol=1e-9,
):
    q = q_init
    qs = -torch.ones(max_iterations)
    for i in (pbar := tqdm(range(max_iterations))):
        q_new = q_map(torch.Tensor([q]), alpha, sigma_W, sigma_b, phi)
        if abs(q_new - q) < tol:
            break
        q = q_new
        qs[i] = q
        if i % (max_iterations // 100) == 0:
            pbar.set_postfix_str(f"q={q.item():.4f}")
    return qs[:i]

import RMT



torch.set_default_device("cpu")
# print(q_map(torch.Tensor(range(10)), 1.5, 1.0, 0, torch.tanh))

import matplotlib.pyplot as plt

vals = [q_star_MC(1, 1)[-1] for _ in tqdm(range(100))]
print(np.mean(vals), np.std(vals))

In [None]:
import torch

device = "cuda"
torch.set_default_device(device)

import RMT

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import multiprocessing as mp
from functools import partial

mp.set_start_method("spawn", force=True)


sing_vals = torch.linspace(0, 3, 100)[1:]

with mp.Pool(3, torch.set_default_device, (device,)) as pool:
    pdfs_list = list(
        tqdm(
            pool.imap(
                partial(
                    RMT.jacobian_singular_value_pdf,
                    sigma_W=1.5,
                    sing_vals=sing_vals,
                    pop_size=1000,
                    num_steps=10000,
                    phi=torch.tanh,
                ),
                [1.5 for _ in range(10)],
            ),
        )
    )
# pdfs = RMT.jacobian_singular_value_pdf(1.5, 1.0, sing_vals, 100, 10000, torch.tanh)

for pdfs in pdfs_list:
    plt.plot(sing_vals.cpu(), pdfs.cpu())

Compare the singular values predicted by the RMT cavity pop dynamics against those from the empirical random MLPs

In [None]:
import numpy as np

empirical_svdvals = np.exp(
    RMT.MLP_log_svdvals(
        1.5,
        1.5,
        0,
        torch.tanh,
        1000,
        1000,
    )
)

In [None]:
sing_val_bins = torch.linspace(0, 3, 100)[1:]
theoretical_pdfs = RMT.jacobian_singular_value_pdf(
    1.5,
    1.5,
    sing_val_bins,
    1000,
    10000,
    torch.tanh,
)


In [None]:

plt.plot(sing_val_bins.cpu(), theoretical_pdfs.cpu())
plt.hist(
    empirical_svdvals,
    bins=sing_val_bins.cpu(),
    density=True,
    alpha=0.5,
    label="empirical",
)

plt.show()