### Libraries & Setup

In [None]:
import sys
sys.path.append('../')

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [None]:
from layers import CompleteLayer
from inits import Size, Like
from inits import (
    RandomNormal,
    RandomUniform,
    Ones,
    Zeros,
    Triu
)
from pruning import PruneEnsemble
from pruning import (
    NoPrune,
    RandomPrune,
    TopKPrune,
    DynamicTopK,
    ThresholdPrune,
    TrilPrune,
    TrilDamp,
    DynamicTrilDamp
)
import data
import losses
import experiments
from training import train
from evals import (
    EvalVisualiser,
    LineVisualiser,
    BoxVisualiser,
    WeightVisualiser,
    OrderednessVisualiser
)
from utils import permute, brute_force_orderedness

In [None]:
if torch.cuda.is_available():
    try:
        _ = torch.tensor([0], device='cuda')
        device = torch.device('cuda')
    except:
        device = torch.device('cpu')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

### Complete Network

In [None]:
NUM_TRIES = 10  # number of times to run each experiment for reliability
SEED = 3141592  # random seed for reproducibility

In [None]:
WEIGHTS_INIT = RandomNormal()
VALUES_INIT = RandomNormal()
WEIGHTS_PRUNE = lambda : RandomPrune(p=0.5)

for task in ('none', 'xor', 'sine'):

    generator = data.TaskGenerator(task, device)
    dataloader = generator.dataloader
    params = generator.params
    trainable = task != 'none'

    def setup():
        complete = CompleteLayer(
            input_size=params['input_size'],
            hidden_size=params['hidden_size'],
            output_size=params['output_size'],
            values_init=(VALUES_INIT, True),
            weights_init=(WEIGHTS_INIT, True),
            activation=F.sigmoid,
            use_bias=False
        ).to(device)
        optim = torch.optim.Adam(
            complete.parameters(),
            lr=params['complete_lr']
        )
        return {
            'model': complete,
            'optimiser': optim
        }

    visualisers, result = experiments.run(
        its=params['its'],
        track_orderedness=False,
        pruner=PruneEnsemble({
            'values': NoPrune(),
            'weights': WEIGHTS_PRUNE(),
        }),
        visualisers={
            'start': BoxVisualiser(
                lambda r: r['start_orderedness'],
                show=False
            ),
            'delta': BoxVisualiser(
                lambda r: r['delta_final'],
                show=False
            ),
        },
        seed=SEED,
        tries=NUM_TRIES,
        n_epochs=params['complete_epochs'],
        setup_fn=setup,
        train_dataloader=dataloader,
        train_criterion=losses.MSELoss(),
        trainable=trainable
    )
    if task == 'none':
        print(f'Init: {visualisers['start'].mean_x:.3f} $\\pm$ {visualisers['start'].std_x:.3f}')
    print(f'{task}: {visualisers['delta'].mean_x:.3f} $\\pm$ {visualisers['delta'].std_x:.3f}')