This notebook contains the code for branching model and avalanches animations used in [Brain Criticality video](https://youtu.be/vwLb3XlPCB4)


Copyright © 2023 Artem Kirsanov

# I) Branching model animation
<img src="assets/BranchingModelPreview.png">



In [None]:
# Importing modules

import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.ndimage import convolve,convolve1d
from copy import deepcopy


import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.animation import FuncAnimation
import matplotlib.cm as cm
import cmasher
import seaborn as sns

from tqdm.notebook import tqdm
from IPython.display import Video
import sys


# This is to import my utility functions from AK_animations_utils
sys.path.append("../../../Animation/")
from AK_animation_utils import *

## Branching model implementation

Note: for simplicity I implemented a type of "homogeneous" branching model, when full feedforward connectivity (every neuron in layer $i$ projects every neuron in layer $i+1$ with equal transmission probability ($\sigma / N_\text{neurons}$)

In [None]:
SIGMA_CRITICAL = 1
SIGMA_SUBCRITICAL = 0.7
SIGMA_SUPERCRITICAL = 1.2


NUM_LAYERS = 35
NEURONS_PER_LAYER = 20

In [None]:
def network_init(network,first_layer_data=None):
    '''
        Initializes the network with a data of the input (first) layer.
        If None – first layer is zeros
    '''
    if first_layer_data is None:
        first_layer_data = np.zeros(NEURONS_PER_LAYER)
        first_layer_data[0] = 1
    network[0,:] = first_layer_data
    
def network_advance(old_network, sigma,spont_prob):
    '''
        Advance a network a single time step into the future
    '''
    network = deepcopy(old_network) # I know it is super inefficient, but I was too lazy to think + it works in reasonable time, so don't judge me
    spont = np.random.rand(*network.shape)
    network[spont<spont_prob] = 1
    
    for layer_num in range(NUM_LAYERS-1, 0, -1):
        network[layer_num] = (np.random.rand(NEURONS_PER_LAYER) < sigma*np.sum(network[layer_num-1,:])/NEURONS_PER_LAYER)
        network[layer_num-1] = np.zeros(NEURONS_PER_LAYER)
    return network
        

def run_with_input(network, n_steps, sigma, input_interval, input_neurons=3):
    '''
        Run simulation with random input provided onto the first layer with a certain interval (for guessing game scenes)
    '''
    network_states = np.zeros((n_steps, NUM_LAYERS, NEURONS_PER_LAYER))
    network_states[0,:,:] = network
    for step in range(1,n_steps):
        network_states[step, :,:] = network_advance(network_states[step-1,:,:], sigma,0)
        if (step%input_interval)==0:
            network_states[step,0,np.random.choice(np.arange(NEURONS_PER_LAYER), input_neurons, False)]= 1
    return network_states
       
                               
def run_stochastic(network, n_steps, sigma=1, spont_prob=0.01):
    '''
        Run simulation with stochastic activity for n_steps
    '''
    network_states = np.zeros((n_steps, NUM_LAYERS, NEURONS_PER_LAYER))
    network_states[0,:,:] = network
    
    for step in range(1,n_steps):
        network_states[step, :,:] = network_advance(network_states[step-1, :,:], sigma,spont_prob)
    return network_states


def smooth_activity(network_states, time_stretch=3):
    '''
        Smooth the activity in time for a more eye-pleasant animation 
        
        Note that this is for illustration purposes only. Since the neuron is either active or not, there is no notion of "intermediate stage".
        But blinking animation is just not as beautiful
    '''
    def get_symmetric_kernel(slope=-20, npoints=100):
        t = np.linspace(0,1,npoints)
        kernel = np.zeros_like(t)
        t_mask = t>0.5
        kernel[t_mask]=np.exp(slope*t[t_mask])
        kernel[(t<=0.5)]=np.exp(slope*t[t_mask])[::-1]
        return kernel/kernel[t_mask][0]

    kernel = get_symmetric_kernel(-60)
    smoothed_activity = np.zeros((network_states.shape[0]*time_stretch, network_states.shape[1], network_states.shape[2]))
    smoothed_activity[::time_stretch, :, :] = network_states
    smoothed_activity = convolve1d(smoothed_activity, kernel, axis=0,mode="constant",origin=1)
    return smoothed_activity


def run_guessing_fake_critical_case(Ntries=1000, tolerance=2):
    '''
        Note: because this is a stochastic model by nature and I'm using a small-sized simplified version of it,
        even when sigma==1, this does not guarantee the correct behavior (often avalaches die out quickly)
        
        As a workaround, I run a number of tries and select only the good ones – when (#OUTPUT == #INPUT +- tolerance) 
        
        I know this is slightly "faking" the criticality, which would not be acceptable in a research paper, but it does the job of illustrating things ;)
    '''
    successes = []
    silence_frames=2
    NUM_FRAMES=NUM_LAYERS+silence_frames
    for k in tqdm(range(Ntries)):
        network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
        INPUT_NEURONS = np.random.randint(2,NEURONS_PER_LAYER//2, )
        network[0,np.random.choice(np.arange(NEURONS_PER_LAYER),INPUT_NEURONS, False)] = 1
        network_states = run_with_input(network, NUM_FRAMES, SIGMA_CRITICAL,100, 0)

        if np.abs(np.sum(network_states[-1-silence_frames,:]) - INPUT_NEURONS) < tolerance:
            successes.append(network_states)
    return successes

def run_guessing_subcritical_case(Ntries=6):
    ''' Run subcritical case for a few trials'''
    silence_frames=3
    outputs = []
    NUM_FRAMES=NUM_LAYERS+silence_frames
    for k in tqdm(range(Ntries)):
        network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
        INPUT_NEURONS = np.random.randint(NEURONS_PER_LAYER//2,NEURONS_PER_LAYER) 
        network[0,np.random.choice(np.arange(NEURONS_PER_LAYER),INPUT_NEURONS, False)] = 1
        network_states = run_with_input(network, NUM_FRAMES, SIGMA_SUBCRITICAL,100, 0)
        outputs.append(network_states)
    return np.concatenate(outputs)


def run_guessing_supercritical_case(Ntries=6):
    ''' Run supercritical case for a few trials'''
    silence_frames=3
    outputs = []
    NUM_FRAMES=NUM_LAYERS+silence_frames
    for k in tqdm(range(Ntries)):
        network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
        INPUT_NEURONS = np.random.randint(2,NEURONS_PER_LAYER//2)
        network[0,np.random.choice(np.arange(NEURONS_PER_LAYER),INPUT_NEURONS, False)] = 1
        network_states = run_with_input(network, NUM_FRAMES, SIGMA_SUPERCRITICAL,100, 0)
        outputs.append(network_states)
    return np.concatenate(outputs)

## Animation functions

In [None]:
def setup_network_figure(figsize=None):  
    '''
        Set up a matplotlib figure with black background and no axis labels.
        
        If figsize is not provided, it is determined from global NUM_LAYERS and NEURONS_PER_LAYER variables
    '''
    
    if figsize is None:
        figsize = (NUM_LAYERS/5, NEURONS_PER_LAYER/5)
    fig, ax = plt.subplots(1,1,figsize=figsize,dpi=100)
    fig.set_facecolor("black")
    ax.set_facecolor("black")
    ax.axis(False)
    ax.set_xlim(-1,NUM_LAYERS)
    ax.set_ylim(-1,NEURONS_PER_LAYER)
    return fig, ax
    
def draw_network_state_as_pcolormesh(network_state, ax, cmap):
    ''' Drawing a network state and pcolormesh'''
    cmesh = ax.pcolormesh(network_state.T, edgecolors='k', vmin=0, vmax=1,linewidth=2, cmap=cmap)
    ax.set_xlim(0,network_state.shape[0])
    ax.set_ylim(0,network_state.shape[1])
    return cmesh

def animate_network_states(network_states, cmap=cmasher.get_sub_cmap(sns.color_palette("mako",as_cmap=True),0.2,1)):
    '''
        Animate network acitivity
    
        Returns a matplotlib.FuncAnimation instance to saved
    '''
    fig, ax = setup_network_figure(figsize=(NUM_LAYERS,NEURONS_PER_LAYER))

    cmesh = draw_network_state_as_pcolormesh(network_states[0,:,:], ax,cmap=cmap)
    def anim_function(frame_num):
        cmesh.set_array(network_states[frame_num,:,:].T)
        return cmesh,
    
    anim = FuncAnimation(fig, anim_function, frames=tqdm(np.arange(network_states.shape[0])), interval=40)
    return anim


## Running subcritical case

In [None]:
to_run_guessing_game = False # Whether to run a "guessing game" type of animation

if to_run_guessing_game:
    subcritical_states = run_guessing_subcritical_case() # For "guessing game" type of animation
else:
    network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
    subcritical_states = run_stochastic(network, 100, SIGMA_SUBCRITICAL,0.01 ) # For simulation with spontaneous activation for 100 steps and spontaneous probability of 0.01


smoothed_activity_subcritical = smooth_activity(subcritical_states,3)
animation = animate_network_states(smoothed_activity_subcritical)
#animation.save("Subcritical animation.mp4") # Uncomment to save the animation 

## Running critical case

In [None]:
to_run_guessing_game = False # Whether to run a "guessing game" type of animation

if to_run_guessing_game:
    critical_states = np.vstack(run_guessing_fake_critical_case()) # For "guessing game" type of animation
else:
    network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
    critical_states = run_stochastic(network, 100, SIGMA_CRITICAL,0.005 ) # For simulation with spontaneous activation for 100 steps and spontaneous probability of 0.005


smoothed_activity_critical = smooth_activity(critical_states,3)
animation = animate_network_states(smoothed_activity_critical)
#animation.save("Critical animation.mp4") # Uncomment to save the animation 

## Running supercritical case

In [None]:
to_run_guessing_game = False # Whether to run a "guessing game" type of animation

if to_run_guessing_game:
    supercritical_states = run_guessing_supercritical_case() # For "guessing game" type of animation
else:
    network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
    supercritical_states = run_stochastic(network, 100, SIGMA_SUPERCRITICAL,0.001 ) # For simulation with spontaneous activation for 100 steps and spontaneous probability of 0.001


smoothed_activity_supercritical = smooth_activity(supercritical_states,3)
animation = animate_network_states(smoothed_activity_supercritical)
#animation.save("Supercritical animation.mp4") # Uncomment to save the animation 

# II) Avalanches animation and graph rearranging with Manim

<img src="assets/ManimAnimationPreview.png">

In [None]:
from manim import *
import networkx as nx
from scipy.interpolate import interp1d
import itertools

In [None]:
NUM_LAYERS = 10
NEURONS_PER_LAYER = 10
NUM_FRAMES=2000

network = np.zeros((NUM_LAYERS, NEURONS_PER_LAYER), dtype=bool)
network_states = run_stochastic(network, NUM_FRAMES, sigma=1,spont_prob=0.01) # Note that the data for avalanches comes from the same simplified branching model
smoothed_states = smooth_activity(network_states)

In [None]:
def multilayered_graph(subset_sizes, edge_prob=0.35):
    '''
        Generate a networkx multilayered graph with specied layer sizes 
        
        edge_prob – proportions of all possible edges to make visible
    '''
    extents = nx.utils.pairwise(itertools.accumulate([0] + subset_sizes))
    layers = [range(start, end) for start, end in extents]
    G = nx.Graph()
    for (i, layer) in enumerate(layers):
        G.add_nodes_from(layer, layer=i)
    for layer1, layer2 in nx.utils.pairwise(layers):
        all_edges = list(itertools.product(layer1, layer2))
        selected_edges = np.random.choice(range(len(all_edges)),  size=int(len(all_edges)*edge_prob), replace=False)
        for k in selected_edges:  
            G.add_edge(*all_edges[k])
    return G

In [None]:
class BranchingModelRearranging(Scene):
    def construct(self):
        '''
            This Manim Scene creates an animation of "rearranging" the graph.
            
            At first, neurons are randomly scattered on a grid
            Then, after the transition, the layout is rearranged into a layered structure.
            
        '''
        # Size of the initial grid (in # of neurons)
        GRID_SIZE_X = 10
        GRID_SIZE_Y = 10
        
        grid_ax = Axes(x_range=(0,GRID_SIZE_X), y_range=(0,GRID_SIZE_Y),x_length=7, y_length=7)
        layers_ax = Axes(x_range=(0,network_states.shape[1]), y_range=(0,network_states.shape[2]),x_length=13, y_length=7)
        
        G = multilayered_graph(([network_states.shape[2]]*network_states.shape[1]))
        
        # Create a random mapping from units in a branching model to a shuffled grid layout
        mapping = np.array(list(itertools.product(range(grid_ax.x_range[1]), range(grid_ax.y_range[1]))), dtype=object)
        np.random.shuffle(mapping)
        layout = {k: grid_ax.c2p(*mapping[k]) for k in range(len(G.nodes))}
        
        
        # Construct a graph object
        graph = Graph.from_networkx(G,layout=layout,vertex_config={'radius': 0.2}, edge_config={"stroke_width":0.5,
                                                                                               "stroke_color":GRAY})
        
        # Interpolation function to animate the color of the nodes according to simulation data
        value_interp_function = interp1d(np.arange(smoothed_states.shape[0]), 
                                         smoothed_states.reshape(smoothed_states.shape[0], 
                                                                 smoothed_states.shape[1]*smoothed_states.shape[2]), axis=0)
        cmap = cmasher.get_sub_cmap(sns.color_palette("mako",as_cmap=True),0.2,1)

        def update_node_colors(graph):
            for k in range(len(G.nodes)):
                color =  cmap(value_interp_function(time_tracker.get_value())[k])
                graph[k].set_color(rgba_to_color(color))
            
            
        time_tracker = ValueTracker() # Progressing through simulation data
        graph.add_updater(update_node_colors)
        self.add(graph)
        
        # --- Animating (make sure that there is enough frames in the simulation data)
        FPS = 30 
        PLAY_TIME_BEFORE_REARRANGING = 5
        PLAY_TIME_AFTER_REARRANGING = 5
        REARRANGING_TIME = 2
        
        
        def animate_network(playing_time):
            self.play(time_tracker.animate.increment_value(int(playing_time*FPS)), run_time=playing_time, rate_func=linear)
        
        animate_network(PLAY_TIME_BEFORE_REARRANGING)
        animations = [graph[k].animate.move_to(layers_ax.c2p(k//layers_ax.y_range[1], k%layers_ax.y_range[1])) for k in range(len(G.nodes))]
            
        self.play(*(animations + [time_tracker.animate.increment_value(int(REARRANGING_TIME*FPS))]), run_time=REARRANGING_TIME, rate_func=linear)
        animate_network(PLAY_TIME_AFTER_REARRANGING)
        self.wait()

In [None]:
%manim BranchingModelRearranging