# Collapse

In [None]:
#!pip install einops jaxtyping wandb
#!git clone https://github.com/tdooms/bilinear-decomposition.git

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import itertools
import einops
from collections import defaultdict
import copy
from einops import einsum
from kornia.augmentation import RandomGaussianNoise

from decomp.model import Model
from decomp.datasets import MNIST
from decomp.plotting import plot_explanation, plot_eigenspectrum

device = 'cpu'
mode = 'mnist'
fast_mode = True
noises = [1e-6,0.01,0.02,0.04,0.08,0.16,0.32,0.64,1.28,2.56,5.12,10.24,20.48] # 0 not accepted by std sweep. take 1e-6
assert len(noises) == 13
num_models = 40
epochs = 10


d_input = 3072 if mode == 'cifar' else 784

if fast_mode:
    epochs = 10
    num_models = 1
    noises = noises[:1]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if False:
    model = Model.from_config(epochs=20).to(device)
    train, test = MNIST(train=True, device=device), MNIST(train=False, device=device)
    metrics = model.fit(train, test, RandomGaussianNoise(std=0.4))

    vals, vecs = model.decompose()
    px.imshow(vecs[0, -1].view(28, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

# Run

### Main setup

#### Init

In [3]:
#@title setup
import torch
import numpy as np
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

from transformers import PretrainedConfig, PreTrainedModel
from jaxtyping import Float
from tqdm import tqdm
from pandas import DataFrame
from einops import *

from decomp.components import Linear
from decomp.model import FFNModel, _Config

# Create and train the model. 10 minutes on cpu, 3 mins on gpu
from decomp.datasets import MNIST, CIFAR10, _CIFAR10
from kornia.augmentation import RandomGaussianNoise
from tqdm import tqdm
#model = Model.from_config(epochs=20).to(device)
#metrics = model.fit(train, test, RandomGaussianNoise(std=0.4))
#device='cpu'
#r_train, r_test = MNIST(train=True, device=device), MNIST(train=False, device=device)
cifar_train, cifar_test = _CIFAR10(train=True, device=device), _CIFAR10(train=False, device=device)
mnist_train, mnist_test = MNIST(train=True, device=device), MNIST(train=False, device=device)
# TODO: setup configs to pass through all models.

'''
fast_mode = True
noises = [0.0,0.01,0.02,0.04,0.08,0.16,0.32,0.64,1.28,2.56,5.12,10.24,20.48]
assert len(noises) == 13
num_models = 13
epochs = 20

if fast_mode:
    epochs = 5
    num_models = 1
    noises = noises[:num_models]

'''

mnist_config = _Config(
    lr=1e-3,
    wd=0.5,
    epochs=epochs,
    batch_size=2048,
    d_hidden=256,
    d_input=3072,
    d_output=10,
    bias=True
)

cifar_config = _Config(
    lr=1e-3,
    wd=0.5,
    epochs=epochs,
    batch_size=2048,
    d_hidden=256,
    d_input=784,
    d_output=10,
    bias=True
)


def e2e_model_setup(data = 'mnist', noise_sweep = False):
    if data == 'mnist':
        r_train, r_test = MNIST(train=True, device=device), MNIST(train=False, device=device)
        d_input = 784
    elif data == 'cifar':
        r_train, r_test = _CIFAR10(train=True, device=device), _CIFAR10(train=False, device=device)
        d_input = 3072
    
    if noise_sweep:
        models = [FFNModel.from_config(
            lr=1e-3,
            wd=0.5,
            epochs=epochs,
            batch_size=2048,
            d_hidden=256, # for cifar
            d_input=d_input, # for cifar
            d_output=10,
            bias=True
        ).to(device) for i in range(len(noises))]
        histories = []
        for i in tqdm(range(num_models)):
            histories.append(
                models[i].fit(cifar_train, cifar_test,
                            RandomGaussianNoise(std=noises[i]), disable=False))
    # don't bother with this
    pass




models = [FFNModel.from_config(
    lr=1e-3,
    wd=0.5,
    #epochs=20,
    epochs=epochs,
    batch_size=2048,
    d_hidden=256,
    #d_hidden=512, # for cifar
    #d_input=784,
    d_input=d_input, # for cifar
    d_output=10,
    bias=True
).to(device) for i in range(num_models)]


#models = [FFNModel.from_config_obj(cifar_config)]
histories = []
train = cifar_train if mode == 'cifar' else mnist_train
test = cifar_test if mode == 'cifar' else mnist_test

#print(train.device)
for i in tqdm(range(num_models)):
    histories.append(
        models[i].fit(train, test,
                      RandomGaussianNoise(std=noises[i]), disable=False))

relu_valaccs = [histories[i]['val/acc'][epochs-1].item() for i in range(num_models)]
print(relu_valaccs)

Files already downloaded and verified
Files already downloaded and verified


train/loss: 0.223, train/acc: 0.939, val/loss: 0.223, val/acc: 0.937: 100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
100%|██████████| 1/1 [00:08<00:00,  8.21s/it]

[0.9373999834060669]





#### Else

In [4]:
# for session 2
# law of total covariance. Need: X: (60000, 784), f(X)=models[i](X), Y: (60000) each in [0...9]. Loop over Z in range(10)
# Filter by Y==i
if False:
    # accuracy = lambda y_hat, y: (torch.argmax(y_hat, dim=-1) == y).float().mean()
    import numpy as np

    # Sample data
    # centered_x would be your input features, shape: (60000, 784)
    # centered_y would contain the class labels, shape: (60000, 1) or (60000,)
    # For example, let's assume centered_y has shape (60000,)
    x = r_test.x.cpu().flatten(start_dim=1).numpy()
    centered_x = x - x.mean(dim=0, keepdims=True)

    fx = models[0](x).numpy()
    centered_fx = fx - fx.mean(dim=0, keepdims=True)

    #centered_nfx = centered_fx.numpy()
    # Flatten centered_y if necessary
    #if centered_y.ndim > 1:
    #    centered_y = centered_y.flatten()

    # Step 1: Get unique class labels
    class_labels = np.unique(centered_fx)

    # Step 2: Calculate covariance for each class
    covariances = []
    for label in class_labels:
        print(label)
        # Step 3: Get indices of current class
        indices = np.where(centered_fx == label)[0]
        
        # Subset of X corresponding to the current class
        X_subset = centered_x[indices]
        
        # If there are fewer than 2 samples, skip the computation for this class
        if X_subset.shape[0] < 2:
            continue
        
        # Compute covariance for this subset
        cov_matrix = np.cov(X_subset, rowvar=False)  # Shape will be (784, 784)
        covariances.append(cov_matrix)

    # Step 4: Average the covariances across classes
    # Convert the list of covariance matrices to an array for averaging
    covariances_array = np.array(covariances)

    # We assume here that the covariances array has shape (num_classes, 784, 784)
    # We can compute the mean along the first axis (the classes)
    if covariances_array.shape[0] > 0:
        total_covariance = np.mean(covariances_array, axis=0)

    print(total_covariance.shape)  # This will output (784, 784)

In [5]:
def compute_cov(model, X, labels=None, mode='gaussian'):
    '''
    takes args:
    model : (probably relu, gelu, glu) 1L network
    X     : input dataset
    labels: target class info, relevant only if using gauss_mixture
    
    Computes Cov(X, model(X)) assuming X gaussian or gaussian mixture.
    Gaussian mixture assumption uses law of total covariance to compute total covariance
    Note these are cross covariance terms: which just amounts to computing the elementwise
    AKA computing correlations between every pair of (in_pixel, out_neuron_act) combination
    '''
    if mode == 'gaussian':
        pass
    elif mode == 'gauss_mixture':
        assert labels is not None, ValueError('Need to provide labels for gaussian mixture distribution!')
        assert X.shape[0] == labels.shape[0], ValueError('Batchsize must match for inputs and labels!')
        
        pass

In [6]:
if False:
    #i, j = np.tril_indices(784)
    acov = np.tril(cov)
    #acov.shape
    eps = 1e-10
    acov_reg = acov + eps * np.eye(acov.shape[0])

### Filling out `all_noise_data` dict
Done!!

In [7]:
# todo: implement validation accuracy for all models into dict
# todo: show all numbers in dataframe

# add other keys as constructed
data = {'models': models, 'noises': noises, 'test_acc': relu_valaccs}

In [8]:
# test set data. Should be shape (10000, 784)
# expand into outer: 10000, 307720 ~ 3B
#r_test = cifar_test
#r_train = cifar_train
test_x = test.x.flatten(start_dim=1)#.numpy()
test_y = test.y#.cpu().numpy()
print(test_x.shape, test_y.shape)
#quad_out = quad(test_x)

# getting stuck on reshaping gamma. 

torch.Size([10000, 784]) torch.Size([10000])


In [9]:
# take first model, 0 noise for now
from torch_polyapprox.ols import ols

def test_approx_model(model, mu=None, sig=None, order='linear', debug_mode='False'):
    #assert type(). y_hat.argmax(dim=-1) works for Torch, not for np.ndarray
    accuracy = lambda y_hat, y: (y_hat.argmax(dim=-1) == y).float().mean()
    _device = model.device
    W1 = model.w_e.detach()
    W2 = model.w_u.detach()
    b1 = model.embed.bias.detach()#.cpu().data.numpy()
    b2 = model.head.bias.detach()#.cpu().data.numpy()

    # FVU computation not implemented for non-zero means
    # Got the cov symmetric bug. Running with None for now
    print(f'5.7 precheck, mu type {type(mu)}, sig type {type(sig)}')
    approx = ols(W1, b1, W2, b2,
                act='relu',
                mean=mu,
                cov=sig,
                order=order,
                debug_mode=debug_mode)
    print(f'type(test_x): {type(test_x)}, type(test_y): {type(test_y)}')
    assert type(test_x) == type(test_y), ValueError('Must have matching types!')
    print(f'5.7 check')
    #torch_test_x = torch.Tensor(test_x)
    #torch_test_y = torch.Tensor(test_y)
    #print(torch_test_x.shape, torch_test_y.shape)
    # Current error. accr
    print('5.8 check')
    #fwd = torch.Tensor(approx(test_x))
    fwd = approx(test_x)
    print('5.8a')
    acc = accuracy(fwd, test_y).item()
    print('5.8b finish')
    print('5.9 check')
    return acc, approx

def test_approx_list(model_list, mu=None, sig=None, order='linear', debug_mode=False):
    '''expects
    model_list: list of models (presumably over noise)
    mu: mean of data, None -> 0-centered.
    cov: cov of data, None -> identity
    '''
    accs = []
    approximations = []
    for i in tqdm(range(len(model_list))):
        acc, approx = test_approx_model(model_list[i], mu=mu, sig=sig, order=order, debug_mode=debug_mode)
        accs.append(acc)
        approximations.append(approx)
    return accs, approximations

In [10]:

# compute norms, warning likely incorrect. TODO: redo with ols.py. Takes 3 mins
from extra.mc import compute_E_xf, monte_carlo_E_xf

def compute_norms_mc_old(models_list, N_samples=10_000):
    norms_dict = {'L1': [], 'L2': [], 'Linf': []}
    for i in tqdm(range(len(models_list))):
        W1 = models[i].w_e.cpu().numpy()
        W2 = models[i].w_u.cpu().numpy()
        b1 = models[i].embed.bias.cpu().data.numpy()
        b2 = models[i].head.bias.cpu().data.numpy()
        
        exact = compute_E_xf(W1, W2, b1)
        est = monte_carlo_E_xf(W1, W2, b1, N_samples=N_samples)
        
        diff = exact - est
        norms_dict['L1'].append(float(np.linalg.norm(diff, ord=1)))
        norms_dict['L2'].append(float(np.linalg.norm(diff)))
        norms_dict['Linf'].append(float(np.linalg.norm(diff, ord=np.inf)))
    
add_norms = False
data['norms'] = {}
if add_norms:
    data['norms'] = compute_norms_mc_old(models)

In [11]:
print(mode)


mnist


In [12]:
#takes 9s for lin + quad on MNIST. takes ~2-3 minutes for lin+quad on CIFAR, per noise level.
# One last test, Quad on CIFAR with corrected shapes, at noise0.
train_data = train.x.flatten(start_dim=1)
mu = train_data.mean(dim=0, keepdims=True)
#mu = np.mean(train_data, axis=0)
centralized_data = train_data - mu
#cov = np.cov(centralized_data, rowvar=False)
cov = centralized_data.T @ centralized_data / (train_data.size(0) - 1)
#cov.shape, centralized_data.shape
#data['acc_lin01'], data['ols_lin01'] = test_approx_list(models, debug_mode=False) # type error dim.
data['acc_quad01'], data['ols_quad01'] = test_approx_list(models, order='quadratic', debug_mode=True) # type error dim.
# to debug: singular Cov matrix. Need to do a tril_indices. But where was it before?
#data['acc_linmc'], data['ols_linmc'] = test_approx_list(models, mu, cov, debug_mode=True)

  0%|          | 0/1 [00:00<?, ?it/s]

5.7 precheck, mu type <class 'NoneType'>, sig type <class 'NoneType'>
0. First handle linear case.
According to Nora, linear and quad should be split into separate cases, so only trust up to 4.
1. No cov provided, assuming identity. relevant shapes: W1 torch.Size([256, 784]), W1^T torch.Size([784, 256])
Computing preact_cov = W1 @ Id @ W1.T, shape ftorch.Size([256, 256])
Computing cross_cov = Id @ W1.T, shape ftorch.Size([784, 256])
2. Preactivation mean (from b1): torch.Size([256]), variance: torch.Size([256]), std: torch.Size([256])
3. Applying Stein's lemma to compute the cross-covariance of the input. Uses preact mean & std. Stores in output_cross_cov
Stein's lemma says that E[g(X)X^n] can be computed as a linear combination of E[g^(k)(X)] terms, kth derivatives
Important to remember!! n is rarely larger than 2, IMO Nora overly generalized this. We will nonetheless roll with it.
preact_std torch.Size([256])
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([256])
torch.Size(

  Cov_x = torch.Tensor([ # prev np. Need torch.Tensor rather than np.array. To fix.
  ]).to(device).T


Cov_x torch.Size([307720, 2, 2]), device cpu
Creating cross covariance matrix. xcov torch.Size([784, 256]), xcov[rows] torch.Size([307720, 256]), xcov[cols] torch.Size([307720, 256])
Attempting to produce matrix Mean_x: 
 Mean_x = [mu[rows] torch.Size([307720]), mu[cols] torch.Size([307720])].T
5.3 check
Attempting master_theorem. On device cpu. Takes args:
 Mean_x torch.Size([307720, 2]), Cov_x torch.Size([307720, 2, 2]), mean_y[..., None] torch.Size([256, 1]), XCov torch.Size([256, 307720, 2])
5.4 check
5.5 check
5.6 check, rows <class 'torch.Tensor'>
type(test_x): <class 'torch.Tensor'>, type(test_y): <class 'torch.Tensor'>
5.7 check
5.8 check
torch.Size([10000, 307720]) torch.Size([10, 307720])


100%|██████████| 1/1 [00:06<00:00,  6.92s/it]

5.8a
5.8b finish
5.9 check





In [13]:
mu.device, cov.device

(device(type='cpu'), device(type='cpu'))

In [14]:
data.keys()

dict_keys(['models', 'noises', 'test_acc', 'norms', 'acc_quad01', 'ols_quad01'])

In [15]:
data.keys()

dict_keys(['models', 'noises', 'test_acc', 'norms', 'acc_quad01', 'ols_quad01'])

In [16]:
#len(data['noises']), len(data['test_acc']), len(data['acc_lin01']), len(data['acc_linmc']), len(data['models'])
len(data['noises']), len(data['test_acc']), len(data['models']), len(data['acc_linmc'])

KeyError: 'acc_linmc'

In [None]:
import pandas as pd
df = pd.DataFrame(data, columns = ['noises', 'test_acc', 'acc_lin01', 'acc_quad01'])
df['diff_lin01'] = df['test_acc'] - df['acc_lin01']
df['diff_quad01'] = df['test_acc'] - df['acc_quad01']

df

# interesting. On CIFAR, quad01 acc is 15.93%, lin01 acc is 19.72%, test_acc is 37.33%. This is for 5 epochs
# Try again when trained on ~40 epochs, so that model is possibly 'more quadratic'. Could be
# 'statistics learned' is not high enough order yet. After N(0,std) sweep though.
# add column for lin/test gap, quad/test gap.

#### Add std N(0,std) sweep to linear, don't run quad sweep.

In [None]:
print(models[0].w_e.device)

In [None]:
if mode == 'cifar':
    data[f'acc_quad_N(0,1)'], data[f'ols_quad_N(0,1)'] = test_approx_list(models, order='quadratic', debug_mode=False)
else:
    print(f'mode needs to be mnist!! Mode: {mode}')

In [None]:
if mode == 'cifar':
    for z in noises:
        data[f'acc_lin_N(0,{z:.2f})'], data[f'ols_lin_N(0,{z:.2f})'] = test_approx_list(models,
        mu=mu*z, sig=z*torch.eye(d_input).to(device), debug_mode=False)
else:
    print(f'mode needs to be mnist!! Mode: {mode}')

In [None]:
data.keys()

In [None]:
keys_to_include = ['noises', 'test_acc', 'acc_lin01', 'acc_linmc'] + [f'acc_lin_N(0,{z:.2f})' for z in noises]

# Create DataFrame, excluding specific keys
df = pd.DataFrame({k: v for k, v in data.items() if k in keys_to_include})
df

### To do:
- add gaussian mixture fit to above table (self contained code)

In [None]:
def compute_cov(model, X, labels=None, mode='gaussian'):
    '''
    takes args:
    model : (probably relu, gelu, glu) 1L network
    X     : input dataset e.g. r_test.x.cpu().flatten(start_dim=1)
    labels: target class info, relevant only if using gauss_mixture, e.g. r_test.y.cpu().numpy()
    
    Computes Cov(X, model(X)) assuming X gaussian or gaussian mixture.
    Gaussian mixture assumption uses law of total covariance to compute total covariance
    Note these are cross covariance terms: which just amounts to computing the elementwise
    AKA computing correlations between every pair of (in_pixel, out_neuron_act) combination
    '''
    if mode == 'gaussian':
        pass
    elif mode == 'gauss_mixture':
        assert labels is not None, ValueError('Need to provide labels for gaussian mixture distribution!')
        assert X.shape[0] == labels.shape[0], ValueError('Batchsize must match for inputs and labels!')
        
        pass

In [None]:
# accuracy = lambda y_hat, y: (torch.argmax(y_hat, dim=-1) == y).float().mean().
# total_cov implemented here. Need to wrap into above method cleanly, principally implemented though. Next sesh: evaluate MNIST, CIFAr difference
# Guess: no difference?
if False:

    # Sample data
    # centered_x would be your input features, shape: (60000, 784)
    # centered_y would contain the class labels, shape: (60000, 1) or (60000,)
    # For example, let's assume centered_y has shape (60000,)
    x = r_test.x.cpu().flatten(start_dim=1)
    centered_x = (x - x.mean(dim=0, keepdims=True)).numpy()
    centered_x_np = x.numpy() - np.mean(x.numpy(), axis=0, keepdims=True)

    print(np.linalg.norm(centered_x - centered_x_np))
    #assert np.allclose(centered_x, centered_x_np)
    fx = models[0](x).detach()
    centered_fx = (fx - fx.mean(dim=0, keepdims=True)).numpy()

    class_labels = np.unique(r_test.y.cpu().numpy())
    #centered_nfx = centered_fx.numpy()
    # Flatten centered_y if necessary
    #if centered_y.ndim > 1:
    #    centered_y = centered_y.flatten()

    # Step 1: Get unique class labels
    #class_labels = np.unique(centered_fx) # wrong one, these are the outputs. We need to get labels from r_test.y

    # Step 2: Calculate covariance for each class
    covariances = []
    for label in class_labels:
        #print(label)
        # Step 3: Get indices of current class
        indices = np.where(centered_fx == label)[0]
        
        # Subset of X corresponding to the current class
        X_subset = centered_x[indices]
        
        # If there are fewer than 2 samples, skip the computation for this class
        if X_subset.shape[0] < 2:
            continue
        
        # Compute covariance for this subset
        cov_matrix = np.cov(X_subset, rowvar=False)  # Shape will be (784, 784)
        covariances.append(cov_matrix)

    # Step 4: Average the covariances across classes
    # Convert the list of covariance matrices to an array for averaging
    covariances_array = np.array(covariances)

    # We assume here that the covariances array has shape (num_classes, 784, 784)
    # We can compute the mean along the first axis (the classes)
    if covariances_array.shape[0] > 0:
        total_covariance = np.mean(covariances_array, axis=0)

    print(total_covariance.shape)  # This will output (784, 784)

### To do (NEW, tomorrow)
Don't run this cell rn
- Train cifar bilinear
- decompose with old methods using tdooms code
- visualize singular vectors of linear approx. Interpret them for MNIST
- SVD adv mask on MNIST, CIFAR. Figure out how to do it
- Implement `ols.py` for GLUs

In [None]:
# train cifar bilinear
import plotly.express as px
from decomp.model import Model
from decomp.datasets import _CIFAR10
from einops import einsum
from decomp.plotting import plot_explanation, plot_eigenspectrum

model = Model.from_config(
    lr=1e-3,
    wd=0.5,
    #epochs=20,
    epochs=25,
    batch_size=2048,
    #d_hidden=256,
    d_hidden=512, # for cifar
    #d_input=784,
    d_input=3072, # for cifar
    d_output=10,
    bias=False
).to(device)

#train, test = r_train, r_test = _CIFAR10...

metrics = model.fit(r_train, r_test)
l, r = model.w_lr[0].unbind()
b = einsum(model.w_u, l, r, "cls out, out in1, out in2 -> cls in1 in2")

# Symmetrize the tensor
b = 0.5 * (b + b.mT)

# Perform the eigendecomposition
vals, vecs = torch.linalg.eigh(b)

# Project the eigenvectors back to the input space
vecs = einsum(vecs, model.w_e, "cls emb comp, emb inp -> cls comp inp")

# Take the class (cls) for digit 0 and the last component (comp), which indicates the most positive eigenvalue
px.imshow(vecs[0, -1].view(96, 32).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

In [None]:
model_reg = Model.from_config(
    lr=1e-3,
    wd=0.5,
    #epochs=20,
    epochs=25,
    batch_size=2048,
    #d_hidden=256,
    d_hidden=512, # for cifar
    #d_input=784,
    d_input=3072, # for cifar
    d_output=10,
    bias=False
).to(device)

metrics = model_reg.fit(r_train, r_test, RandomGaussianNoise(std=0.4))

vals, vecs = model.decompose()
px.imshow(vecs[0, -1].view(96, 32).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

In [None]:
# understanding SVD
vals, vecs = model_reg.decompose()
px.imshow(vecs[0, -1].view(96, 32).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")


In [None]:
vecs.shape

In [None]:
rgb_image

In [None]:
# open problem for later: plot RGB separately along a nice color scheme that indicates the positive/negative parts well.
import matplotlib.pyplot as plt
# display image from first class in rgb. How?
reshaped_image = vecs[0,-1].reshape(3, 32, 32)

# Step 2: Transpose to (32, 32, 3)
# This will put the RGB channels in the last dimension
rgb_image = np.transpose(reshaped_image, (1, 2, 0)).numpy()
print(type(rgb_image))
# Step 3: Display the image
#plt.imshow(np.transpose(reshaped_(2, 0, 1))  # Ensure it's in the right channel order for plotting
fig = px.imshow(rgb_image.astype(np.float64), color_continuous_midpoint=0)  # Ensure the image is in float format for visualization
fig.show()

In [None]:
if True: #repeat code for no reason? See the top cell. Just reference `models` list.
    cifar_models = [FFNModel.from_config(
        lr=1e-3,
        wd=0.5,
        #epochs=20,
        epochs=20,
        batch_size=2048,
        #d_hidden=256,
        d_hidden=512, # for cifar
        #d_input=784,
        d_input=3072, # for cifar
        d_output=10,
        bias=True
    ).to(device) for i in range(13)]

    cifar_histories = []


    for i in tqdm(range(13)):
        cifar_histories.append(
            cifar_models[i].fit(cifar_train, cifar_test,
                        RandomGaussianNoise(std=noises[i]), disable=False))

    relu_cifar_valaccs = [cifar_histories[i]['val/acc'][19].item() for i in range(13)]
    print(relu_cifar_valaccs)

### Understand ols.py, linear case

In [None]:
from polyapprox.ols import ols

print(data.keys())

In [None]:
ex_linmc = data['ols_linmc'][0]
ex_linmc.beta

In [None]:
# train separate MNIST model.
ex_mnist_model = FFNModel.from_config(
    lr=1e-3,
    wd=0.5,
    #epochs=20,
    epochs=20,
    batch_size=2048,
    d_hidden=256,
    #d_hidden=512, # for cifar
    d_input=784,
    #d_input=3072, # for cifar
    d_output=10,
    bias=True
).to(device)


#models = [FFNModel.from_config_obj(cifar_config)]
histories = []




ex_mnist_history = ex_mnist_model.fit(mnist_train, mnist_test,
                    RandomGaussianNoise(std=noises[i]), disable=False)

#relu_valaccs = [histories[i]['val/acc'][19].item() for i in range(num_models)]
#print(relu_valaccs)
#mnist_model = 
print(ex_mnist_history['val/acc'][19].item())

In [None]:
#r_test = cifar_test
#r_train = cifar_train
test_x = mnist_test.x.cpu().flatten(start_dim=1)#.numpy()
test_y = mnist_test.y.cpu()#.numpy()
print(test_x.shape, test_y.shape)

ex_fit_acc, ex_lin = test_approx_model(ex_mnist_model)
ex_fit_qacc, ex_quad = test_approx_model(ex_mnist_model, order='quadratic')

In [None]:
import plotly.express as px
beta = ex_lin.beta
beta.shape

u, s, v = torch.svd(beta)
nu, ns, nv = np.linalg.svd(beta.numpy())
print(u.shape, s.shape, v.shape)
print(nu.shape, ns.shape, nv.shape)

#from kornia.augmentation import RandomGaussianNoise
#model = Model.from_config(epochs=20).to(device)
#metrics = model.fit(train, test, RandomGaussianNoise(std=0.4))

#vals, vecs = model.decompose() # vecs shape [10, 512, 784]. Using this how?
#px.imshow(vecs[0, -1].view(28, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

px.imshow(u[:,4].cpu().numpy().reshape((28,28)).T, color_continuous_midpoint=0, color_continuous_scale="RdBu")

In [None]:
print(s)
total_sum = np.sum(s)

pu = 100.0 * (total_sum - np.cumsum(s))/total_sum
pe = 100.0 * np.cumsum(s)/total_sum
pu2 = 100.0 * (np.sum(s**2) - np.cumsum(s**2))/np.sum(s**2)
pe2 = 100.0 * np.cumsum(s**2)/np.sum(s**2)
np.set_printoptions(precision=2, suppress=True)
print(pu)
print(pe)
print(pu2)
print(pe2)
#print(percentage_explained)

In [None]:
diag_s = np.diag(s)

adjusted_vh = diag_s @ v

component = 0
print(adjusted_vh[:,-1])

In [None]:
adjusted_v = s[..., None] * v
print(adjusted_v[:0])

In [None]:
# build off of bilinear viz code
from kornia.augmentation import RandomGaussianNoise
from decomp.model import Model
bl_mnist_model = Model.from_config(epochs=20).to(device)
bl_cifar_model = Model.from_config(epochs=20, d_input=3072).to(device)
mnist_metrics = bl_mnist_model.fit(mnist_train, mnist_test, RandomGaussianNoise(std=0.4))
cifar_metrics = bl_cifar_model.fit(cifar_train, cifar_test, RandomGaussianNoise(std=0.4))

#mnist_vals, mnist_vecs = bl_mnist_model.decompose()
#cifar_vals, cifar_vecs = bl_cifar_model.decompose()
#px.imshow(mnist_vecs[0, -1].view(28, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")
#px.imshow(cifar_vecs[0, -1].view(96, 32).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

In [None]:
#px.imshow(cifar_test.x[3].mean(dim=0).cpu(), color_continuous_scale='gray')
def view_image_idx(idx, version = 'cifar', color=True):
    if version == 'cifar':
        print(f'class = {cifar_test.y[idx]}')
        plt.imshow(cifar_test.x[idx].permute(1, 2, 0).cpu().numpy())
    elif version == 'mnist':
        print(f'class = {mnist_test.y[idx]}')
        px.imshow(mnist_test.x[idx].view(28, 28).cpu().numpy(), color_continuous_midpoint=0, color_continuous_scale="RdBu")
    else:
        pass

In [None]:
#view_image_idx(3)

### Quadratic case
Note that Master theorem only comes up in the quadratic case.

In [None]:
#ex_fit_acc, ex_lin = test_approx_model(ex_mnist_model)
acc, new_quad = test_approx_model(ex_mnist_model, order = 'quadratic', debug_mode=True) # 5.3 takes 5s, 5.6 takes 15s
'''8s gets to 5.8 check, (10000, 307720), (10, 307720).
42s: runs olsResult(x)
after 20s, prints shapes (10000, 307720), (10, 307720). Finishes after 20 more s. (40s total)

# This is for 784x10 mnist relu. For 3072x10 cifar relu, it took maybe 40 minutes total? Good to know.
# May be good to check random nets scaling laws.

With new 'more efficient' mult code, takes 2 minutes to perform

y += np.einsum('ijh,bi,bj ->bh', full_mat, x, x)

with full_mat (784, 784, 10) and x: (10000, 784). Weird. 
# Checking the code I saw print(torch_test_x.shape, torch_test_y.shape) ~ this means producing the quad approx itself took 20s. Evaluating took another 20s
# Interesting
#    acc = accuracy(approx(torch_test_x), torch_test_y).item()
'''

In [None]:
# OK interesting. its taking in torch tensor, but then call runs on np arrays. So its... converting?
accuracy = lambda y_hat, y: (y_hat.argmax(dim=-1) == y).float().mean()
torch_test_x = torch.Tensor(test_x)
torch_test_y = torch.Tensor(test_y)
print(torch_test_x.shape, torch_test_y.shape)
# Current error. accr
print('5.8 check')
acc = accuracy(approx(torch_test_x), torch_test_y).item()

In [None]:
new_quad.gamma.shape

In [None]:
def gamma_to_B(gamma_mat):
    gamma_entries = gamma_mat.shape[-1]
    #print(gamma_entries)
    row_dim = int(np.floor(np.sqrt(2*gamma_entries)))
    full_mat = np.zeros((row_dim, row_dim, gamma_mat.shape[0]))
    tril_indices = np.tril_indices(row_dim)
    
    full_mat[tril_indices] = gamma_mat.T
    full_mat = 0.5 * (full_mat + full_mat.transpose(1, 0, 2))
    return full_mat

def test_outer(x, gamma_mat):
    # expects x.shape (batch, 784)
    # expects gamma_mat.shape (10, 307720)
    if x.ndim == 1:
        x = np.expand_dims(x, axis=0)
    outer = np.einsum('bj,bk->bjk', x, x)
    #this is probably hugely inefficient. Check x shape [10000, 3072]
    #Becomes shape [10000, 3072, 3072] ~ 30B. ~ 120GB ram. Why though?
    rows, cols = np.tril_indices(x.shape[-1])
    print(outer[:, rows, cols].shape, gamma_mat.shape)
    return outer[:, rows, cols] @ gamma_mat.T

def test_inner(x, gamma_mat):
    # full_mat shape: in1, in2 h:
    if x.ndim == 1:
        x = np.expand_dims(x, axis=0)
    
    full_mat = gamma_to_B(gamma_mat)
    print(full_mat.shape, x.shape)
    prod = np.einsum('ijh,bi,bj ->bh', full_mat, x, x)
    #print(full_mat.shape, x.shape)
    return prod

assert np.allclose(test_inner(sample_x, new_quad.gamma), test_outer(sample_x, new_quad.gamma))
#print(z, z.shape)

Goals:
- Reshape new_quad.gamma into (784, 784, 10) B tensor
- assert unittest: `torch.allclose(outer[:, rows, cols] @ self.gamma.T, einsum('ijk, ...ij, ...ik -> ...i', B, x, x))`

In [None]:
'''

To make: tensor version of this
    def __call__(self, x: NDArray) -> NDArray:
        """Evaluate the linear model at the given inputs."""
        y = x @ self.beta + self.alpha

        if self.gamma is not None:
            outer = np.einsum('ij,ik->ijk', x, x)
            this is probably hugely inefficient. Check x shape [10000, 3072]
            Becomes shape [10000, 3072, 3072] ~ 30B. ~ 120GB ram. Why though?
            rows, cols = np.tril_indices(x.shape[1])
            print(outer[:, rows, cols].shape, self.gamma.shape)
            y += outer[:, rows, cols] @ self.gamma.T

        return y
'''

In [None]:
quad_acc = acc
print(f'relu acc: {ex_mnist_history['val/acc'][19].item()}, linear01 acc: {ex_fit_acc}, quad01 acc: {quad_acc}')