In [None]:
import sys
sys.path.append("..")
from modules.utils import generate_matrix_close_to_isometry, generate_matrix_far_from_isometry, isometry_gap, ortho_gap
from modules.models import MLPWithBatchNorm, SinAct, CustomNormalization
from torch import nn
import torch
from torch.func import jacfwd
from tqdm import tqdm
import pandas as pd
import numpy as np
import math
import matplotlib
import itertools
from constants import *
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import transforms, datasets
matplotlib.rcParams["figure.dpi"] = 80
# matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})
sns.set(rc={"figure.dpi":80, 'savefig.dpi':80})
sns.set_theme()
sns.set_context('paper')
palette = sns.color_palette("tab10")
style = {"grid.linestyle": ":", 
        "border.color": "black",
       "axes.edgecolor": "black",
       "xtick.bottom": "True",
       "xtick.top": "True",
       "ytick.left": "True",
       "ytick.right": "True",
       "xtick.direction": "in",
       "ytick.direction": "in"}

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

d = 100
n = 100

root_dir = 'FILL_HERE'
train_set = datasets.CIFAR10(root=root_dir, train=True, download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=n, shuffle=True)
test_x, test_y = next(iter(train_loader))
test_x = test_x.cuda().flatten(1, -1)
test_y = test_y.cuda()

In [None]:
ds = [50, 100]
inits = ['orthogonal']
act_names = ['identity']
depths = [100]
runs = list(itertools.product(*[inits, act_names, depths, ds]))

In [None]:
df = []

C = 0.015
for run in tqdm(runs):
    init_type, act_name, L, d = run
    n = d
    f = 0.005
    test_x = (1-f) * torch.eye(d).cuda() + f * (torch.rand(d).outer(torch.rand(d))).cuda()
    test_y = torch.randint(low=0, high=d, size=(n,)).cuda()
    ig0 = isometry_gap(test_x).item()
    print(ig0)
    for expi in range(10):
        activation = ACTIVATIONS[act_name]
        model = MLPWithBatchNorm(input_dim=d, 
                                output_dim=d, 
                                num_layers=L, 
                                hidden_dim=d, 
                                norm_type='bn',
                                bias=False,
                                order='norm_act',
                                force_factor=1.0,
                                mean_reduction=False,
                                activation=activation,
                                exponent=0,
                                save_hidden=True).cuda()

        model.reset_parameters(init_type, gain=GAINS[act_name])

        model.zero_grad()
        y_pred = model(test_x)
        outputs = model.hiddens
        for layer_num in range(L):
            upperbound = ig0 * np.exp(-layer_num / (C * (d**2)*(1 + d*ig0)))
            gap = isometry_gap(outputs[f'fc_{layer_num}']).item()
            df.append({
                'isogap': gap,
                'Activation': act_name,
                'Initialization': init_type,
                'Layer': layer_num,
                'd': d,
                'upperbound': upperbound
                })
df = pd.DataFrame(df)

In [None]:
edited_df = df.replace('orthogonal','Orthogonal').replace('xavier_normal', 'Normal')
edited_df = edited_df[edited_df['Activation'] == 'identity']
palette = sns.color_palette()
layers = np.arange(L)


sns.set_style('darkgrid', style)
fig, ax = plt.subplots(dpi=200, figsize=(4,2), ncols=1)
sns.lineplot(edited_df, x='Layer', y='isogap', ax=ax, hue='d', palette=palette)
sns.lineplot(edited_df, x='Layer', y='upperbound', hue='d', linestyle='--', ax=ax, palette=palette, legend=True)

# ax.set_yscale('log')
ax.set_yscale('log')
ax.set_ylabel(r'$\phi(X_\ell)$')
ax.set_xlabel(r'$\ell$')
ax.set_ylim(1e-6)
ax.legend(ax.get_legend().legend_handles, ['50', '100', '50 (theory)', '100 (theory)'], frameon=False, title='Width')
sns.move_legend(ax, "center left", bbox_to_anchor=(1.1, 0.5))

leg = ax.get_legend()

leg.legend_handles[2].set_linestyle('--')
leg.legend_handles[3].set_linestyle('--')
fig.tight_layout()
fig.show()