## Introduction

In [None]:
# !pip install git+https://github.com/UNIST-LIM-Lab/torchlevy.git

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import ipywidgets as widgets


def load_dataset(fpath: Path):
    with np.load(fpath) as npz:
        npz_series = pd.Series(npz, name="fname")
    npz_df = npz_series.to_frame("arr")
    npz_df["arr_name"] = npz_series.index.str.findall(r"([^;.]*)\.txt$").str[0]
    npz_df["arr_name"] = npz_df["arr_name"].where(
        lambda series: ~series.str.contains("="), "arr"
    )
    param_df = pd.DataFrame.from_records(
        npz_series.index.str.findall(
            r"(?P<key>[^=;.]+)=(?!\.txt)(?P<value>[^=;.]+)"
        ).map(dict),
        index=npz_series.index,
    ).astype(int)
    npz_df = (
        npz_df.join(param_df)
        .set_index(param_df.columns.tolist())
        .pivot(columns="arr_name", values="arr")
    )
    return npz_df


# widgets.interact_manual(
#     load_dataset,
#     fpath=widgets.Dropdown(
#         options=sorted(
#             Path("/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig").glob(
#                 "*/*.npz"
#             )
#         )
#     ),
# )

In [None]:
def moving_avg(arr):
    return (arr[1:] + arr[:-1]) / 2


def empirical_avg(fn, bins, pdfs):
    return np.nansum(moving_avg(fn(bins) * pdfs) * np.diff(bins))


def empirical_std(fn, bins, pdfs):
    return np.sqrt(
        empirical_avg(lambda x: fn(x) ** 2, bins, pdfs)
        - empirical_avg(fn, bins, pdfs) ** 2
    )


def empirical_CV(fn, bins, pdfs):
    return empirical_std(fn, bins, pdfs) / abs(empirical_avg(fn, bins, pdfs))


def empirical_skew(fn, bins, pdfs):
    mu = empirical_avg(fn, bins, pdfs)
    sigma = empirical_std(fn, bins, pdfs)
    return empirical_avg(lambda x: (x - mu) ** 3, bins, pdfs) / sigma**3


def empirical_kurtosis(fn, bins, pdfs):
    mu = empirical_avg(fn, bins, pdfs)
    sigma = empirical_std(fn, bins, pdfs)
    return (empirical_avg(lambda x: (x - mu) ** 4, bins, pdfs) / sigma**4)

## Comparison
Compare the singular values predicted by the RMT cavity pop dynamics against those from the empirical random MLPs

In [None]:
import matplotlib.pyplot as plt
import ipywidgets as widgets

import RMT
import torch

torch.set_default_device("cpu")

RMT_df = load_dataset(Path("/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/jac_cavity_svd_log_pdf/num_doublings=11;logspace_params=-10,10,1000.npz"))
sing_vals = np.logspace(-10, 10, 1000)

@widgets.interact_manual(
    alpha100=widgets.IntSlider(min=100, max=200, step=5, continuous_update=False),
    sigmaW100=widgets.IntSlider(min=1, max=301, step=5, continuous_update=False),
    compute_MLP=False,
)
def f(alpha100, sigmaW100, compute_MLP):
    fig, ax = plt.subplots()
    
    if compute_MLP:
        mlp = RMT.MLP(
            np.linspace(-1, 1, 1000, dtype=np.float32),
            50,
            alpha100 / 100,
            sigmaW100 / 100,
        )
        ax.hist(
            np.exp(mlp["prejac_log_svdvals"][-1]),
            bins=[-np.inf] + sing_vals.tolist() + [np.inf],
            density=True,
            alpha=0.5,
            label="empirical",
        )
    ax.plot(sing_vals, np.exp(RMT_df.loc[(alpha100, sigmaW100), "arr"]))
    ax.set_title(f"norm={empirical_avg(lambda x: 1, sing_vals, np.exp(RMT_df.loc[(alpha100, sigmaW100), 'arr'])):.2f}")
    ax.set_xscale("log")
    ax.set_yscale("log")

## empirical DNN analysis

In [None]:
fpath = Path("/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/MLP_agg/width=1000;depth=100;num_realisations=50.npz")
MLP_df = load_dataset(fpath)

for jac_name in ["prejac", "postjac"]:
    MLP_df[f"{jac_name}_log_svdvals_CV"] = MLP_df[f"{jac_name}_log_svdvals_std"] / abs(
        MLP_df[f"{jac_name}_log_svdvals_mean"]
    )


In [None]:
import matplotlib.pyplot as plt

grand_avgs = (
    MLP_df.map(lambda arr: np.mean(arr[20:]))
    .groupby(["alpha100", "sigmaW100"])
    .mean()
)
xyz = grand_avgs.reset_index()[["alpha100", "sigmaW100", "prejac_log_svdvals_CV"]].to_numpy().T
plt.tricontourf(*xyz, levels=20)
plt.colorbar()
# plt.tricontour(*xyz, levels=[2, 2.05, 2.1], colors=["black", "blue", "brown"])



In [None]:
import PIL.Image
import io
from tqdm import tqdm
import matplotlib.pyplot as plt
import multiprocessing as mp

for stat in MLP_df.columns:
    for agg_over_realisations in [np.mean, np.std, np.median, np.max, np.min]:

        def plot_frame(layer):
            fig, ax = plt.subplots()
            grand_avgs = (
                MLP_df.map(lambda arr: agg_over_realisations(arr[layer]))
                .groupby(["alpha100", "sigmaW100"])
                .mean()
            )
            contour = ax.tricontourf(
                *grand_avgs.reset_index()[["alpha100", "sigmaW100", stat]].to_numpy().T,  # TODO: rescale alpha and sigma_W
                levels=20,
            )
            fig.colorbar(contour, ax=ax)
            ax.set_title(
                f"Layer {layer}: {stat} over {agg_over_realisations.__name__} realisations"
            )
            ax.set_xlabel(r"$\alpha$")
            ax.set_ylabel(r"$\sigma_W$")
            buf = io.BytesIO()
            plt.savefig(buf, format="png")
            buf.seek(0)
            plt.close(fig)
            return buf

        with mp.Pool(8) as pool:
            frames = [
                PIL.Image.open(buf)
                for buf in tqdm(
                    pool.imap(plot_frame, range(100)),
                    total=100,
                    desc=f"Generating frames for {stat} with {agg_over_realisations.__name__}...",
                )
            ]

        gif_path = (
            fpath.with_suffix("")
            / f"{stat};{agg_over_realisations.__name__}_realisations.gif"
        )
        gif_path.parent.mkdir(parents=True, exist_ok=True)
        frames[0].save(
            gif_path, save_all=True, append_images=frames[1:], duration=100, loop=0
        )

## RMT cavity pop dynamics analysis

In [None]:
import numpy as np
import pandas as pd

sing_vals = np.logspace(-10, 10, 1000)
RMT_df = (
    load_dataset(
        Path("/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig")
        / "jac_cavity_svd_log_pdf"
        / "num_doublings=11;logspace_params=-10,10,1000.npz"
    )
    .groupby(["alpha100", "sigmaW100"])
    .aggregate(lambda x: np.log(np.median(np.exp(np.stack(x)), axis=0)))
)

In [None]:
# mask = (sing_vals > 0) & (sing_vals < 1)
mask = np.ones_like(sing_vals, dtype=bool)

RMT_df["CV"] = RMT_df["arr"].map(
    lambda log_pdfs: empirical_CV(
        lambda x: np.log(x), sing_vals[mask], np.exp(log_pdfs)[mask]
    )
)
RMT_df["mean"] = RMT_df["arr"].map(
    lambda log_pdfs: empirical_avg(
        lambda x: np.log(x), sing_vals[mask], np.exp(log_pdfs)[mask]
    )
)
RMT_df["std"] = RMT_df["arr"].map(
    lambda log_pdfs: empirical_std(
        lambda x: np.log(x), sing_vals[mask], np.exp(log_pdfs)[mask]
    )
)
RMT_df["skewness"] = RMT_df["arr"].map(
    lambda log_pdfs: empirical_skew(
        lambda x: np.log(x), sing_vals[mask], np.exp(log_pdfs)[mask]
    )
)
RMT_df["kurtosis"] = RMT_df["arr"].map(
    lambda log_pdfs: empirical_kurtosis(
        lambda x: np.log(x), sing_vals[mask], np.exp(log_pdfs)[mask]
    )
)
RMT_df["norm"] = RMT_df["arr"].map(
    lambda log_pdfs: empirical_avg(lambda x: 1, sing_vals[mask], np.exp(log_pdfs)[mask])
)
import matplotlib.pyplot as plt

for stat in ["CV", "mean", "std", "skewness", "kurtosis", "norm"]:
    plt.tricontourf(
        *(
            RMT_df[~RMT_df[stat].isna()]
            .groupby(["alpha100", "sigmaW100"])[stat]
            .mean()
        )
        .reset_index()[["alpha100", "sigmaW100", stat]]
        .to_numpy()
        .T,
        levels=20,
    )
    plt.colorbar()
    plt.title(stat)
    plt.show()

## MFT map

In [None]:
import RMT
import torch

torch.set_default_device("cpu")
MLP_df = load_dataset(
    Path(
        "/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/MLP_agg/width=1000;depth=100;num_realisations=50.npz"
    )
)

In [None]:
%matplotlib widget

In [None]:
import matplotlib.pyplot as plt
import ipywidgets as widgets

import RMT
import torch

torch.set_default_device("cpu")

fig, ax = plt.subplots()
# plt.show()


@widgets.interact_manual(
    alpha100=widgets.IntSlider(min=100, max=200, step=5, continuous_update=False),
    sigmaW100=widgets.IntSlider(min=1, max=301, step=5, continuous_update=False),
    compute_MFT=False,
)
def f(alpha100, sigmaW100, compute_MFT):
    alpha = alpha100 / 100
    sigmaW = sigmaW100 / 100
    ax.clear()
    MLP_arr = MLP_df.loc[(alpha100, sigmaW100, 0), "postact_sq_mean"].mean(-1)
    ax.plot(MLP_arr, "o-", label="MLP empirical")
    if compute_MFT:
        q0 = sigmaW**alpha * 1 / (alpha + 1)
        stats = RMT.MFT_map(q0, alpha, sigmaW, num_layers=100, num_realisations=50)
        MFT_arr = stats["postact_sq_mean"][:, 0].mean(-1)
        ax.plot(MFT_arr, "o-", label="MFT")
        # MLP_live_arr = (
        #     (RMT.MLP(
        #         torch.linspace(-1, 1, 1000, dtype=torch.float32), 100, alpha, sigmaW
        #     )["postact"]**2).mean(-1)
        # )
        # ax.plot(MLP_live_arr, "o-", label="MLP live")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.grid(True)


## Direct RMT with regularisation

In [None]:
import RMT

import numpy as np
import torch
import matplotlib.pyplot as plt

sing_vals = np.logspace(-10, 10, 1000)
log_pdfs = RMT.jac_cavity_svd_log_pdf(
    sing_vals,
    1.5,
    1.5,
    num_doublings=6,
    progress=True,
    # regulariser=True,
    num_chi_realisations=1000,
)

plt.plot(sing_vals, log_pdfs, alpha=log_pdfs.shape[-1]**-0.5)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Singular Value")
plt.ylabel("Log PDF")
plt.title("Log PDF of Singular Values")
plt.show()

In [None]:
plt.plot(sing_vals, np.median(log_pdfs, axis=-1))
plt.xscale("log")
plt.yscale("log")

In [None]:
torch.rand(4) * torch.rand(3, 4) * torch.rand(2, 3, 4)

In [None]:
import RMT

import numpy as np
import torch
import matplotlib.pyplot as plt

sing_vals = np.logspace(-10, 10, 1000)
log_pdfs = RMT.jac_cavity_svd_log_pdf(
    sing_vals,
    1.5,
    1.5,
    num_doublings=6,
    progress=True,
    regulariser=True,
)

plt.plot(sing_vals, log_pdfs)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Singular Value")
plt.ylabel("Log PDF")
plt.title("Log PDF of Singular Values")
plt.show()