Imports:

In [None]:
import copy
from itertools import permutations

import pandas as pd
import torch
from matplotlib import pyplot as plt

from lr_symm_invar_analysis.perceptron_symm_decomp import train_model
from symm_data_gen import PermXYSymmGenerator, AbsXSymmGenerator, SymmDataGenerator, PolyXYSymmGenerator
from utils import Perceptron


Helper functions:

In [None]:
def train_models(symm_data_gen, lr, n_epochs, n_samples, n_models: int,
                 manually_set_params: bool = False, asymm: bool = False) -> pd.DataFrame:
    results = []
    for n_model in range(n_models):
        model = Perceptron(input_dim=symm_data_gen.n_features, output_dim=1, activation=lambda x: x)

        # relevant only for reflection symmetry
        if manually_set_params:
            weights, biases = tuple(model.state_dict().values())

            if asymm:
                weights.copy_(torch.tensor([[n_model]], dtype=weights.dtype))
                biases.copy_(torch.tensor([1], dtype=biases.dtype))
            else:
                weights.copy_(torch.tensor([[1]], dtype=weights.dtype))
                biases.copy_(torch.tensor([n_model], dtype=biases.dtype))

        res = train_model(model, symm_data_gen, lr=lr, n_epochs=n_epochs, n_samples=n_samples)
        res['n_model'] = n_model

        results.append(res)

    return pd.concat(results)

In [None]:
def plot_res_asymm_loss(results: pd.DataFrame, fs: int = 16, ax=None, label: str = None):
    if ax is None:
        plt.figure()
        ax = plt.gca()

    results.groupby('n_model').apply(lambda data: data.plot(x='epoch', y='asymm_loss', ax=ax, label=label))
    ax.set_xlim(0)
    ax.semilogy()

    ax.get_legend().remove()
    ax.set_ylabel('Asymmetry loss', fontsize=fs)
    ax.set_xlabel('Epoch', fontsize=fs)
    plt.xticks(fontsize=fs)
    plt.yticks(fontsize=fs)


In [None]:
def plot_exp(exp: pd.DataFrame, fs: int, figsize: tuple, groupby_col: str = 'n_model', semilogy: bool = False):
    fig, axs = plt.subplots(2, 1, sharex=True)
    fig.set_size_inches(*figsize)

    exp.groupby(groupby_col).apply(lambda data: data.plot(x='epoch', y='weights', ax=axs[0]))
    exp.groupby(groupby_col).apply(lambda data: data.plot(x='epoch', y='biases', ax=axs[1]))

    for ax in axs:
        ax.get_legend().remove()

    plt.subplots_adjust(hspace=0.)

    axs[0].set_ylabel('Asymm. comp.', fontsize=fs)
    axs[1].set_ylabel('Symm. comp.', fontsize=fs)
    axs[1].set_xlabel('Epoch', fontsize=fs)

    for ax in axs:
        ax.set_xlim(0)
        ax.tick_params(axis='both', which='major', labelsize=fs)

        if semilogy:
            ax.semilogy()

In [None]:
def extr_weights_and_biases(res):
    res['weights'] = res.params.apply(lambda x: x[0][0][0])
    res['biases'] = res.params.apply(lambda x: x[1][0])


Experiments' hyperparams:

In [None]:
n_epochs = 100
lr = 1e-2
n_samples = 1000
n_models = 3

Experiment I - learning x+y, x-y switching symmetry

In [None]:
symm_data_gen = PermXYSymmGenerator

exp1_res = train_models(symm_data_gen, lr, n_epochs, n_samples, n_models)

In [None]:
plot_res_asymm_loss(exp1_res)

Experiment II - learning |x|, reflection symmetry

In [None]:
symm_data_gen = AbsXSymmGenerator

exp2_res = train_models(symm_data_gen, lr, n_epochs, n_samples, n_models)

In [None]:
plt.figure(figsize=(5, 5))
plot_res_asymm_loss(exp2_res, ax=plt.gca())

plt.gcf().savefig('abs_x_asymm.pdf', dpi=300, bbox_inches="tight")

Experiment III - learning x+y+z, permutation symmetry

In [None]:
n_features = 3
symm_data_gen = SymmDataGenerator(n_features,
                                  [lambda x: x[:, list(perm)] for perm in permutations(range(n_features))],
                                  lambda x: x.sum(axis=-1))

exp3_res = train_models(symm_data_gen, lr, n_epochs, n_samples, n_models)

In [None]:
plot_res_asymm_loss(exp3_res)

Experiment IV - |x|, various SNRs

In [None]:
noise_levels = [0, 0.1, 0.3, 0.5, 0.7]

noisy_res = []
for noise in noise_levels:
    NoisyAbsXSymmGenerator = SymmDataGenerator(1, [lambda x: x, lambda x: -x], lambda x: abs(x), noise=noise)
    res = train_models(NoisyAbsXSymmGenerator, lr, n_epochs, n_samples, n_models=1)
    res['noise_std'] = noise

    noisy_res.append(res)

exp4_res = pd.concat(noisy_res)

In [None]:
plt.figure(figsize=(9, 7))

exp4_res.groupby('noise_std').apply(lambda x: x.plot(x='epoch', y='asymm_loss', ax=plt.gca(), label=x.name));

plt.semilogy()

fs = 16
plt.legend(title='Noise std', title_fontsize=fs, fontsize=fs)
plt.xlabel('Epoch', fontsize=fs)
plt.ylabel('Asymmetry loss', fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs)

plt.gcf().savefig('inv_noise_test.pdf', dpi=300, bbox_inches="tight")

In [None]:
extr_weights_and_biases(exp4_res)

In [None]:
plot_exp(exp4_res, fs=14, figsize=(8, 8), groupby_col='noise_std', semilogy=True)

Experiment V - x+y, inexact symmetry

In [None]:
symm_data_gen = copy.deepcopy(PermXYSymmGenerator)
symm_data_gen.symmetrise_expl = False

red_n_samples = 20

exp5_res = train_models(symm_data_gen, lr, n_epochs * 10, n_samples=red_n_samples, n_models=n_models)

In [None]:
plot_res_asymm_loss(exp5_res)

fs = 16
plt.xlabel('Epoch', fontsize=fs)
plt.ylabel('Asymmetry loss', fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs)

plt.gcf().savefig('inexact_inv_symm.pdf', dpi=300, bbox_inches="tight")

Experiment VI - x+y with a polynomial kernel, under and overparameterised systems

In [None]:
# overparameterised
symm_data_gen = PolyXYSymmGenerator(order=2)

exp6_res = train_models(symm_data_gen, lr, n_epochs, n_samples=100, n_models=n_models)

In [None]:
plot_res_asymm_loss(exp6_res)

In [None]:
# underparameterised
symm_data_gen = PolyXYSymmGenerator(order=15)

exp6_res = train_models(symm_data_gen, lr, n_epochs, n_samples=100, n_models=n_models)

In [None]:
plot_res_asymm_loss(exp6_res)

Experiment VII - x+y, large learning rates

In [None]:
lrs = [1e-3, 1e-2, 1e-1, 1]

symm_data_gen = PermXYSymmGenerator

diff_lr_res = []
for diff_lr in lrs:
    res = train_models(symm_data_gen, diff_lr, n_epochs, n_samples, n_models=1)
    res['lr'] = lr

    diff_lr_res.append(res)

exp7_res = pd.concat(diff_lr_res)

Experiment VIII - |x|and x^2, asymm part's lack of dependence on labels

Experiment IX - |x|, independence of symm and asymm parts

In [None]:
symm_data_gen = AbsXSymmGenerator

exp9_symm_res = train_models(symm_data_gen, lr, n_samples=n_samples, n_models=5, manually_set_params=True, n_epochs=50)
exp9_asymm_res = train_models(symm_data_gen, lr, n_samples=n_samples, n_models=5, manually_set_params=True, asymm=True,
                              n_epochs=50)

In [None]:
extr_weights_and_biases(exp9_symm_res)
extr_weights_and_biases(exp9_asymm_res)

In [None]:
plot_exp(exp9_asymm_res, fs=16, figsize=(5, 5))

plt.gcf().savefig('abs_x_asymm_weights.pdf', dpi=300, bbox_inches="tight")

In [None]:
plot_exp(exp9_symm_res, fs=16, figsize=(5, 5))

plt.gcf().savefig('abs_x_symm_weights.pdf', dpi=300, bbox_inches="tight")