In [25]:
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import itertools
import einops
from collections import defaultdict
import copy
from einops import einsum
from kornia.augmentation import RandomGaussianNoise

from decomp.model import Model, FFNModel, _Config
from decomp.datasets import MNIST, FMNIST, _CIFAR10
from decomp.plotting import plot_explanation, plot_eigenspectrum

from torch_polyapprox.ols import ols

In [74]:
# init configs
device = 'cpu'
noise_std = 2.0
epochs = 20
dataset = 'mnist' # or cifar

if dataset == 'mnist':
    d_input = 784
    train, test = MNIST(train=True, device=device), MNIST(train=False, device=device)
elif dataset == 'fmnist':
    d_input = 784
    train, test = FMNIST(train=True, device=device), FMNIST(train=False, device=device)
else: # 'cifar'
    d_input = 3072
    train, test = _CIFAR10(train=True, device=device), _CIFAR10(train=False, device=device)

In [100]:
# bilinear baseline model. Decompose using tdooms/michael methods
bilinear_model = Model.from_config(epochs=20, bias=True).to(device)
bilinear_metrics = bilinear_model.fit(train, test, RandomGaussianNoise(std=noise_std))

train/loss: 1.000, train/acc: 0.726, val/loss: 0.481, val/acc: 0.931: 100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


In [76]:
# relu mnist model to approximate and decompose
relu_model = FFNModel.from_config(
            lr=1e-3,
            wd=0.5,
            epochs=epochs,
            batch_size=2048,
            d_hidden=256, # for cifar
            d_input=d_input, # for cifar
            d_output=10,
            bias=True
).to(device)

relu_metrics = relu_model.fit(train, test, RandomGaussianNoise(std=noise_std))

train/loss: 0.979, train/acc: 0.706, val/loss: 0.442, val/acc: 0.912: 100%|██████████| 20/20 [00:14<00:00,  1.39it/s]


In [77]:
W1 = relu_model.w_e.detach()
W2 = relu_model.w_u.detach()
b1 = relu_model.embed.bias.detach()#.cpu().data.numpy()
b2 = relu_model.head.bias.detach()

print(type(W1), type(W2), type(b1), type(b2))
lin_01 = ols(W1, b1, W2, b2, act='relu')
quad01 = ols(W1, b1, W2, b2, act='relu', order='quadratic')

<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([256])
torch.Size([256])
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([256])
torch.Size([256])


In [78]:
print(quad01.gamma.shape)

torch.Size([10, 307720])


In [99]:
def test_model_acc(model, test_set=test):
    accuracy = lambda y_hat, y: (y_hat.argmax(dim=-1) == y).float().mean()
    fwd = model(test_set.x.flatten(start_dim=1))
    return accuracy(fwd, test_set.y).item()

def np_gamma_to_B(gamma_mat):
    gamma_entries = gamma_mat.shape[-1]
    #print(gamma_entries)
    row_dim = int(np.floor(np.sqrt(2*gamma_entries)))
    full_mat = np.zeros((row_dim, row_dim, gamma_mat.shape[0]))
    tril_indices = np.tril_indices(row_dim)
    
    full_mat[tril_indices] = gamma_mat.T
    full_mat = 0.5 * (full_mat + full_mat.transpose(1, 0, 2))
    return full_mat.permute(2, 0, 1)

def gamma_to_B(gamma_mat):
    gamma_entries = gamma_mat.shape[-1]
    print(gamma_entries)
    row_dim = math.floor((2*gamma_entries)**0.5)
    full_mat = torch.zeros((gamma_mat.shape[0], row_dim, row_dim))
    tril_indices = torch.tril_indices(row_dim, row_dim)
    
    full_mat[:, tril_indices[0], tril_indices[1]] = gamma_mat
    full_mat = 0.5 * (full_mat + full_mat.mT)
    return full_mat

B = gamma_to_B(quad01.gamma)

qvals, qvecs = torch.linalg.eigh(B)
px.imshow(qvecs[0, -1].view(-1,28).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

307720


In [102]:
bilinear_model

Model(
  (embed): Linear(
    in_features=784, out_features=256, bias=False
    (gate): Identity()
  )
  (blocks): ModuleList(
    (0): Bilinear(
      in_features=256, out_features=512, bias=True
      (gate): Identity()
    )
  )
  (head): Linear(
    in_features=256, out_features=10, bias=False
    (gate): Identity()
  )
  (criterion): CrossEntropyLoss()
)

In [101]:
bvals, bvecs = bilinear_model.decompose()
px.imshow(bvecs[0, -1].view(-1, 28).cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

In [80]:
def test_inner(x, quad=quad01):
    if x.ndim == 1:
        x = np.expand_dims(x, axis=0)
    
    full_mat = gamma_to_B(quad.gamma)
    print(full_mat.shape, x.shape)
    prod = torch.einsum('hij,bi,bj->bh', full_mat, x, x)
    return prod

def test_outer(x, quad=quad01):
    outer = torch.einsum('ij,ik->ijk', x, x)
    rows, cols = torch.tril_indices(x.shape[1], x.shape[1])
    print(outer[:, rows, cols].shape, quad.gamma.shape)
    return outer[:, rows, cols] @ quad.gamma.T

In [81]:
x = test.x.flatten(start_dim=1)[:100,:]
torch.allclose(test_inner(x), test_outer(x), atol=1e-5)

307720
torch.Size([10, 784, 784]) torch.Size([100, 784])
torch.Size([100, 307720]) torch.Size([10, 307720])


True

In [None]:
lin_01_acc = test_model_acc(lin_01)
quad01_acc = test_model_acc(quad01)
relu_acc = test_model_acc(relu_model)
print(f'{relu_acc:.4f}, {lin_01_acc:.4f}, {quad01_acc:.4f}')

torch.Size([10000, 307720]) torch.Size([10, 307720])
0.9537, 0.8834, 0.9490


In [None]:
def decompose(self):
    """The function to decompose a single-layer model into eigenvalues and eigenvectors."""
    
    # Split the bilinear layer into the left and right components
    l, r = self.w_lr[0].unbind()
    
    # Compute the third-order (bilinear) tensor
    b = einsum(self.w_u, l, r, "cls out, out in1, out in2 -> cls in1 in2")
    
    # Symmetrize the tensor
    b = 0.5 * (b + b.mT)

    # Perform the eigendecomposition
    vals, vecs = torch.linalg.eigh(b)
    
    # Project the eigenvectors back to the input space
    vecs = einsum(vecs, self.w_e, "cls emb comp, emb inp -> cls comp inp")
    
    # Return the eigenvalues and eigenvectors
    return vals, vecs

In [15]:
a, b, c = quad01.alpha, quad01.beta, quad01.gamma
print(a.shape, b.shape, c.shape)

torch.Size([10]) torch.Size([784, 10]) torch.Size([10, 307720])
