# Mean field variance dynamics for noisy signal propagation

In [1]:
# imports
import os, sys, pickle
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.lines import Line2D
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap
from collections import OrderedDict

# custom import
from src.simulation import *
from src.viz import get_colours
from src.utils import load_experiment
from src.theory import depth
from src.theory import critical_point

# plot settings
sns.set_context("paper")
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.unicode'] = True
plt.rcParams['image.cmap'] = 'viridis'

# results directory
relative_results_dir = "results"
results_dir = os.path.join(relative_results_dir)

Using TensorFlow backend.


## --- Variance map: Theory vs simulation ---

In [2]:
experiments = [
    #{"dist": "bern", "noise": ('prob_1', 0.6), "act":"relu", "init":"underflow"}, 
    {"dist": "bern", "noise": ('prob_1', 0.6), "act":"relu", "init":"overflow"}, 
    #{"dist": "bern", "noise": ('prob_1', 0.6), "act":"relu", "init":"crit"},
    #{"dist": "mult gauss", "noise": ('std', 0.25), "act":"relu", "init":"underflow"},
    {"dist": "mult gauss", "noise": ('std', 0.25), "act":"relu", "init":"overflow"}, 
    #{"dist": "mult gauss", "noise": ('std', 0.25), "act":"relu", "init":"crit"}
]

for experiment in experiments:
    dist = experiment['dist']
    noise = experiment['noise']
    act = experiment['act']
    init = experiment['init']
    
    # run simulations for scenario
    noisy_signal_prop_simulations(dist, noise, act, init)

weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A

  0%|          | 0/25 [00:00<?, ?it/s][A[A

 32%|███▏      | 8/25 [00:00<00:00, 70.54it/s][A[A

####### EXPERIMENT: dist:  bern ;  prob_1 :  0.6 ; activation:  relu ; initialisation:  overflow ##############
qmap calculations...
Calculating Theory:




 56%|█████▌    | 14/25 [00:00<00:00, 64.25it/s][A[A

 80%|████████  | 20/25 [00:00<00:00, 61.65it/s][A[A

100%|██████████| 25/25 [00:00<00:00, 60.76it/s][A[A

  0%|          | 0/19 [00:00<?, ?it/s][A[A

100%|██████████| 19/19 [00:00<00:00, 1515.74it/s][A[A
bias': 100%|██████████| 1/1 [00:00<00:00,  2.23it/s][A
weights: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
jacobians: 100%|██████████| 1/1 [00:00<00:00, 20.26it/s]
hessians: 100%|██████████| 1/1 [00:00<00:00, 63.42it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A

Simulating network
Single layer sims...
Compiling net...
backend is tensorflow
Do Rops...




  0%|          | 0/100 [00:00<?, ?it/s][A[A

 34%|███▍      | 34/100 [00:00<00:00, 337.08it/s][A[A

 72%|███████▏  | 72/100 [00:00<00:00, 358.34it/s][A[A

100%|██████████| 100/100 [00:00<00:00, 361.94it/s][A[A
bias': 100%|██████████| 1/1 [00:00<00:00,  2.97it/s][A
weights: 100%|██████████| 1/1 [00:00<00:00,  2.95it/s]


Multi-layer sims...
Compiling net...
backend is tensorflow


jacobians:  11%|█         | 2/19 [00:00<00:01, 13.38it/s]

Do Rops...


jacobians: 100%|██████████| 19/19 [00:11<00:00,  1.70it/s]
hessians: 100%|██████████| 19/19 [00:19<00:00,  1.05s/it]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A

  0%|          | 0/3 [00:00<?, ?it/s][A[A

 33%|███▎      | 1/3 [00:00<00:01,  1.19it/s][A[A

100%|██████████| 3/3 [00:00<00:00,  3.12it/s][A[A

[A[A
bias': 100%|██████████| 1/1 [00:02<00:00,  2.44s/it][A
weights: 100%|██████████| 1/1 [00:02<00:00,  2.44s/it]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A
weights: 100%|██████████| 1/1 [00:00<00:00, 403.18it/s][A
100%|██████████| 1/1 [00:00<00:00, 2225.09it/s]
100%|██████████| 1/1 [00:00<00:00, 15827.56it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A
weights: 100%|██████████| 1/1 [00:00<00:00, 80.14it/s][A

Cmap calculations...
General cov prop...
Specific cov prop...
Curvature prop...
General cov prop simulation...
Compiling net...
backend is tensorflow



jacobians: 100%|██████████| 1/1 [00:00<00:00, 22.12it/s]
hessians: 100%|██████████| 1/1 [00:00<00:00, 70.10it/s]
  0%|          | 0/100 [00:00<?, ?it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

Do Rops...




bias': 100%|██████████| 1/1 [00:01<00:00,  1.44s/it][A[A

[A[A
weights: 100%|██████████| 1/1 [00:01<00:00,  1.44s/it][A
  1%|          | 1/100 [00:01<02:29,  1.51s/it]
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 20.93it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 19.03it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.10it/s][A[A
  3%|▎         | 3/100 [00:01<00:54,  1.79it/s]18it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.01it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 20.02it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 20.89it/s][A[A
  5

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 19.91it/s][A[A
 38%|███▊      | 38/100 [00:04<00:07,  8.03it/s]2it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.74it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 17.59it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.37it/s][A[A
 40%|████      | 40/100 [00:04<00:07,  8.17it/s]4it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 20.55it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 17.52it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 20.57it/s][A[A
 42%|████

bias': 100%|██████████| 1/1 [00:00<00:00, 20.49it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 17.58it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 19.93it/s][A[A
 76%|███████▌  | 76/100 [00:08<00:02,  9.41it/s]0it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 19.72it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 16.98it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.03it/s][A[A
 78%|███████▊  | 78/100 [00:08<00:02,  9.46it/s]0it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 19.31it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 16.62it/s][A
wei

Simulate multi-layer cov prop...
Compiling net...
backend is tensorflow


jacobians:   7%|▋         | 2/30 [00:00<00:01, 14.00it/s]

Do Rops...


jacobians: 100%|██████████| 30/30 [00:28<00:00,  1.07it/s]
hessians: 100%|██████████| 30/30 [00:58<00:00,  1.96s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:07<00:00,  7.67s/it][A[A

[A[A
weights: 100%|██████████| 1/1 [00:07<00:00,  7.68s/it][A
[A
  0%|          | 0/30 [00:00<?, ?it/s][A
  7%|▋         | 2/30 [00:00<00:02, 13.30it/s][A
 13%|█▎        | 4/30 [00:00<00:01, 13.09it/s][A
 20%|██        | 6/30 [00:00<00:01, 13.06it/s][A
 27%|██▋       | 8/30 [00:00<00:01, 13.05it/s][A
 33%|███▎      | 10/30 [00:00<00:01, 13.09it/s][A
 40%|████      | 12/30 [00:00<00:01, 13.14it/s][A
 47%|████▋     | 14/30 [00:01<00:01, 13.10it/s][A
 53%|█████▎    | 16/30 [00:01<00:01, 13.09it/s][A
 60%|██████    | 18/30 [00:01<00:00, 13.07it/s][A
 67%|██████▋   | 20/30 [00:01<00:00, 13.04it/s][A
 73%|███████▎  | 22/30 [00:01<00:00, 13.04it/s][A
 80%|████████ 

####### EXPERIMENT: dist:  mult gauss ;  std :  0.25 ; activation:  relu ; initialisation:  overflow ##############
qmap calculations...
Calculating Theory:




 12%|█▏        | 3/25 [00:00<00:01, 20.11it/s][A[A

 20%|██        | 5/25 [00:00<00:01, 17.05it/s][A[A

 28%|██▊       | 7/25 [00:00<00:01, 16.03it/s][A[A

 36%|███▌      | 9/25 [00:00<00:01, 15.55it/s][A[A

 44%|████▍     | 11/25 [00:00<00:00, 15.24it/s][A[A

 52%|█████▏    | 13/25 [00:00<00:00, 15.06it/s][A[A

 60%|██████    | 15/25 [00:01<00:00, 14.90it/s][A[A

 68%|██████▊   | 17/25 [00:01<00:00, 14.81it/s][A[A

 76%|███████▌  | 19/25 [00:01<00:00, 14.72it/s][A[A

 84%|████████▍ | 21/25 [00:01<00:00, 14.66it/s][A[A

 92%|█████████▏| 23/25 [00:01<00:00, 14.60it/s][A[A

100%|██████████| 25/25 [00:01<00:00, 14.55it/s][A[A

[A[A

  0%|          | 0/19 [00:00<?, ?it/s][A[A

100%|██████████| 19/19 [00:00<00:00, 2380.35it/s][A[A
bias': 100%|██████████| 1/1 [00:01<00:00,  1.81s/it][A
weights: 100%|██████████| 1/1 [00:01<00:00,  1.81s/it]


Simulating network
Single layer sims...
Compiling net...
backend is tensorflow


jacobians: 100%|██████████| 1/1 [00:00<00:00, 12.98it/s]
hessians: 100%|██████████| 1/1 [00:00<00:00, 48.81it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A

Do Rops...




  0%|          | 0/100 [00:00<?, ?it/s][A[A

  1%|          | 1/100 [00:02<04:41,  2.84s/it][A[A

 39%|███▉      | 39/100 [00:02<00:04, 13.25it/s][A[A

 74%|███████▍  | 74/100 [00:03<00:01, 24.31it/s][A[A

100%|██████████| 100/100 [00:03<00:00, 32.11it/s][A[A
bias': 100%|██████████| 1/1 [00:06<00:00,  6.00s/it][A
weights: 100%|██████████| 1/1 [00:06<00:00,  6.01s/it]


Multi-layer sims...
Compiling net...
backend is tensorflow


jacobians:   0%|          | 0/19 [00:00<?, ?it/s]

Do Rops...


jacobians: 100%|██████████| 19/19 [00:17<00:00,  1.10it/s]
hessians: 100%|██████████| 19/19 [00:29<00:00,  1.53s/it]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A

  0%|          | 0/3 [00:00<?, ?it/s][A[A

 33%|███▎      | 1/3 [00:04<00:09,  4.67s/it][A[A

100%|██████████| 3/3 [00:04<00:00,  1.60s/it][A[A

[A[A
bias': 100%|██████████| 1/1 [00:10<00:00, 10.17s/it][A
weights: 100%|██████████| 1/1 [00:10<00:00, 10.18s/it]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A
weights: 100%|██████████| 1/1 [00:00<00:00, 251.07it/s][A
100%|██████████| 1/1 [00:00<00:00, 1847.71it/s]
100%|██████████| 1/1 [00:00<00:00, 14563.56it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s]
bias':   0%|          | 0/1 [00:00<?, ?it/s][A
weights: 100%|██████████| 1/1 [00:00<00:00, 128.26it/s]A

Cmap calculations...
General cov prop...
Specific cov prop...
Curvature prop...
General cov prop simulation...
Compiling net...
backend is tensorflow



jacobians: 100%|██████████| 1/1 [00:00<00:00, 13.01it/s]
hessians: 100%|██████████| 1/1 [00:00<00:00, 54.39it/s]
  0%|          | 0/100 [00:00<?, ?it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

Do Rops...




bias': 100%|██████████| 1/1 [00:08<00:00,  8.99s/it][A[A

[A[A
weights: 100%|██████████| 1/1 [00:08<00:00,  9.00s/it][A
  1%|          | 1/100 [00:09<14:53,  9.02s/it]
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 24.40it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 23.09it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 24.60it/s][A[A
  3%|▎         | 3/100 [00:09<04:55,  3.05s/it]28it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 24.28it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 22.50it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 19.60it/s][A[A
  5

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 22.53it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 21.33it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.62it/s][A[A
 39%|███▉      | 39/100 [00:12<00:18,  3.23it/s]2it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.46it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 19.10it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.55it/s][A[A
 41%|████      | 41/100 [00:12<00:17,  3.35it/s]8it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 20.85it/s][A[A
weights: 

bias': 100%|██████████| 1/1 [00:00<00:00, 20.55it/s][A[A
 75%|███████▌  | 75/100 [00:14<00:04,  5.02it/s]0it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.17it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 20.10it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 20.15it/s][A[A
 77%|███████▋  | 77/100 [00:15<00:04,  5.10it/s]4it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 21.02it/s][A[A
weights: 100%|██████████| 1/1 [00:00<00:00, 19.91it/s][A
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:00<00:00, 22.15it/s][A[A
 79%|███████▉  | 79/100 [00:15<00:04,  5.19it/s]8it/s][A
wei

Simulate multi-layer cov prop...
Compiling net...
backend is tensorflow


jacobians:   0%|          | 0/30 [00:00<?, ?it/s]

Do Rops...


jacobians: 100%|██████████| 30/30 [00:45<00:00,  1.53s/it]
hessians: 100%|██████████| 30/30 [01:29<00:00,  3.00s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
weights:   0%|          | 0/1 [00:00<?, ?it/s][A

bias':   0%|          | 0/1 [00:00<?, ?it/s][A[A

bias': 100%|██████████| 1/1 [00:21<00:00, 21.39s/it][A[A

[A[A
weights: 100%|██████████| 1/1 [00:21<00:00, 21.39s/it][A
[A
  0%|          | 0/30 [00:00<?, ?it/s][A
  3%|▎         | 1/30 [00:00<00:03,  8.56it/s][A
 10%|█         | 3/30 [00:00<00:02, 10.68it/s][A
 17%|█▋        | 5/30 [00:00<00:02, 11.42it/s][A
 23%|██▎       | 7/30 [00:00<00:01, 11.84it/s][A
 30%|███       | 9/30 [00:00<00:01, 12.03it/s][A
 37%|███▋      | 11/30 [00:00<00:01, 12.13it/s][A
 43%|████▎     | 13/30 [00:01<00:01, 12.28it/s][A
 50%|█████     | 15/30 [00:01<00:01, 12.38it/s][A
 57%|█████▋    | 17/30 [00:01<00:01, 12.39it/s][A
 63%|██████▎   | 19/30 [00:01<00:00, 12.36it/s][A
 70%|███████   | 21/30 [00:01<00:00, 12.20it/s][A
 77%|███████▋  

### Iterative variance map and variance dynamics

In [3]:
# Dictionary for data that needs to be extracted
tests = [{
        "distributions": [{"dist": "mult gauss", "std": 0.25}],
        "activations": ["relu"],
        "inits": ["underflow"]
    }, {
        "distributions": [{"dist": "mult gauss", "std": 0.25}],
        "activations": ["relu"],
        "inits": ["overflow"]
    }, {
        "distributions": [{"dist": "mult gauss", "std": 0.25}],
        "activations": ["relu"],
        "inits": ["crit"]
    }, {
        "distributions": [{"dist": "bern", "prob_1": 0.6}],
        "activations": ["relu"],
        "inits": ["underflow"]
    }, {
        "distributions": [{"dist": "bern", "prob_1": 0.6}],
        "activations": ["relu"],
        "inits": ["overflow"]
    }, {
        "distributions": [{"dist": "bern", "prob_1": 0.6}],
        "activations": ["relu"],
        "inits": ["crit"]
    }]

############################################################################
############################################################################
# q - length / variance plots
############################################################################
############################################################################
nq = 100
qmax = 15
qrange = np.linspace(0, qmax, nq)
widxs = [0]
bidxs = [0]
qidx = [26]
n_hidden_layers = 20

n_tests = len(tests)
pal = get_colours(10, 7)
test_data = []
for i, test in enumerate(tests):
    test_data.append(load_experiment(test, ["q_maps", "single_layer_qmap_sim", "multi_layer_qmap_sim"], "results"))


fig = plt.figure(figsize=(20, 4))

gs = plt.GridSpec(1, 4)
ax1 = plt.subplot(gs[0, 0])
ax2 = plt.subplot(gs[0, 1])
ax3 = plt.subplot(gs[0, 2])
ax4 = plt.subplot(gs[0, 3])

# Add unity line
ax1.plot((0, qmax), (0, qmax), '--', color='k', zorder=900, dashes=(4, 8))
ax1.set_xlim(0, qmax)
ax1.set_ylim(0, qmax)
ax1.set_xlabel('Input variance ($q^{l-1})$')
ax1.set_ylabel('Output variance ($q^l$)')
ax1.set_title("Iterative length map")
ax1.text(2, 10, r'$\sigma^2_w > \frac{2}{\mu_2}$', fontsize=12)
ax1.text(10, 1, r'$\sigma^2_w < \frac{2}{\mu_2}$', fontsize=12)
ax1.text(11, 8.5, r'$\sigma^2_w = \frac{2}{\mu_2}$', fontsize=12)

# axis 2
ax2.set_xlim(0, qmax) #n_hidden_layers-1)
ax2.set_ylim(0, qmax)
ax2.set_xlabel('Layer ($l$)')
ax2.set_ylabel('Variance ($q^{l})$')
ax2.set_title("Dynamics of $q$")

nn = len(test_data)
col_i = 0
for test, attr in zip(test_data, tests):
    for dist in attr["distributions"]:
        if dist['dist'] == "none":
            col_i = 0
        elif dist['dist'] == "bern":
            col_i = 1
        elif "gauss" in dist['dist']:
            col_i = 3

        for act in attr["activations"]:
            for init in attr["inits"]:
                dashes = (None, None)
                if "under" in init:
                    shade_i = 4
                elif "crit" in init:
                    shade_i = 5
                    dashes = (8, 4) if dist['dist'] == "bern" else (None, None)
                else:
                    shade_i = 6

                # extract test data
                qmaps = test[dist['dist']][act][init]['q_maps']['qmaps']
                single_layer_sims = test[dist['dist']][act][init]['single_layer_qmap_sim']
                multi_layer_sims = test[dist['dist']][act][init]['multi_layer_qmap_sim']
                label = ""

                try:
                    label = "Gauss: $\sigma_\epsilon = $ " + str(dist['std'])
                except:
                    try:
                        label = "Dropout: $p = $ " + str(dist['prob_1'])
                    except:
                        pass

                ############################################################################
                # left
                ############################################################################
                for w, b in zip(widxs, bidxs):

                    # plot means of simulation (as dots)
                    ax1.plot(qrange, single_layer_sims[w, b].mean(-1), w, b, marker='o', ls='none', markersize=1, alpha=0.9, zorder=0, c=pal[col_i][shade_i])

                    # add confidence interval around simulation
                    mu = single_layer_sims[w, b].mean(-1)
                    std = single_layer_sims[w, b].std(-1)
                    ax1.fill_between(qrange, mu-std, mu+std, alpha=0.4, label='_nolegend_', color=pal[col_i][shade_i])

                    # theory line
                    ax1.plot(qrange, qmaps[0, 0, :, 1], c=pal[col_i][shade_i], label=label, dashes=dashes)
                    # fixed point
                    

                ############################################################################
                # right
                ############################################################################
                q = 1
                xx = np.arange(n_hidden_layers)
                for w, b in zip(widxs, bidxs):
                    for sim in range(len(qidx)):
                        # confidence intervals
                        mu = multi_layer_sims[w, b, sim].mean(-1)
                        std = multi_layer_sims[w, b, sim].std(-1)

                        # plot theory
                        ax2.plot(qmaps[w, b, qidx[sim], :n_hidden_layers].T, c=pal[col_i][shade_i], label="Theory")

                        # plot the simulation
                        ax2.fill_between(xx, mu-std, mu+std, alpha=0.2, label='_nolegend_', color=pal[col_i][shade_i])

                        # dots for mean
                        ax2.plot(xx, mu, 'o', markersize=4, alpha=0.9, color=pal[col_i][shade_i], label="Simulation")

#plt.gcf().tight_layout()

# add legend 
handles, labels = ax1.get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
ax1.legend(by_label.values(), by_label.keys())


#####################################################################
# Variance critical boundary (at the edge of underflow or overflow) #
#####################################################################

mu21 = np.linspace(1, 2, 100)
sigma1 = 2/mu21

ax3.plot(mu21, sigma1, c="purple", label="Variance critical boundary", linestyle='--')
ax3.fill_between(mu21, 1, sigma1, facecolor='blue', alpha=0.2)
ax3.fill_between(mu21, 2, sigma1, facecolor='red', alpha=0.2)
ax3.text(1.5, 1.6, 'Overflow', fontsize=15)
ax3.text(1.53, 1.45, r'($\sigma^2_w > \frac{2}{\mu_2}$)', fontsize=10)
ax3.text(1.1, 1.2, 'Underflow', fontsize=15)
ax3.text(1.15, 1.05, r'($\sigma^2_w < \frac{2}{\mu_2}$)', fontsize=10)
ax3.text(1.2, 1.7, r'$\sigma^2_w = \frac{2}{\mu_2}$', fontsize=12)
ax3.set_xlim(1, 2)
ax3.set_ylim(1, 2)
ax3.set_xlabel('Weight initialisation ($\sigma^2_w$)')
ax3.set_ylabel('Second moment of noise dist. ($\mu_2$)')
ax3.legend()
ax3.set_title('Variance propagation dynamics')

############################################
# Variance depth scales for dropout noise  #
############################################

pickle_in = open("src/results/under_overflow.pk","rb")
example_dict = pickle.load(pickle_in)[:120]
p = 0.6
num_layers = 1000
nets = np.linspace(10, num_layers, 12, dtype=int)[:10]
inits = np.round(np.linspace(0.1, 2.5, 12),1)
xv, yv = np.meshgrid(nets, inits, sparse=False, indexing='ij')

Z1 = np.log(np.array(example_dict))

bad_indices = np.isnan(Z1) + np.isinf(Z1)
Z1 = np.ma.array(Z1, mask=bad_indices)
cmap = matplotlib.cm.get_cmap(name="Spectral_r")
cmap.set_bad('white', 1.)

pcm = ax4.pcolormesh(yv-0.1, xv-0.1, Z1.reshape(10,12), cmap=cmap)
cbar = fig.colorbar(pcm, ax=ax4, extend='max')
cbar.ax.set_title('$log(q^l)$')

ax4.set_xlabel('Weight initialisation ($\sigma^2_w$)')
ax4.set_ylabel("Number of layers")
ax4.set_title("Variance propagation depth:\ndropout with $p$ = 0.6,\ncrit. init. at $\sigma^2_w = 1.2$")

max_depth = 0
init_theory = np.linspace(0, 2.4, 1000)
depth_per_p_theory = depth("Dropout", init_theory, p)
max_depth = np.max([max_depth, np.max(depth_per_p_theory)])
ax4.plot(init_theory, depth_per_p_theory, label="Theory", c='cyan', linewidth=3)

# plot critical point
crit_point = critical_point("Dropout", p)
ax4.plot([crit_point,]*2, [0, num_layers], color="black", linestyle="--", label="criticality")

ax4.set_ylim(0, 820)
ax4.legend()
ax4.set_xticks(inits[:-1:2])
ax4.text(0.1, 400, 'Underflow', fontsize=10)
ax4.text(1.7, 400, 'Overflow', fontsize=10)


##############
# Add labels #
##############
fig.text(0.11, 0.91, "(A)", fontsize=14)
fig.text(0.31, 0.91, "(B)", fontsize=14)
fig.text(0.51, 0.91, "(C)", fontsize=14)
fig.text(0.71, 0.91, "(D)", fontsize=14)

plt.savefig('varplot.pdf', dpi=200)
#plt.show()