## Introduction

In [None]:
import RMT

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
torchvision.datasets.CIFAR100(root="fig/datasets", download=True)

In [None]:
dataset = torchvision.datasets.CIFAR10(root="fig/datasets", transform=torchvision.transforms.ToTensor())
dataset_data = np.array(dataset.data, dtype=np.float32)
dataset_targets = np.array(dataset.targets)
display(dataset.data.shape)
class_to_data = {name: dataset_data[dataset_targets == l] for name, l in dataset.class_to_idx.items()}
# show number of images (and tensor shape) per class
display({k: v.shape for k, v in class_to_data.items()})
layer_size = np.prod(dataset.data.shape[1:])
display(layer_size)
norm_data_dict = {
    k: (v / np.mean(v**2, tuple(range(1, len(v.shape))), keepdims=True)**0.5)
    for k, v in class_to_data.items()
}

## Demonstration on empirical MLPs

In [None]:
outputs_dict = {}

alpha = 1.5
sigma_W = 1.5

fp_norm = (
    RMT.MFT_map(RMT.q_star_MC(alpha, sigma_W)[-1], alpha, sigma_W, usetqdm=False)[
        "postact_sq_mean"
    ][0]
    ** 0.5
)

for label, data in tqdm(norm_data_dict.items()):
    num_items = 2 # len(data)
    # flatten the images and scale by fixed point norm
    x0 = data[:num_items].reshape(num_items, -1) * fp_norm
    xs = RMT.MLP(
        torch.tensor(x0),
        10,
        alpha,
        sigma_W,
        seed=42,
        fast=True,
        usetqdm=False,
    )["postact"]
    norms = (np.array([x0, *xs]) ** 2).mean(-1) ** 0.5

    outputs_dict[label] = norms

In [None]:
color_dict = {label: plt.cm.tab10(i) for i, label in enumerate(norm_data_dict.keys())}
for label, norms in outputs_dict.items():
    plt.plot(
        norms.mean(-1),
        "-o",
        # alpha=x0.shape[0] ** -0.5,
        color=color_dict[label],
        label=label,
    )
    plt.fill_between(
        np.arange(len(norms)),
        np.quantile(norms, 0.01, axis=-1),
        np.quantile(norms, 0.99, axis=-1),
        color=color_dict[label],
        alpha=0.5,
    )

plt.ylim([0.25, 0.75])
# plt.legend()
plt.show()

## MFT demonstration
takes longer than the MLP (about 7m33s on a Tesla M4 for CIFAR-10, 10 layers)

In [None]:
outputs_dict = {}

alpha = 1.5
sigma_W = 1.5

fp_norm = (
    RMT.MFT_map(RMT.q_star_MC(alpha, sigma_W)[-1], alpha, sigma_W, usetqdm=False)[
        "postact_sq_mean"
    ][0]
    ** 0.5
)

# gpu memory bottleneck due to `RMT.MFT_map::postact_samples`
max_num_items_in_batch = 100

for label, data in tqdm(norm_data_dict.items()):
    num_items = 100  # len(data)
    x0 = data[:num_items].reshape(num_items, -1) * fp_norm
    q0 = sigma_W**alpha * (abs(x0) ** alpha).mean(-1)
    MFT_maps = []
    for q0_chunk in tqdm(
        np.array_split(q0, max(1, num_items // max_num_items_in_batch), axis=0),
        leave=False,
    ):
        torch.manual_seed(42)
        val = RMT.MFT_map(
            q0_chunk,
            alpha,
            sigma_W,
            num_layers=10,
            agg_postact=lambda x: {
                "sq_mean": (x**2).mean(-1),
                "alpha_mean": (abs(x) ** alpha).mean(-1),
            },
            usetqdm=False,
        )
        MFT_maps.append(val)
    # average over Monte-Carlo realisations (axis=2), then concatenate chunks of inputs (axis=1)
    agg_stats = {
        k: np.concatenate([np.mean(mft[k], axis=2) for mft in MFT_maps], axis=1)
        for k in MFT_maps[0]
    }
    norms = np.array([(x0**2).mean(-1), *agg_stats["postact_sq_mean"]]) ** 0.5
    alpha_norms = np.array(
        [(abs(x0) ** alpha).mean(-1), *agg_stats["postact_alpha_mean"]]
    ) ** (1 / alpha)
    qs = agg_stats["q_val"]

    outputs_dict[label] = norms

In [None]:
# np.savez('fig/mixed_selectivity/outputs_MLP.npz', **outputs_dict)
outputs_dict = np.load('fig/mixed_selectivity/outputs_MLP.npz')

In [None]:
fig, ax = plt.subplots()

color_dict = {label: plt.cm.tab10(i) for i, label in enumerate(norm_data_dict.keys())}
for label, norms in outputs_dict.items():
    ax.plot(
        norms.mean(-1),
        "-o",
        # alpha=x0.shape[0] ** -0.5,
        color=color_dict[label],
        label=label,
    )
    ax.fill_between(
        np.arange(len(norms)),
        np.quantile(norms, 0.01, axis=-1),
        np.quantile(norms, 0.99, axis=-1),
        color=color_dict[label],
        alpha=0.5,
    )

ax.set(
    xlabel='layer',
    ylabel='postact_norm',
    # ylim=(0.25, 0.75),
)
# ax.legend()
plt.show()

# Testing the agg stats function for the cluster

In [None]:
from mixed_selectivity import MFT_map

import torch
import numpy as np
import matplotlib.pyplot as plt

torch.set_default_device("cpu")
result = MFT_map(
    dataset_name="CIFAR10",
    alpha=1.5,
    sigma_W=1.5,
    sigma_b=0.0,
    num_layers=10,
    chunk_size=100,
    seed=42,
    num_inputs=550,
)

print({k: v.shape for k, v in result.items()})

vals = result["postact_sq_mean"]
plt.plot(range(vals.shape[0]), vals.mean(axis=1), "-o")
plt.fill_between(
    range(vals.shape[0]),
    np.quantile(vals, 0.1, axis=1),
    np.quantile(vals, 0.9, axis=1),
    alpha=0.5,
)

# Results from the cluster

In [None]:
import numpy as np
from numpy.lib.npyio import NpzFile
import pandas as pd

def load_dataset_params(npz: NpzFile):
    # with np.load(fpath) as npz:
    #     npz_series = pd.Series(
    #         {k: npz[k] for k in tqdm(npz.files, desc=f"Loading {fpath.name}")},
    #         name="fname",
    #     )
    npz_params_index = pd.Index(
        npz.files,
        name="fname",
    )
    npz_params = npz_params_index.to_frame()
    npz_params['arr_name'] = npz_params_index.str.findall(r"([^;.]*)\.txt$").str[0]
    npz_params['arr_name'] = npz_params['arr_name'].where(
        lambda series: ~series.str.contains("="), "arr"
    )
    param_df = pd.DataFrame.from_records(
        npz_params_index.str.findall(
            r"(?P<key>[^=;.]+)=(?!\.txt)(?P<value>[^=;.]+)"
        ).map(dict),
        index=npz_params_index,
    ).astype(int)
    npz_params = (npz_params.join(param_df)
        .set_index(param_df.columns.tolist())
        .pivot(columns="arr_name", values="fname")
    )
    return npz_params

In [None]:
from pathlib import Path
import ipywidgets as widgets
import matplotlib.pyplot as plt



npz_path = Path(
    "/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/mixed_selectivity/dataset_name=CIFAR10;num_layers=50.npz"
)
with np.load(npz_path) as npz:
    df = load_dataset_params(npz)

@widgets.interact_manual(
    alpha100=df.index.get_level_values("alpha100").unique().tolist(),
    sigmaW100=df.index.get_level_values("sigmaW100").unique().tolist(),
    seed=df.index.get_level_values("seed").unique().tolist(),
    col=df.columns.tolist(),
    interquantile_range=(0.0, 1.0, 0.05),
)
def lineplot(alpha100, sigmaW100, seed, col, interquantile_range):
    arr_name = df.loc[(alpha100, sigmaW100, seed), col]
    with np.load(npz_path) as npz:
        vals = npz[arr_name]
    plt.plot(range(vals.shape[0]), vals.mean(axis=1), "-o")
    plt.fill_between(
        range(vals.shape[0]),
        *np.quantile(
            vals, ((1 - interquantile_range) / 2, (1 + interquantile_range) / 2), axis=1
        ),
        alpha=0.5,
    )
    print(f"Plotting {arr_name} with shape {vals.shape}")
    plt.title(
        rf"{col} ($\alpha$={alpha100/100}, $\sigma_W$={sigmaW100/100}, seed={seed})"
    )
    plt.xlabel("layer")
    plt.ylabel(col)
    plt.show()

Preprocess the data by aggregating over the input images (loading the npz takes about 10m), saving time loading the phase diagram.

In [None]:
# aggregate over the input images
from tqdm.auto import tqdm
from collections import defaultdict
from pathlib import Path

npz_path = Path(
    "/import/silo3/wardak/extended-criticality-dnn/random_dnn/fig/mixed_selectivity/dataset_name=CIFAR10;num_layers=50.npz"
)


def interquantile_range(x, iqr=0.5):
    return np.subtract(*np.quantile(x, (1 - iqr) / 2, (1 + iqr) / 2), axis=-1)


agg_dict = defaultdict(dict)
with np.load(npz_path) as npz:
    for f in tqdm(npz.files):
        arr = npz[f]
        for aggfn in [np.mean, np.std, np.median, np.max, np.min]:
            agg_dict[aggfn.__name__][f] = aggfn(arr, axis=-1)
    for aggfn_name in agg_dict:
        np.savez(
            npz_path.with_stem(f"{npz_path.stem}_{aggfn_name}"), **agg_dict[aggfn_name]
        )

In [None]:
import PIL.Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import multiprocessing as mp
import io
from functools import partial


def plot_frame(npz_path: Path, stat: str, layer: int, norm=None):
    with np.load(npz_path) as npz:
        series = load_dataset_params(npz)[stat].map(lambda k: npz[k][layer])
    arr = (
        series.groupby(["alpha100", "sigmaW100"]).mean().reset_index().to_numpy()
        / (100, 100, 1)
    ).T
    fig, ax = plt.subplots()
    contour = ax.tricontourf(*arr, levels=30, norm=norm)
    fig.colorbar(contour, ax=ax)
    ax.set(
        title=f"Layer {layer}: {stat}",
        xlabel=r"$\alpha$",
        ylabel=r"$\sigma_W$",
    )
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    plt.close(fig)
    return buf


with mp.Pool(8) as pool:
    frames = [
        PIL.Image.open(buf)
        for buf in tqdm(
            pool.imap(
                partial(
                    plot_frame,
                    npz_path.with_stem(f"{npz_path.stem};std"),
                    "postact_sq_mean",
                    # norm=mcolors.SymLogNorm(1e-10),
                    norm=mcolors.Normalize(0, 0.01),
                ),
                range(50),
            ),
            total=50,
        )
    ]
gif_path = npz_path.with_stem(f"{npz_path.stem};norm_std;postact_sq_mean").with_suffix(
    ".gif"
)
gif_path.parent.mkdir(parents=True, exist_ok=True)
frames[0].save(gif_path, save_all=True, append_images=frames[1:], duration=1000, loop=0)