In [None]:
import RMT

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
torchvision.datasets.CIFAR10(root="fig/datasets", download=True)

In [None]:
dataset = torchvision.datasets.CIFAR10(root="fig/datasets", transform=torchvision.transforms.ToTensor())
dataset_data = np.array(dataset.data, dtype=np.float32)
dataset_targets = np.array(dataset.targets)
display(dataset.data.shape)
class_to_data = {name: dataset_data[dataset_targets == l] for name, l in dataset.class_to_idx.items()}
# show number of images (and tensor shape) per class
display({k: v.shape for k, v in class_to_data.items()})
layer_size = np.prod(dataset.data.shape[-2:])
display(layer_size)
norm_data_dict = {
    k: (v / np.mean(v**2, tuple(range(1, len(v.shape))), keepdims=True)**0.5)
    for k, v in class_to_data.items()
}

In [None]:
outputs_dict = {}

alpha = 1.5
sigma_W = 1.5

for label, data in tqdm(norm_data_dict.items()):
    num_items = len(data)
    fp_norm = (
        RMT.MFT_map(RMT.q_star_MC(alpha, sigma_W)[-1], alpha, sigma_W, usetqdm=False)[
            "postact_sq_mean"
        ][0]
        ** 0.5
    )
    x0 = data[:num_items].reshape(num_items, -1) * fp_norm
    xs = RMT.MLP(
        torch.tensor(x0),
        10,
        alpha,
        sigma_W,
        seed=42,
        fast=True,
        usetqdm=False,
    )["postact"]
    norms = (np.array([x0, *xs]) ** 2).mean(-1) ** 0.5

    outputs_dict[label] = norms

In [None]:
color_dict = {label: plt.cm.tab10(i) for i, label in enumerate(norm_data_dict.keys())}
for label, norms in outputs_dict.items():
    plt.plot(
        norms.mean(-1),
        "-o",
        # alpha=x0.shape[0] ** -0.5,
        color=color_dict[label],
        label=label,
    )
    plt.fill_between(
        np.arange(len(norms)),
        np.quantile(norms, 0.01, axis=-1),
        np.quantile(norms, 0.99, axis=-1),
        color=color_dict[label],
        alpha=0.5,
    )

plt.ylim([0.25, 0.75])
# plt.legend()
plt.show()