In [None]:
# !pip install git+https://github.com/UNIST-LIM-Lab/torchlevy.git

Compare the singular values predicted by the RMT cavity pop dynamics against those from the empirical random MLPs

In [None]:
import RMT
import torch
import numpy as np
import matplotlib.pyplot as plt


mlp = RMT.MLP(np.linspace(-1, 1, 1000, dtype=np.float32), 50, 1.5, 1.5, device="cpu")
print({k: v.shape for k, v in mlp.items()})

torch.set_default_device('cuda')
sing_val_bins = np.linspace(0, 3, 100)[1:]
cavity_log_pdfs = RMT.jac_cavity_svd_log_pdf(sing_val_bins, 1.5, 1.5, num_doublings=8, progress=True)

plt.plot(sing_val_bins, np.exp(cavity_log_pdfs), 'o-')
plt.hist(
    np.exp(mlp["prejac_log_svdvals"][-1]),
    bins=sing_val_bins,
    density=True,
    alpha=0.5,
    label="empirical",
)

plt.show()

empirical DNN analysis

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
fpath = Path("/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/MLP_agg/width=1000;depth=100;num_realisations=50.npz")
with np.load(fpath) as npz:
    npz_series = pd.Series(npz, name="fname")

In [None]:
npz_df = npz_series.to_frame("arr")
npz_df["arr_name"] = npz_series.index.str.findall(r"([^;]*)\.txt$").str[0]
param_df = pd.DataFrame.from_records(
    npz_series.index.str.findall(r"(?P<key>[^=;]+)=(?P<value>[^=;]+)").map(dict),
    index=npz_series.index,
).astype(int)
npz_df = (
    npz_df.join(param_df)
    .set_index(param_df.columns.tolist())
    .pivot(columns="arr_name", values="arr")
)
for jac_name in ["prejac", "postjac"]:
    npz_df[f"{jac_name}_log_svdvals_CV"] = npz_df[f"{jac_name}_log_svdvals_std"] / abs(
        npz_df[f"{jac_name}_log_svdvals_mean"]
    )

In [None]:
import PIL.Image
import io
from tqdm import tqdm
import matplotlib.pyplot as plt
import multiprocessing as mp

for stat in npz_df.columns:
    for agg_over_realisations in [np.mean, np.std, np.median, np.max, np.min]:

        def plot_frame(layer):
            fig, ax = plt.subplots()
            grand_avgs = (
                npz_df.map(lambda arr: agg_over_realisations(arr[layer]))
                .groupby(["alpha100", "sigmaW100"])
                .mean()
            )
            contour = ax.tricontourf(
                *grand_avgs.reset_index()[["alpha100", "sigmaW100", stat]].to_numpy().T,  # TODO: rescale alpha and sigma_W
                levels=20,
            )
            fig.colorbar(contour, ax=ax)
            ax.set_title(
                f"Layer {layer}: {stat} over {agg_over_realisations.__name__} realisations"
            )
            ax.set_xlabel(r"$\alpha$")
            ax.set_ylabel(r"$\sigma_W$")
            buf = io.BytesIO()
            plt.savefig(buf, format="png")
            buf.seek(0)
            plt.close(fig)
            return buf

        with mp.Pool(8) as pool:
            frames = [
                PIL.Image.open(buf)
                for buf in tqdm(
                    pool.imap(plot_frame, range(100)),
                    total=100,
                    desc=f"Generating frames for {stat} with {agg_over_realisations.__name__}...",
                )
            ]

        gif_path = (
            fpath.with_suffix("")
            / f"{stat};{agg_over_realisations.__name__}_realisations.gif"
        )
        gif_path.parent.mkdir(parents=True, exist_ok=True)
        frames[0].save(
            gif_path, save_all=True, append_images=frames[1:], duration=100, loop=0
        )

In [None]:
import matplotlib.pyplot as plt

grand_avgs = (
    npz_df.map(lambda arr: np.mean(arr[20:]))
    .groupby(["alpha100", "sigmaW100"])
    .mean()
)
xyz = grand_avgs.reset_index()[["alpha100", "sigmaW100", "prejac_svd_left_D2_std"]].to_numpy().T
plt.tricontourf(*xyz, levels=20)
plt.colorbar()
# plt.tricontour(*xyz, levels=[2, 2.05, 2.1], colors=["black", "blue", "brown"])



RMT cavity pop dynamics analysis

In [None]:
def moving_avg(arr):
    return (arr[1:] + arr[:-1]) / 2


def empirical_avg(fn, bins, pdfs):
    return np.nansum(moving_avg(fn(bins) * pdfs) * np.diff(bins))


def empirical_std(fn, bins, pdfs):
    return np.sqrt(
        empirical_avg(lambda x: fn(x) ** 2, bins, pdfs)
        - empirical_avg(fn, bins, pdfs) ** 2
    )


def empirical_CV(fn, bins, pdfs):
    return empirical_std(fn, bins, pdfs) / abs(empirical_avg(fn, bins, pdfs))


def empirical_skew(fn, bins, pdfs):
    mu = empirical_avg(fn, bins, pdfs)
    sigma = empirical_std(fn, bins, pdfs)
    return empirical_avg(lambda x: (x - mu) ** 3, bins, pdfs) / sigma**3


def empirical_kurtosis(fn, bins, pdfs):
    mu = empirical_avg(fn, bins, pdfs)
    sigma = empirical_std(fn, bins, pdfs)
    return (empirical_avg(lambda x: (x - mu) ** 4, bins, pdfs) / sigma**4) - 3

In [None]:
import numpy as np
import pandas as pd

sing_vals = np.logspace(-10, 10, 1000)
with np.load(
    "/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/jac_cavity_svd_log_pdf/num_doublings=10;logspace_params=-10,10,1000.npz"
) as npz:
    npz_series = pd.Series(npz, name="fname")

In [None]:
npz_df = npz_series.to_frame("arr")
# npz_df["arr_name"] = npz_series.index.str.findall(r"([^;]*)\.txt$").str[0]

param_df = pd.DataFrame.from_records(
    npz_series.index.str[:-4].str.findall(r"(?P<key>[^=;]+)=(?P<value>[^=;]+)").map(dict),
    index=npz_series.index,
).astype(int)
npz_df = (
    npz_df.join(param_df)
    .set_index(param_df.columns.tolist())
    # .pivot(columns="arr_name", values="arr")
)
# npz_df['log_svdvals_CV'] = npz_df['log_svdvals_std'] / abs(npz_df['log_svdvals_mean'])
npz_df

In [None]:
npz_df["CV"] = npz_df["arr"].map(
    lambda log_pdfs: empirical_CV(lambda x: np.log(x), sing_vals, np.exp(log_pdfs))
)
npz_df["mean"] = npz_df["arr"].map(
    lambda log_pdfs: empirical_avg(lambda x: np.log(x), sing_vals, np.exp(log_pdfs))
)
npz_df["std"] = npz_df["arr"].map(
    lambda log_pdfs: empirical_std(lambda x: np.log(x), sing_vals, np.exp(log_pdfs))
)
npz_df["skewness"] = npz_df["arr"].map(
    lambda log_pdfs: empirical_skew(lambda x: np.log(x), sing_vals, np.exp(log_pdfs))
)
npz_df["kurtosis"] = npz_df["arr"].map(
    lambda log_pdfs: empirical_kurtosis(
        lambda x: np.log(x), sing_vals, np.exp(log_pdfs)
    )
)
npz_df["norm"] = npz_df["arr"].map(
    lambda log_pdfs: empirical_avg(lambda x: 1, sing_vals, np.exp(log_pdfs))
)
import matplotlib.pyplot as plt

for stat in ["CV", "mean", "std", "skewness", "kurtosis", "norm"]:
    plt.tricontourf(
        *npz_df[~npz_df[stat].isna()]
        .reset_index()[["alpha100", "sigmaW100", stat]]
        .to_numpy()
        .T,
        levels=20,
    )
    plt.colorbar()
    plt.title(stat)
    plt.show()