### Libraries & Setup

In [None]:
import sys
sys.path.append('../')

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [None]:
from layers import CompleteLayer
from inits import Size, Like
from inits import (
    RandomNormal,
    RandomUniform,
    Ones,
    Zeros,
    Triu
)
from pruning import PruneEnsemble
from pruning import (
    NoPrune,
    RandomPrune,
    TopKPrune,
    DynamicTopK,
    ThresholdPrune,
    TrilPrune,
    TrilDamp,
    DynamicTrilDamp
)
import data
import losses
import experiments
from training import train
from evals import (
    NullVisualiser,
    LineVisualiser,
    BoxVisualiser,
    WeightVisualiser,
    OrderednessVisualiser
)
from utils import permute, brute_force_orderedness

In [None]:
if torch.cuda.is_available():
    try:
        _ = torch.tensor([0], device='cuda')
        device = torch.device('cuda')
    except:
        device = torch.device('cpu')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

### Generate Tiny Dataset

In [None]:
TASK_NAME = 'xor' # 'xor', 'sine', 'none'
OUT_DIR = Path('../media')
FNAME = f'{TASK_NAME}.pdf'

print(f'Running task: {TASK_NAME}.')

generator = data.TaskGenerator(TASK_NAME, device)
dataloader = generator.dataloader
params = generator.params

### Simple Baseline

In [None]:
NUM_TRIES = 10  # number of times to run each experiment for reliability
SEED = 3141592  # random seed for reproducibility

In [None]:
def baseline_setup():
    baseline = generator.get_mlp_baseline()
    criterion = losses.MSELoss()
    optim = torch.optim.Adam(
        baseline.parameters(),
        lr=params['baseline_lr']
    )
    return {
        'model': baseline,
        'train_criterion': criterion,
        'optimiser': optim
    }

visualisers = experiments.run(
    tries=NUM_TRIES,
    seed=SEED,
    track_orderedness=False,
    setup_fn=baseline_setup,
    visualisers={
        'train': LineVisualiser(
            lambda r: r['train_losses'],
            xlabel='Step',
            ylabel='Train Loss',
            only_values=True,
            fname=OUT_DIR/'dynamics'/f'{TASK_NAME}-mlp.pdf'
        )
    },
    n_epochs=params['baseline_epochs'],
    train_dataloader=dataloader,
    early_stop=3e-3,
    trainable=TASK_NAME != 'none'
)

### Complete Network

In [None]:
def setup():
    complete = CompleteLayer(
        input_size=params['input_size'],
        hidden_size=params['hidden_size'],
        output_size=params['output_size'],
        values_init=(RandomNormal(), True),
        weights_init=(RandomNormal(), True),
        activation=F.sigmoid,
        use_bias=False
    ).to(device)
    optim = torch.optim.Adam(
        complete.parameters(),
        lr=params['complete_lr']
    )
    return {
        'model': complete,
        'optimiser': optim
    }

TRACK_ORDEREDNESS = TASK_NAME in ('none', 'xor')
visualisers, result = experiments.run(
    pruner=PruneEnsemble({
        'values': NoPrune(),
        'weights': DynamicTopK(k=0.5),
    }),
    visualisers={
        'orderedness-absolute': OrderednessVisualiser(
            lambda r: r['model'].weights,
            name   = 'weights',
            graphs = False
        ),
        'orderedness-change': BoxVisualiser(
            lambda r: r['delta_final'],
            name='Change in orderedness',
            ylabel='Change in orderedness',
            fname=OUT_DIR/'change'/f'{TASK_NAME}-box.pdf'
        ),
        'orderedness-steps': LineVisualiser(
            lambda r: r['delta_steps'],
            xlabel='Step',
            ylabel='Change in orderedness',
            only_values=True,
            fname=OUT_DIR/'change'/f'{TASK_NAME}-curve.pdf'
        ) if TRACK_ORDEREDNESS else NullVisualiser(),
        'train': LineVisualiser(
            lambda r: r['train_losses'],
            xlabel='Step',
            ylabel='Train Loss',
            only_values=True,
            fname=OUT_DIR/'dynamics'/f'{TASK_NAME}-clp.pdf'
        ),
        'weights': WeightVisualiser(
            lambda r: r['model'].weights,
            name='weights',
            show=['sample']
        )
    },
    seed=SEED,
    tries=NUM_TRIES,
    its=params['its'],
    track_orderedness=TRACK_ORDEREDNESS,
    n_epochs=params['complete_epochs'],
    setup_fn=setup,
    train_dataloader=dataloader,
    train_criterion=losses.MSELoss(),
    trainable=TASK_NAME != 'none'
)

In [None]:
sample = visualisers['weights'].all_weights[0]
square = sample[:, :-params['input_size']]
orderedness, perm = brute_force_orderedness(
    square, fixed_size=params['output_size']
)
print(f'Orderedness: {orderedness:.3f}')

sns.heatmap(permute(
    sample, perm, perm+[len(perm)+i for i in range(params['input_size'])]
), annot=True, cmap='viridis')
plt.title(f'Orderedness of weights: {orderedness:.3f}')
plt.tight_layout()
plt.savefig(OUT_DIR/'weights'/FNAME, bbox_inches='tight')
plt.show()

#### Hidden unit - iteration plots

In [None]:
its = list(range(1, 11))
units = list(range(1, 9, 2))

ui_means = np.zeros((len(units), len(its)))
ui_stds = np.zeros((len(units), len(its)))

pbar = tqdm(total=len(units) * len(its))
for u_idx, u in enumerate(units):
    for i_idx, i in enumerate(its):

        def _setup():
            complete = CompleteLayer(
                input_size=params['input_size'],
                hidden_size=u,
                output_size=params['output_size'],
                values_init=(RandomNormal(), True),
                weights_init=(RandomNormal(), True),
                activation=F.sigmoid,
                use_bias=False
            ).to(device)
            optim = torch.optim.Adam(
                complete.parameters(),
                lr=params['complete_lr']
            )
            return {
                'model': complete,
                'optimiser': optim
            }

        visualisers, result = experiments.run(
            its=i,
            pruner=PruneEnsemble({
                'values': NoPrune(),
                'weights': DynamicTopK(k=0.5),
            }),
            visualisers={
                'delta-final': BoxVisualiser(
                    lambda r: r['delta_final'],
                    show=False
                ),
            },
            seed=SEED,
            tries=NUM_TRIES,
            track_orderedness=False,
            n_epochs=params['complete_epochs'],
            setup_fn=_setup,
            train_dataloader=dataloader,
            train_criterion=losses.MSELoss(),
            early_stop=0,
            trainable=TASK_NAME != 'none'
        )

        ui_means[u_idx, i_idx] = visualisers['delta-final'].mean_x
        ui_stds[u_idx, i_idx] = visualisers['delta-final'].std_x
        pbar.update(1)

In [None]:
sns.heatmap(
    ui_means[::-1],
    annot=True,
    yticklabels=units[::-1],
    xticklabels=its,
    cmap='crest'
)

print(f'Change in Orderedness - Mean ({TASK_NAME})')
plt.xlabel('Number of iterations')
plt.ylabel('Number of hidden units')
plt.tight_layout()
plt.savefig(OUT_DIR/'hi'/f'{TASK_NAME}-mean.pdf', bbox_inches='tight')
plt.show()

In [None]:
sns.heatmap(
    ui_stds[::-1],
    annot=True,
    yticklabels=units[::-1],
    xticklabels=its,
    cmap='crest'
)

print(f'Change in Orderedness - Stddev ({TASK_NAME})')
plt.xlabel('Number of iterations')
plt.ylabel('Number of hidden units')
plt.tight_layout()
plt.savefig(OUT_DIR/'hi'/f'{TASK_NAME}-std.pdf', bbox_inches='tight')
plt.show()

#### Sparsity - orderedness plots

In [None]:
sparsity_vals = np.arange(0, 1, 0.1).tolist()

pruners = {
    'Random': (lambda s: RandomPrune(p=s), 'p'),
    'Top-K': (lambda s: TopKPrune(k=1-s), '1-k'),
    'Dyn. Top-K': (lambda s: DynamicTopK(k=1-s), '1-k'),
    'Tril-damp': (lambda s: TrilDamp(f=s), 'f'),
    'Dyn. Tril-damp': (lambda s: DynamicTrilDamp(f=s), 'f')
}

so_means = {
    p: []
    for p in pruners
}

so_stds = {
    p: []
    for p in pruners
}

so_maxs = {
    p: []
    for p in pruners
}

so_mins = {
    p: []
    for p in pruners
}

pbar = tqdm(total=len(sparsity_vals) * len(pruners))
for pn in pruners:
    for v_idx, v in enumerate(sparsity_vals):

        def _setup():
            complete = CompleteLayer(
                input_size=params['input_size'],
                hidden_size=params['hidden_size'],
                output_size=params['output_size'],
                values_init=(RandomNormal(), True),
                weights_init=(RandomNormal(), True),
                activation=F.sigmoid,
                use_bias=False
            ).to(device)
            optim = torch.optim.Adam(
                complete.parameters(),
                lr=params['complete_lr']
            )
            return {
                'model': complete,
                'optimiser': optim
            }

        visualisers, result = experiments.run(
            pruner=PruneEnsemble({
                'values': NoPrune(),
                'weights': pruners[pn][0](v),
            }),
            visualisers={
                'final': BoxVisualiser(
                    lambda r: r['final_orderedness'],
                    show=False
                ),
            },
            seed=SEED,
            tries=NUM_TRIES,
            its=params['its'],
            track_orderedness=False,
            n_epochs=params['complete_epochs'],
            setup_fn=_setup,
            train_dataloader=dataloader,
            train_criterion=losses.MSELoss(),
            early_stop=0,
            trainable=TASK_NAME != 'none'
        )

        so_means[pn].append(visualisers['final'].mean_x)
        so_stds[pn].append(visualisers['final'].std_x)
        so_maxs[pn].append(visualisers['final'].max_x)
        so_mins[pn].append(visualisers['final'].min_x)
        pbar.update(1)

In [None]:
plt.figure(figsize=(12, 7))

for pn in pruners:
    plt.plot(sparsity_vals, so_means[pn], label=pn + f' ({pruners[pn][1]})')
    plt.fill_between(sparsity_vals, so_maxs[pn], so_mins[pn], alpha=0.2)

print(f'Relationship Between Pruning Sparsity and Orderedness ({TASK_NAME})')
plt.legend()
plt.xlabel('Pruning sparsity')
plt.ylabel('Orderedness')
plt.tight_layout()
plt.savefig(OUT_DIR/'so'/FNAME, bbox_inches='tight')
plt.show()