## A Measure of the Complexity of Neural Representations based on Partial Information Decomposition
David A. Ehrlich, Andreas C. Schneider, Viola Priesemann, Michael Wibral, Abdullah Makkeh. TMLR 2023.\
Supplementary Code -  Script 3/5
### Training and evaluating the 4-level-quantized, one-hot-output MNIST network (Figure 5.B,D)

In [None]:
%load_ext autoreload
%autoreload 2

### Train network

In [None]:
import time
import itertools
import  math

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cm = 1/2.54  # centimeters to inches

import nninfo

In [None]:
# Set experiment id
experiment_id = "mnist_4levels_onehot"

#### Set up parameters and train first network

In [None]:
# Note that we do not set initial seeds manually here, but save all seeds to the
# checkpoints files during training for later reproducibility. Rerunning this script
# will produce slightly different figures due to the randomness of network
# initialization etc.

layer_infos = [
    nninfo.LayerInfo(connection_layer='input', activation_function='input'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 784, 'out_features': 50}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 50, 'out_features': 10}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 10, 'out_features': 5}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 5, 'out_features': 5}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 5, 'out_features': 5}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 5, 'out_features': 10}, activation_function='softmax_output')
]

# Set weight initialization
initializer_name = 'xavier'

# Create network instance
network = nninfo.NeuralNetwork(layer_infos=layer_infos, init_str=initializer_name)

# Set task instance
task = nninfo.TaskManager('mnist_1d_dat')
# Split dataset into MNIST training set, and MNIST+QMNIST test set
task['full_set'].train_test_val_sequential_split(60_000, 60_000, 0)

# Create quantizer list with stochastic quantization. The input layer is not quantized.
quantizer = [None] + 5 * [{'levels': 4, 'rounding_point': 'stochastic'}] + [None]

#Set up trainer
trainer = nninfo.Trainer(dataset_name='full_set/train',
                        optim_str='SGD',
                        loss_str='CELoss',
                        lr=0.01,
                        shuffle=True,
                        batch_size=64,
                        quantizer=quantizer)

# Set up tester
tester = nninfo.Tester(dataset_name='full_set/test')

# Set up schedule
schedule = nninfo.Schedule()
# Save training state for 50 logarithmically spaced checkpoints
schedule.create_log_spaced_chapters(100_000, 50)

# Combine components into experiment
experiment = nninfo.Experiment(experiment_id=experiment_id,
                        network=network,
                        task=task,
                        trainer=trainer,
                        tester=tester,
                        schedule=schedule)

# Run training for 10^5 epochs
experiment.run_following_schedule()

#### Rerun training with different random weight initializations

In [None]:
# Set up experiment
exp = nninfo.Experiment.load(experiment_id)

# Compute 19 more training runs with different random weight initializations.
exp.rerun(19)

### Evaluate network performance
#### Compute loss and accuracy

In [None]:
quantizer_params = [None] + 5 * [{'levels': 4, 'rounding_point': 'center_saturating'}] + [None]

experiment = nninfo.Experiment.load(experiment_id)
performance_measurement = nninfo.analysis.PerformanceMeasurement(experiment, ['full_set/train', 'full_set/test'], quantizer_params=quantizer_params)

performance_measurement.perform_measurements(run_ids='all', chapter_ids='all', exists_ok=True)

#### Plot loss and accuracy

In [None]:
# Load performance file
experiment = nninfo.Experiment.load(experiment_id)
performance_measurement = nninfo.analysis.PerformanceMeasurement.load(experiment, "performance")

fig, ax = plt.subplots(figsize=(4*cm, 4*cm), dpi=150)
ax.set_ylim(0, 1)
twinax = ax.twinx()

# Plot accuracy
nninfo.plot.plot_loss_accuracy(performance_measurement.results, ax, twinax)

ax.legend(ncol=1, bbox_to_anchor=(1.5, 0.5), loc='center left');

In [None]:
# Save result
plt.savefig(f"experiments/exp_{experiment_id}/plots/performance.pdf", bbox_inches='tight')

### Perform subsampling PID on hidden layer $L_2$

In [None]:
def get_source_indices(n_sources: int, combination_index: int, layer_width: int):
    """ Get deterministic random source indices for a given combination index."""

    # Create a reproducable shuffled list of numbers from 0 to (layer_width over n_sources)
    np.random.seed(1234)
    rand = np.random.permutation(math.comb(layer_width, n_sources))

    # Get the combination_index'th element of the shuffled list
    random_combination_index = rand[combination_index]

    # Get the source indices from the random combination index
    combinations_iter = itertools.combinations(range(layer_width), n_sources)
    source_indices = next(x for i, x in enumerate(combinations_iter) if i == random_combination_index)

    return source_indices

In [None]:
# Load experiment
experiment = nninfo.Experiment.load(experiment_id)

# Repeat for 26 random draws of five neurons as sources
for combination_index in range(26):

    source_indices = get_source_indices(
            n_sources=5, combination_index=combination_index, layer_width=10)

    target = [nninfo.NeuronID('Y', (1,))]
    sources = [[nninfo.NeuronID('L2', (source_id+1,))] for source_id in source_indices]

    # Create quantizer list for deterministic rounding.
    quantization_params = [None] + 5 * [{'levels': 4, 'rounding_point': 'center_saturating'}] + [None]

    # Compute PID for all chapters of all random network initializations
    pid_measurement = nninfo.analysis.PIDMeasurement(experiment,
                                                     measurement_id=f'subsampling_pid_{source_indices}',
                                                     dataset_name='full_set/train',
                                                     pid_definition='sxpid',
                                                     target_id_list=target,
                                                     source_id_lists=sources,
                                                     binning_kwargs={'binning_method':'none'},
                                                     quantizer_params=quantizer_params)

    itic = time.time_ns()
    pid_measurement.perform_measurements(run_ids='all', chapter_ids='all')
    itoc = time.time_ns()
    print(f"Computing subsampled PID for a single choice of sources in L2 took: ", (itoc-itic)/10**9, "s")

### Perform coarse-grained PID on hidden layer $L_2$

In [None]:
# Load experiment
experiment = nninfo.Experiment.load(experiment_id)

target = [nninfo.NeuronID('Y', (1,))]
source1 = [nninfo.NeuronID(f'L2', (1,)), nninfo.NeuronID(f'L2', (2,))]
source2 = [nninfo.NeuronID(f'L2', (3,)), nninfo.NeuronID(f'L2', (4,))]
source3 = [nninfo.NeuronID(f'L2', (5,)), nninfo.NeuronID(f'L2', (6,))]
source4 = [nninfo.NeuronID(f'L2', (7,)), nninfo.NeuronID(f'L2', (8,))]
source5 = [nninfo.NeuronID(f'L2', (9,)), nninfo.NeuronID(f'L2', (10,))]
# Create quantizer list for deterministic rounding.
quantization_dict = [None] + 5 * [{'levels': 4, 'roundidng_point': 'center_saturating'}] + [None]

pid_measurement = nninfo.analysis.PIDMeasurement(experiment,
                                                     measurement_id=f'coarse_graining_pid',
                                                     dataset_name='full_set/train',
                                                     pid_definition='sxpid',
                                                     target_id_list=target,
                                                     source_id_lists=sources,
                                                     binning_kwargs={'binning_method':None},
                                                     quantizer_params=quantizer_params)

itic = time.time_ns()
pid_measurement.perform_measurements(run_ids='all', chapter_ids='all')
itoc = time.time_ns()
print(f"Computing coarse-grained PID for L2 took: ", (itoc-itic)/10**9, "s")

### Compute representational complexity for subsampling and plot results

In [None]:
from nninfo.postprocessing.pid_postprocessing import get_pid_summary_quantities

fig, ax = plt.subplots(figsize=(4*cm, 4*cm), sharex=True, dpi=150)

# Combine result dataframes from the different subsampling measurements
results = pd.DataFrame()
for combination_index in range(26):
    measurement = nninfo.analysis.PIDMeasurement.load(experiment, f'subsampling_pid_{combination_index}')
    pid_summary = get_pid_summary_quantities(measurement.results)
    results = pd.concat([results, pid_summary])

# Plot the results
nninfo.plot.plot_representational_complexity(results, ax, label='$L_2$')

ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.25))

ax.set_xlabel('Training Epoch')
ax.set_ylabel(r'Repr. Compl. $C$')

ax.legend(loc='upper right')

ax.set_ylim(1.25, 3.25);

In [None]:
# Save results
plt.savefig(f"experiments/exp_{experiment_id}/plots/representational_complexity_subsampling.pdf", bbox_inches='tight')

### Compute representational complexity for coarse-graining and plot results

In [None]:
from nninfo.postprocessing.pid_postprocessing import get_pid_summary_quantities

fig, ax = plt.subplots(figsize=(4*cm, 4*cm), sharex=True, dpi=150)

# Combine result dataframes from the different subsampling measurements
measurement = nninfo.analysis.PIDMeasurement.load(experiment, f'coarse_graining_pid')
pid_summary = get_pid_summary_quantities(measurement.results)

# Plot the results
nninfo.plot.plot_representational_complexity(pid_summary, ax, label='$L_2$')

ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.25))

ax.set_xlabel('Training Epoch')
ax.set_ylabel(r'Repr. Compl. $C$')

ax.legend(loc='upper right')

ax.set_ylim(1.25, 3.25);

In [None]:
# Save results
plt.savefig(f"experiments/exp_{experiment_id}/plots/representational_complexity_coarse_graining.pdf", bbox_inches='tight')