## A Measure of the Complexity of Neural Representations based on Partial Information Decomposition
David A. Ehrlich, Andreas C. Schneider, Viola Priesemann, Michael Wibral, Abdullah Makkeh. TMLR 2023.\
Supplementary Code -  Script 5/5

### Demo for Task from R. Shwartz-Ziv and N. Tishby (2017, arXiv https://arxiv.org/abs/1703.00810)
### (Toy dataset with 12 bits input and binary decision task)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import time

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cm = 1/2.54  # centimeters to inches

import nninfo

### Train network

In [None]:
# Set experiment id
experiment_id = "tishby_demo"

In [None]:
# Before rerunning the experiment, remove the previous version
import nninfo.utils
nninfo.utils.remove_experiment(experiment_id, silent=True)

#### Set up parameters and train first network

In [None]:
# Note that we do not set initial seeds manually here, but save all seeds to the
# checkpoints files during training for later reproducibility. Rerunning this script
# will produce slightly different figures due to the randomness of network
# initialization etc.

layer_infos = [
    nninfo.LayerInfo(connection_layer='input', activation_function='input'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 784, 'out_features': 12}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 12, 'out_features': 4}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 4, 'out_features': 4}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 4, 'out_features': 4}, activation_function='tanh'),
    nninfo.LayerInfo(connection_layer='linear', connection_layer_kwargs={'in_features': 4, 'out_features': 1}, activation_function='sigmoid')
]

initializer_name = 'xavier'

# Create network instance
network = nninfo.NeuralNetwork(layer_infos=layer_infos, init_str=initializer_name)

# Set task instance
task = nninfo.TaskManager('tishby_dat')

# Split dataset into Shwartz-Ziv and Tishby training set, and test set
task['full_set'].train_test_val_random_split(2796, 1300, 0, seed=42)
print(task)

# Create quantizer list with stochastic quantization. The input layer is not quantized.
quantizer = [None] + 4 * [{'levels': 8, 'rounding_point': 'stochastic'}] + [None]

# Initialize training components
trainer = nninfo.Trainer(dataset_name='full_set/train',
                                optim_str='SGD',
                                loss_str='BCELoss',
                                lr=0.01,
                                shuffle=True,
                                batch_size=64,
                                quantizer=quantizer)

tester = nninfo.Tester(dataset_name='full_set/test')

# Save training state for 30 logarithmically spaced checkpoints
schedule = nninfo.Schedule()
schedule.create_log_spaced_chapters(1000, 30)

# Set up experiment
exp = nninfo.Experiment(experiment_id, network, task, trainer, tester, schedule)

# Run training for 1000 epochs
exp.run_following_schedule(compute_test_loss=False)

In [None]:
# Set up experiment
exp = nninfo.Experiment.load(experiment_id)

# Compute 9 more training runs with different random weight initializations.
exp.rerun(9)

### Evaluate network performance
#### Compute loss and accuracy

In [None]:
# Compute loss and accuracy

quantizer_params = [None] + 4 * [{'levels': 8, 'rounding_point': 'center_saturating'}] + [None]

experiment = nninfo.Experiment.load(experiment_id)
perf_measurement = nninfo.analysis.PerformanceMeasurement(experiment, ['full_set/train', 'full_set/test'], quantizer_params=quantizer_params)

perf_measurement.perform_measurements(run_ids='all', chapter_ids='all', exists_ok=True)
perf_measurement.results

#### Plot loss and accuracy

In [None]:
# Load performance file
experiment = nninfo.Experiment.load(experiment_id)
performance_measurement = nninfo.analysis.PerformanceMeasurement.load(experiment, "performance")
performance_results = performance_measurement.results

cm = 1/2.54

fig, ax = plt.subplots(figsize=(4*cm, 4*cm), dpi=150)
twinax = ax.twinx()

# Plot accuracy
nninfo.plot.plot_loss_accuracy(performance_results, ax, twinax)

ax.legend(ncol=1, bbox_to_anchor=(1.5, 0.5), loc='center left');

In [None]:
# Save result
plt.savefig(f"experiments/exp_{experiment_id}/plots/performance.pdf", bbox_inches='tight');

### Perform quadravariate PID on hidden layers $L_3$, $L_4$ and $L_5$

In [None]:
from nninfo.model.neural_network import NeuronID

# Load experiment
exp = nninfo.experiment.Experiment.load(experiment_id)

for layer in [2, 3, 4]:

    target = [NeuronID('Y', (1,))]
    source1 = [NeuronID(f'L{layer}', (1,))]
    source2 = [NeuronID(f'L{layer}', (2,))]
    source3 = [NeuronID(f'L{layer}', (3,))]
    source4 = [NeuronID(f'L{layer}', (4,))]

    # Create quantizer list for deterministic rounding.
    quantization_dict = [None] + 6 * [{'levels': 8, 'rounding_point': 'center_saturating'}]

    # Set up analysis environment
    measurement = nninfo.analysis.pid_measurement.PIDMeasurement(experiment=exp,
                                                     measurement_id=f'pid_L{layer}',
                                                     dataset_name='full_set/train',
                                                     quantizer_params=quantization_dict,

                                                     pid_definition='reing',
                                                     binning_kwargs={'binning_method' : 'none'},
                                                     target_id_list=target,
                                                     source_id_lists=[source1, source2, source3, source4])

    itic = time.time_ns()

    # Compute PID for all chapters of all random network initializations
    measurement.perform_measurements(run_ids=[0], chapter_ids='all')

    itoc = time.time_ns()
    print(f"Computing PID for L{layer} took: ", (itoc-itic)/10**9, "s")

### Compute representational complexity and plot results

In [None]:
import matplotlib.ticker as ticker
from nninfo.postprocessing.pid_postprocessing import get_pid_summary_quantities

fig, ax= plt.subplots(figsize=(5*cm, 4*cm), dpi=150)
inset_axis = inset_axes(ax, width=0.7, height=0.4)

# Load experiment
experiment = nninfo.experiment.Experiment.load(experiment_id)

# Plot accuracy on inset axis
performance_measurement = nninfo.analysis.PerformanceMeasurement.load(experiment, 'performance')
nninfo.plot.plot_accuracy(performance_measurement.results, 'full_set/train', inset_axis, c='k', label='Train')
nninfo.plot.plot_accuracy(performance_measurement.results, 'full_set/test', inset_axis, c='k', ls='--', label='Test')
inset_axis.set_ylim(0, 1)
inset_axis.set_ylabel('Acc.')
inset_axis.legend(bbox_to_anchor=(1.1, 1.3), loc='upper left')
    
for layer in [2, 3, 4]:
    measurement_loaded = nninfo.analysis.PIDMeasurement.load(experiment=experiment, measurement_id=f'pid_L{layer}')
    pid_summary = get_pid_summary_quantities(measurement_loaded.results)
    nninfo.plot.plot_representational_complexity(pid_summary, ax, label=f'Layer {layer}')

ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.25))

ax.set_xlabel('Training Epoch')
ax.set_ylabel(r'Repr. Compl. $C$')

ax.legend(bbox_to_anchor=(1, .6), loc='upper left')

ax.set_ylim(1, 4)
ax.set_yticks([1, 2, 3])
ax.set_yticklabels(['1', '2', '3']);

In [None]:
from nninfo.postprocessing.pid_postprocessing import get_pid_summary_quantities

for layer in [2, 3, 4]:
    fig, ax= plt.subplots(figsize=(5*cm, 4*cm), sharex=True, dpi=150)
    experiment = nninfo.experiment.Experiment.load(experiment_id)

    measurement_loaded = nninfo.analysis.PIDMeasurement.load(experiment=experiment, measurement_id=f'pid_L{layer}')
    pid_summary = get_pid_summary_quantities(measurement_loaded.results)
    nninfo.plot.plot_degree_of_synergy_atoms(pid_summary, ax)

    ax.set_xlabel('Training Epoch')
    ax.set_ylabel(r'Deg. of Syn. atoms (bits)')

    ax.set_title(f'Layer {layer}')
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

In [None]:
# Save results
plt.savefig(f"experiments/exp_{experiment_id}/plots/representational_complexity.pdf", bbox_inches='tight')