Imports:

In [None]:
import copy
from itertools import permutations

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt

from lr_symm_invar_analysis.perceptron_symm_decomp import train_model
from symm_data_gen import PermXYSymmGenerator, AbsXSymmGenerator, SymmDataGenerator, PolyXYSymmGenerator
from utils import Perceptron


Helper functions:

In [None]:
def train_models(symm_data_gen, lr, n_epochs, n_samples, n_models: int, manually_set_params: bool = False,
                 asymm: bool = False, bias: bool = False) -> pd.DataFrame:
    results = []
    for n_model in range(n_models):
        model = Perceptron(input_dim=symm_data_gen.n_features, output_dim=symm_data_gen.n_features,
                           activation=lambda x: x, bias=bias)

        if manually_set_params:
            weights, = tuple(model.state_dict().values())

            if asymm:
                weights.copy_(torch.tensor([[n_model + 1, n_model],
                                            [- n_model, -n_model + 1]], dtype=weights.dtype))
            else:
                weights.copy_(torch.tensor([[n_model - 1, n_model - 1],
                                            [n_model + 1, n_model + 1]], dtype=weights.dtype))

        res = train_model(model, symm_data_gen, lr=lr, n_epochs=n_epochs, n_samples=n_samples)
        res['n_model'] = n_model

        results.append(res)

    return pd.concat(results)

In [None]:
def plot_res_asymm_loss(results: pd.DataFrame, fs: int = 16, ax=None, label: str = None):
    if ax is None:
        plt.figure()
        ax = plt.gca()

    results.groupby('n_model').apply(lambda data: data.plot(x='epoch', y='asymm_loss', ax=ax, label=label))
    ax.set_xlim(0)
    ax.semilogy()

    ax.get_legend().remove()
    ax.set_ylabel('Asymmetry loss', fontsize=fs)
    ax.set_xlabel('Epoch', fontsize=fs)
    plt.xticks(fontsize=fs)
    plt.yticks(fontsize=fs)


Experiments' hyperparams:

In [None]:
n_epochs = 100
lr = 1e-2
n_samples = 1000
n_models = 3

Experiment I - learning $y, x$, x-y switching symmetry

In [None]:
symm_data_gen = SymmDataGenerator(2, [lambda x: x, lambda x: x[:, [1, 0]]], lambda x: x[:, ::-1], equiv=True)

exp1_res = train_models(symm_data_gen, lr, n_epochs, n_samples, n_models)

In [None]:
plt.figure(figsize=(6, 7))

plot_res_asymm_loss(exp1_res, ax=plt.gca())

plt.gcf().savefig('xy_equiv.pdf', dpi=300, bbox_inches="tight")

Experiment II - plotting asymmetric and symmetric components over time:

In [None]:
def extr_asymm_comps(exp_res: pd.DataFrame):
    exp_res['asymm_comp1'] = exp_res.params.apply(lambda data: data[0][0, 1] - data[0][1, 0]) / 2
    exp_res['asymm_comp2'] = exp_res.params.apply(lambda data: data[0][1, 1] - data[0][0, 0]) / 2


def extr_symm_comps(exp_res: pd.DataFrame):
    exp_res['symm_comp1'] = exp_res.params.apply(lambda data: data[0][0, 1] + data[0][1, 0]) / 2
    exp_res['symm_comp2'] = exp_res.params.apply(lambda data: data[0][1, 1] + data[0][0, 0]) / 2

In [None]:
n_models = 5
asymm_exps = train_models(symm_data_gen, lr, n_epochs, n_samples, n_models, asymm=True, manually_set_params=True)
symm_exps = train_models(symm_data_gen, lr, n_epochs, n_samples, n_models, asymm=False, manually_set_params=True)

extr_symm_comps(asymm_exps)
extr_asymm_comps(asymm_exps)

extr_symm_comps(symm_exps)
extr_asymm_comps(symm_exps)

In [None]:
def plot_exp(exp: pd.DataFrame, fs: int, figsize: tuple, groupby_col: str = 'n_model', semilogy: bool = False):
    fig, axs = plt.subplots(2, 2, sharex=True)
    fig.set_size_inches(*figsize)

    exp.groupby(groupby_col).apply(lambda data: data.plot(x='epoch', y='asymm_comp1', ax=axs[0, 0]))
    exp.groupby(groupby_col).apply(lambda data: data.plot(x='epoch', y='asymm_comp2', ax=axs[1, 0]))
    exp.groupby(groupby_col).apply(lambda data: data.plot(x='epoch', y='symm_comp1', ax=axs[0, 1]))
    exp.groupby(groupby_col).apply(lambda data: data.plot(x='epoch', y='symm_comp2', ax=axs[1, 1]))

    for sub_axs in axs:
        for ax in sub_axs:
            ax.get_legend().remove()

    plt.subplots_adjust(wspace=0.3, hspace=0.05)

    axs[0, 0].set_title('Asymm. comps.', fontsize=fs)
    axs[0, 1].set_title('Symm. comps.', fontsize=fs)
    axs[1, 0].set_xlabel('Epochs', fontsize=fs)
    axs[1, 1].set_xlabel('Epochs', fontsize=fs)

    for sub_ax in axs:
        for ax in sub_ax:
            ax.set_xlim(0)
            ax.tick_params(axis='both', which='major', labelsize=fs)

            if semilogy:
                ax.semilogy()


In [None]:
plot_exp(asymm_exps, fs=16, figsize=(9, 7))

plt.gcf().savefig('xy_equiv_asymm.pdf', dpi=300, bbox_inches="tight")

In [None]:
plot_exp(symm_exps, fs=16, figsize=(9, 7))

plt.gcf().savefig('xy_equiv_symm.pdf', dpi=300, bbox_inches="tight")