Imports:

In [None]:
import copy
from collections import OrderedDict
from dataclasses import dataclass
from itertools import permutations
from typing import Union

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
from torch import nn, optim
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Helper functions/classes:

In [None]:
Id = lambda x: x


@dataclass
class SymmDataGenerator:
    n_features: int
    group_elems: list
    labeller: callable
    domain: str = 'hypercube'
    noise: float = None
    classification: bool = False
    cutoff: float = 1.
    # whether to explicitly input all the elements in an orbit or have the symmetry be in the data only on average
    symmetrise_expl: bool = True
    postprocess_func: callable = Id
    # required in cases where the postprocessing adds/removes features
    n_base_features: int = None

    def __call__(self, n_samples: int, *args, **kwargs) -> tuple:
        n_base_samples = n_samples // (len(self.group_elems) if self.symmetrise_expl else 1)

        # need to be finnicky about what gets labelled when because the postprocessed data might not have a group
        # representation - the symmetrisation works on the base data
        base_data = self.gen_asymm_data(n_base_samples, postprocess=False)

        labels = self.labeller(self.postprocess_func(base_data))

        return self.symmetrise(base_data, labels) if self.symmetrise_expl else (base_data, labels.reshape(-1, 1))

    def gen_asymm_data(self, n_samples: int, postprocess: bool = True) -> np.ndarray:
        n_features = self.n_base_features if self.n_base_features is not None else self.n_features
        if self.domain == 'hypercube':
            data = 2 * (np.random.rand(n_samples, n_features) - 0.5)

            return data if not postprocess else self.postprocess_func(data)

        raise NotImplementedError('Still need to implement more general domains!')

    def symmetrise(self, base_data: np.ndarray, labels: np.ndarray = None, noise: float = 0.,
                   postprocess: bool = True) -> Union[tuple, np.ndarray]:
        if self.noise is not None and noise == 0.:
            noise = self.noise
        data = np.concatenate([group_elem(base_data) for group_elem in self.group_elems])
        data = data if not postprocess else self.postprocess_func(data)
        if labels is not None:
            labels = np.concatenate([labels] * len(self.group_elems))

            if len(labels.shape) == 1:
                labels = labels.reshape(-1, 1)

            labels += np.random.randn(*labels.shape) * noise

            # TODO - make sure this makes sense in noisy cases
            if self.classification:
                labels = labels > self.cutoff

            return data, labels

        return data


AbsXSymmGenerator = SymmDataGenerator(1, [Id, lambda x: -x], lambda x: abs(x))
XSquareSymmGenerator = SymmDataGenerator(1, [Id, lambda x: -x], lambda x: x ** 2)
PermXYSymmGenerator = SymmDataGenerator(2, [Id, lambda x: x[:, [1, 0]]], lambda x: x.sum(axis=-1))


def poly_kernel(features, order):
    # returns 1,x,y,x**2,y**2,xy,etc.
    n_samples, n_features = features.shape
    kernelised_features = np.empty((n_samples, (order + 1) * (order + 2) // 2))

    for i in range(order + 1):
        for j in range(i + 1):
            kernelised_features[:, i * (i + 1) // 2 + j] = features[:, 0] ** j * features[:, 1] ** (i - j)

    return kernelised_features


PolyXYSymmGenerator = lambda order: SymmDataGenerator(n_features=(order + 1) * (order + 2) // 2, n_base_features=2,
                                                      group_elems=[Id, lambda x: x[:, [1, 0]]],
                                                      labeller=lambda x: x.sum(axis=-1),
                                                      postprocess_func=lambda features: poly_kernel(features, order))


In [None]:
def get_data_loader(data_generator, dataset_size, batch_size) -> DataLoader:
    X, y = data_generator(dataset_size)

    X = torch.Tensor(X).to(DEVICE)
    X = torch.flatten(X, start_dim=1, end_dim=-1)
    y = torch.Tensor(y).to(DEVICE)
    dataset = TensorDataset(X, y)

    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)


def calc_asymm_loss(model: nn.Module, data_gen: SymmDataGenerator, n_points: int = 100) -> float:
    with torch.no_grad():
        asymm_loss = 0
        for _ in range(n_points):
            x = data_gen.gen_asymm_data(1, postprocess=False)
            symm_x = data_gen.symmetrise(x, postprocess=True)

            x = torch.Tensor(data_gen.postprocess_func(x)).to(DEVICE)
            symm_x = torch.Tensor(symm_x).to(DEVICE)

            pred = model(x)
            avg_symm_pred = model(symm_x).mean()

            if data_gen.classification:
                pred = torch.sigmoid(pred)
                avg_symm_pred = torch.sigmoid(avg_symm_pred)

            asymm_loss += abs(pred - avg_symm_pred).item()

    return asymm_loss / n_points


In [None]:
class Perceptron(nn.Module):
    def __init__(self, input_dim, output_dim, activation: nn.modules.activation = nn.ReLU(), bias: bool = True):
        super(Perceptron, self).__init__()

        self.layer = nn.Linear(input_dim, output_dim, bias=bias)
        self.act = activation

    def forward(self, x):
        x = self.act(self.layer(x))

        return x


class MLP(nn.Module):
    def __init__(self, input_size: int, hidden_l_size: int, n_hidden: int, output_size: int,
                 activation: nn.modules.activation = nn.ReLU(), bias: bool = True):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()

        assert n_hidden >= 1, 'Error - you need some hidden layers for this to be an MLP!'

        layers = [Perceptron(input_size, hidden_l_size, activation=activation, bias=bias)]
        layers += [Perceptron(hidden_l_size, hidden_l_size, activation=activation, bias=bias)
                   for _ in range(n_hidden - 1)]
        layers += [nn.Linear(hidden_l_size, output_size, bias=bias)]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.flatten(x)
        y = self.layers(x)
        return y


Model definition:

In [None]:
n_hidden = 1
hidden_size = 10 ** 2
n_samples = 10 ** 5

In [None]:
n_features = 3
symm_data_gen = SymmDataGenerator(n_features, [lambda x: x], lambda x: x.sum(axis=-1))


In [None]:
model = MLP(input_size=symm_data_gen.n_features, output_size=1, n_hidden=n_hidden, hidden_l_size=hidden_size)
n_params = sum(p.numel() for p in model.parameters())

Calculations:

In [None]:
ewk = torch.zeros((n_params, n_params)).float()

dataloader = get_data_loader(symm_data_gen, n_samples, batch_size=1)

for (X, y) in tqdm(dataloader):
    model.zero_grad()

    out = model(X)
    out.backward()

    grads = torch.cat([p.grad.flatten() for p in model.parameters()])

    ewk += torch.outer(grads, grads)

In [None]:
eig = torch.linalg.eig(ewk)

In [None]:
fs = 16

In [None]:
def plot_loghist(x, bins):
    hist, bins = np.histogram(x, bins=bins)
    logbins = np.logspace(np.log10(bins[0]), np.log10(bins[-1]), len(bins))
    plt.hist(x, bins=logbins, log=True)
    plt.xscale('log')
    plt.yscale('log')


plot_loghist(abs(eig.eigenvalues)[eig.eigenvalues != 0], 30)

plt.ylabel('# of eigenvalues', fontsize=fs)
plt.xlabel('Eigenvalue', fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs);

plt.gcf().savefig('ewk_eig_spect.pdf', dpi=300, bbox_inches="tight")

t=epoch * learning rate, decay=exp(-eigenvalue*t)

In [None]:
(abs(eig.eigenvalues) < 0.2).numpy().mean()

In [None]:
sum(p.numel() for p in model.parameters())