In [14]:
import jax.numpy as jnp
import pickle
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

from jax.flatten_util import ravel_pytree

In [None]:
# Load the desired pickle file
activation_dict = pickle.load(open('/mnt/SharedDrive/Repositories/Maxwell_demon/logs/metadata/asymptotic_live_neurons/2022-06-19---June 19---10:36:20/activations_meta.p',
                             'rb'))
print(activation_dict.keys())

In [None]:
def count_total_neurons(activations):
    """return the total neurons in given model"""
    total_neurons = []
    for layer_activ in activations:
        flat_layer_activations, unravel_fn = ravel_pytree(layer_activ)
        total_neurons.append(len(flat_layer_activations))
    return total_neurons

def count_epsilon_close(eps, activations):
    """Return the number of epsion_close live neurons in the given model"""
    count=[]
    for layer_activ in activations:
        flat_layer_activations, unravel_fn = ravel_pytree(layer_activ)
        count.append(jnp.sum(flat_layer_activations >=eps))
    return count  

In [None]:
%matplotlib notebook

# Inspecting maximum activation value:
key = 'maximum'
activations = activation_dict[key]

# Initial eps value
init_eps = 0

# total_neurons
total_neurons = count_total_neurons(activations)

# Create the figure
fig, ax = plt.subplots()
line, = plt.plot(total_neurons, count_epsilon_close(init_eps, activations), lw=2)
ax.set_xlabel('Total number of neurons')

# adjust the main plot to make room for the sliders
plt.subplots_adjust(bottom=0.25) 

# Make a horizontal slider to control epsilon
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03])
eps_slider = Slider(
    ax=axfreq,
    label='epsilon',
    valmin=0.0,
    valmax=0.0001,
    valinit=init_eps,
)

# The function to be called anytime a slider's value changes
def update(val):
    line.set_ydata(count_epsilon_close(eps_slider.val, activations))
    fig.canvas.draw_idle()
    
# register the update function with each slider
eps_slider.on_changed(update)

plt.show()

In [None]:
%matplotlib notebook

# Inspecting mean activation value:
key = 'mean'
activations = activation_dict[key]

# Initial eps value
init_eps = 0

# total_neurons
total_neurons = count_total_neurons(activations)

# Create the figure
fig, ax = plt.subplots()
line, = plt.plot(total_neurons, count_epsilon_close(init_eps, activations), lw=2)
ax.set_xlabel('Total number of neurons')

# adjust the main plot to make room for the sliders
plt.subplots_adjust(bottom=0.25) 

# Make a horizontal slider to control epsilon
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03])
eps_slider = Slider(
    ax=axfreq,
    label='epsilon',
    valmin=0.0,
    valmax=0.0001,
    valinit=init_eps,
)

# The function to be called anytime a slider's value changes
def update(val):
    line.set_ydata(count_epsilon_close(eps_slider.val, activations))
    fig.canvas.draw_idle()
    
# register the update function with each slider
eps_slider.on_changed(update)

plt.show()

In [None]:
%matplotlib notebook

# Inspecting count value:
key = 'count'
activations = activation_dict[key]

# Initial eps value
init_eps = 1

# total_neurons
total_neurons = count_total_neurons(activations)

# Create the figure
fig, ax = plt.subplots()
line, = plt.plot(total_neurons, count_epsilon_close(init_eps, activations), lw=2)
ax.set_xlabel('Total number of neurons')

# adjust the main plot to make room for the sliders
plt.subplots_adjust(bottom=0.25) 

# Make a horizontal slider to control epsilon
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03])
eps_slider = Slider(
    ax=axfreq,
    label='epsilon',
    valmin=1,
    valmax=100,
    valinit=init_eps,
)

# The function to be called anytime a slider's value changes
def update(val):
    line.set_ydata(count_epsilon_close(eps_slider.val, activations))
    fig.canvas.draw_idle()
    
# register the update function with each slider
eps_slider.on_changed(update)

plt.show()