# Setup and Initialisation of Energy Landscape graph

In [None]:
#data handling
import networkx as nx
import pandas as pd
import numpy as np

#other maths
import random
from collections import Counter
import math

#plotting
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.cm as cm
import seaborn as sns



In [None]:
# initialise a hypercube graph
g = nx.hypercube_graph(7)


# ADD ACTIVATION STATE DATA
# read in the activation states, such that energy values from the matlab code output can be correctly assigned to each activation state
activation_states = pd.read_csv("activation_states.csv", delimiter= "\t",header=None)
activation_states[activation_states==-1]=0 #aestetic change of binarisation for easier visual inspection

# states to dictionary
as_dict = pd.DataFrame.to_dict(activation_states,orient = "list")

# reverse dictionary such that (key:value) == (nx node name: ID in MAtlab data)
as_dict_correct = dict()
for key in as_dict.keys():
    #print(tuple(as_dict[key]))
    as_dict_correct[tuple(as_dict[key])]=key

# add ELA_ID values to the graph
nx.set_node_attributes(g, as_dict_correct, name="ELA_ID")

# example of ELA_ID retrieval
g.nodes[(0,1,0,0,0,1,0)]["ELA_ID"]

# ADD ENERGY VALUE DATA
# Read values to dict
EVk = pd.read_csv("EVk.csv", delimiter= "\t",header=None)
EVk_dict = pd.DataFrame.to_dict(EVk)[0]

# add EVk to relevant nodes
for node in g.nodes:
    ELA_key = g.nodes[node]["ELA_ID"]
    g.nodes[node]["EVk"] = EVk_dict[ELA_key]
    # set all to local minimum as default
    g.nodes[node]["local_minimum"] = True


# Calculation of Local Minima

In [None]:
for node in g.nodes:
    node_EVk = g.nodes[node]["EVk"]
    for nay in nx.neighbors(g,node):
        nay_EVk = g.nodes[nay]["EVk"]
        if nay_EVk<=node_EVk:
            # if any neighbour has a higher energy value the node is not a local minimum
            g.nodes[node]["local_minimum"] = False
        if nay_EVk==node_EVk:
            # to catch when two energy values are the same, as this is not expected behavior 
            print("EQUAL!")

#our local minima
min_df = pd.DataFrame(g.nodes("local_minimum"))
min_df = min_df[min_df[1]==True][0]
min_ids = [g.nodes[i]['ELA_ID'] for i in min_df]
print ("Local Minima activation states: \n",min_df)
print ("\n Local Minima ELA IDS: \n",min_ids)


# example node
print("\n nodes now look like this:", g.nodes[(0,0,0,0,1,1,1)])

# Disconnectivity graphs


## calculation of disconnectivity graph

In [None]:
# initialise connected matrix, describing energy values at which local minima disconnect, if all nodes over that energy value are removed 
# initialised with -9999
connected_matrix = pd.DataFrame(-9999,index=min_ids, columns=min_ids,dtype=float)

#EVK sorted
EVk_removal_list = sorted(list(EVk_dict.values()))
connected = 999
#copy of graph to change
g_dis = g.copy()
prev_connected = 1000
all_stages_g = []
stages_g = [] 
while connected > 0:
    connected = 0
    # pop largest
    largest_EVk = EVk_removal_list.pop()
    selected = [x for x,y in g_dis.nodes(data=True) if y['EVk']==largest_EVk]
    # remove largest
    for n in selected:
        g_dis.remove_node(n)
    all_stages_g.append(g_dis.copy())
    #check connected
    for i in min_ids:
        for j in min_ids:
            if i==j:
                pass
            else:
                i_node = [x for x,y in g_dis.nodes(data=True) if y["ELA_ID"]==i][0]
                j_node = [x for x,y in g_dis.nodes(data=True) if y["ELA_ID"]==j][0]
                # check connected
                path_exists =  nx.has_path(g_dis,i_node,j_node)

                if path_exists:
                    connected+=1
                else:
                    # if not already changed then change ev value to max of matrix and EKV
                    if connected_matrix[i][j] <= -9999:
                    # store info
                        connected_matrix[i][j] = largest_EVk
                        connected_matrix[j][i] = largest_EVk
    if prev_connected != connected:
        # print("connections between minima: ", int(connected/2), "   number of connected components in graph: ", nx.number_connected_components(g_dis))
        stages_g.append(g_dis.copy())

    prev_connected = connected

print("The matrix after computing at which EVk, the  minima disconnect \n")
connected_matrix
# the graph, at different stages is saved in the stages_g list 

## plotting setup

In [None]:
#g is used here to normalise globally across all graphs, such that positions and colours are consitant
#init positions for graphs

tsne = TSNE(n_components=2, random_state=42, init="pca",learning_rate="auto" ,perplexity=30)
tsne_positions = tsne.fit_transform(np.array(g.nodes))

iterations = 30
for steps in range (iterations):
    tsne_positions[:,1] = [g.nodes[i]["EVk"] for i in g.nodes()]
    tsne = TSNE(n_components=2, random_state=42,learning_rate="auto" ,perplexity=7+iterations-steps,init = tsne_positions)
    tsne_positions = tsne.fit_transform(np.array(g.nodes))
    tsne_positions[:,1] = [g.nodes[i]["EVk"] for i in g.nodes()]

positions = {}
for idx, node in enumerate(g.nodes):
    # Assign 2D position from PCA result
    positions[node] = (tsne_positions[idx, 0],tsne_positions[idx, 1]*3)
x_min,x_max = min([pos[0] for pos in positions.values()]),max([pos[0] for pos in positions.values()])
y_min,y_max = min([pos[1] for pos in positions.values()]),max([pos[1] for pos in positions.values()])


# init node colours for graphs
all_evk_values = nx.get_node_attributes(g, "EVk").values()

# Normalize EVk values
min_evk = min(all_evk_values)
max_evk = max(all_evk_values)
norm_colour = plt.Normalize(vmin=min_evk, vmax=max_evk+2)
colormap = cm.hot


In [None]:
#plot
g_dis_draw = stages_g[0]
print(nx.number_connected_components(g_dis_draw))
fig = plt.figure(1, figsize=(10, 10))
plt.xlim(x_min-1,x_max+1)
plt.ylim(y_min-1,y_max)
node_colours = [colormap(norm_colour(i)) for i in nx.get_node_attributes(g_dis_draw, "EVk").values()]
node_sizes = [i*300+200 for i in nx.get_node_attributes(g_dis_draw, "local_minimum").values()]
nx.draw_networkx(g_dis_draw,with_labels=False,node_color = node_colours, node_size = 50 ,alpha = 0.9,width = 0.3,pos=positions)


In [None]:

#graphs = all_stages_g[0:len(all_stages_g):10] # plot every 10th stage
#graphs = all_stages_g #plot all stages
graphs = stages_g # plot stages where new components appear

num_graphs = len(graphs)
cols = 5  # Number of columns for the subplots grid
rows = (num_graphs + cols - 1) // cols  # Calculate rows needed

fig, axes = plt.subplots(rows, cols, figsize=(25, 7 * rows))
axes = axes.flatten() 


for i, g_dis_draw in enumerate(graphs):
    ax = axes[i]  # Select the subplot axis
    
    # Compute node sizes based on "local_minimum" attribute
    node_sizes = [(v * 300) + 100 for v in nx.get_node_attributes(g_dis_draw, "local_minimum").values()]
    node_colours = [colormap(norm_colour(i)) for i in nx.get_node_attributes(g_dis_draw, "EVk").values()]
    # Draw the graph with consistent layout and colors
    nx.draw_networkx(
        g_dis_draw,
        pos=positions,  # Use positions based on PCA
        ax=ax,
        node_color=node_colours,  # Use precomputed gradient colors based on global EVk values
        with_labels=False,
        node_size=node_sizes,
        alpha=1,
        width=0.4
    )
    ax.set_xlim(x_min-1,x_max+1)
    ax.set_ylim(y_min-1,y_max)
    ax.set_title(f"Graph {i}")

# Hide any unused subplots
for ax in axes[num_graphs:]:
    ax.axis("off")

plt.tight_layout()
plt.show()


## plot disconnectivity graph

In [None]:
EVk_disconnection_values = sorted(set(connected_matrix.values.flatten()))[1:]
disconnectivity_graph = nx.Graph()
for node_id in min_ids:
    g.nodes[(0,0,0,1,0,0,0)]["EVk"]
    node_state =  [x for x,y in g.nodes(data=True) if y["ELA_ID"]==node_id][0]
    disconnectivity_graph.add_node(node_id, activation_state = node_state,node_type = "activation",EVk = g.nodes[node_state]["EVk"] )

min_ids_not_in_graph = min_ids.copy()
for i in EVk_disconnection_values:
    disconnectivity_graph.add_node(f"{round(i,4)}", activation_state = None,node_type = "connection",EVk = i )

for node_a in min_ids:
    connect_these = sorted(set(connected_matrix[node_a]),reverse=True)
    self_connection = connect_these.pop()
    lowest = connect_these.pop()
    disconnectivity_graph.add_edge(f"{round(lowest,4)}",node_a)
    while connect_these:
        new_lowest = connect_these.pop()
        disconnectivity_graph.add_edge(f"{round(lowest,4)}",f"{round(new_lowest,4)}")
        lowest=new_lowest


fig = plt.figure(1, figsize=(10, 7))
spring_positions = nx.spring_layout(disconnectivity_graph, seed=42)  # Generate layout for x-axis
dis_positions = {node: (spring_positions[node][0], attrs["EVk"]) 
             for node, attrs in disconnectivity_graph.nodes(data=True)}

color_map = {'activation': 'lightblue', 'connection': 'lightgreen'}
nodes_colour  =  [color_map[n] for n in list(nx.get_node_attributes(disconnectivity_graph, "node_type").values())]

nx.draw_networkx(disconnectivity_graph,with_labels=True,node_color= nodes_colour,pos=dis_positions,node_size =1000,edge_color = "gray")


# Classification of nodes to basins

## remove upward edges

In [None]:
g_c = g.copy()
g_c = g_c.to_directed()
copy_of_edges = list(g_c.edges())
for edge in copy_of_edges:
    if g_c.nodes()[edge[0]]["EVk"]<=g_c.nodes()[edge[1]]["EVk"]:
        g_c.remove_edge(edge[0],edge[1])
    node

### plot only down edges

In [None]:
plot_g = g_c
fig = plt.figure(1, figsize=(10, 7))
edge_max_EVk = [max([plot_g.nodes[i[0]]["EVk"],plot_g.nodes[i[1]]["EVk"]]) for i in plot_g.edges()]
edge_colours = [colormap(norm_colour(i)) for i in edge_max_EVk]
node_sizes = [i*300+200 for i in nx.get_node_attributes(plot_g, "local_minimum").values()]
node_colours = [colormap(norm_colour(i)) for i in nx.get_node_attributes(plot_g, "EVk").values()]

nx.draw_networkx(plot_g,node_color=node_colours,with_labels=False,node_size = node_sizes,alpha = 0.9,width = 0.4,pos=positions)

In [None]:
g_c2 = g_c.copy()
for node in g_c2.nodes:
    if g_c2.nodes()[node]["local_minimum"] == False:
        n_EVks = dict()
        for n_node in g_c2.neighbors(node):
            # print(g_c2.nodes()[n_node]["EVk"])
            n_EVks[g_c2.nodes()[n_node]["EVk"]] = node
        min_n_EVk = min(n_EVks.keys())
        # print("Keep only", min_n_EVk)
        neigh_list = [i for i in g_c2.neighbors(node)]
        for n_node in neigh_list:
            if g_c2.nodes()[n_node]["EVk"]!=min_n_EVk:
                g_c2.remove_edge(node,n_node)


## keep only lowest edge

In [None]:
# basin as attribute
g_c2_b = g_c2.copy()
nx.set_node_attributes(g_c2_b,"none",name="basin")
def get_basin(node):
    if g_c2_b.nodes[node]["local_minimum"]==True:
        g_c2_b.nodes[node]["basin"]=node
        return node
    else:
        basin = get_basin([i for i in g_c2_b.neighbors(node)][0])
        g_c2_b.nodes[node]["basin"]=basin
        return basin 

for node in g_c2_b.nodes():
    get_basin(node)

#adjust position by basin
basin_pos_adjust = {j:x*30 for x,j in enumerate(set([g_c2_b.nodes[i]["basin"] for i in g_c2_b.nodes]))}
positions2 = positions.copy()
for i in g_c2:
    basin_of_i = g_c2_b.nodes[i]["basin"]
    positions2[i] = (positions2[i][0]+basin_pos_adjust[basin_of_i],positions2[i][1])


### plot by basin

In [None]:
plot_g = g_c2_b
fig = plt.figure(1, figsize=(20, 7))
nx.draw_networkx(plot_g,node_color=node_colours,with_labels=False,node_size = node_sizes,alpha = 0.9,width = 0.4,pos=positions2)

## Active Regions in each basin

In [None]:

# Set up the main graph and categories
plot_g = g_c2_b
brain_region = 5  # The index of the attribute indicating the brain region

# Get the unique brain region categories
unique_categories = [i for i in range(7)]
color_palette = cm.get_cmap("Set1", 2)
category_to_color = {cat: color_palette(i) for i, cat in enumerate(unique_categories)}

# Set up subplots
num_categories = len(unique_categories)
cols = 1  # Number of columns in the subplot grid
rows = (num_categories + cols - 1) // cols  # Calculate rows needed

fig, axes = plt.subplots(rows, cols, figsize=(25, 3 * rows))
axes = axes.flatten()

# Iterate over unique brain regions and create subplots
for idx, category in enumerate(unique_categories):
    ax = axes[idx]

    region_or_not = set([node[category] for node, attrs in plot_g.nodes(data=True)])

    node_colors = [category_to_color[node[category]] for node in plot_g.nodes]

    node_sizes = [(v * 250) + 100 for v in nx.get_node_attributes(plot_g, "local_minimum").values()]
    edge_colors = [category_to_color[edge[0][category]]  for edge in plot_g.edges]  # Use the source node color
    min_nodes_names = [i for i in g_c2_b.nodes if g_c2_b.nodes[i]["local_minimum"]==True]
    pos_min_dict = {i: (1 + x % 3, (x >= 3) + 1) for x, i in enumerate(min_nodes_names)}
    nx.draw_networkx(plot_g,ax=ax,node_color=node_colors, with_labels=False,node_size = node_sizes,alpha = 0.9,width = 0.4,pos= positions2,edge_color= edge_colors)
    ax.set_title(f"Brain Region {category}")

# Hide any unused subplots
for ax in axes[num_categories:]:
    ax.axis("off")

plt.tight_layout()
plt.show()


In [None]:
plot_g = g_c2_b
regions = range(7)
region_activation = pd.DataFrame(data = 0, index = min_nodes_names, columns= regions,dtype=float)
for i in regions:
    for basin in min_nodes_names:
        relevant = [node[i] for node in plot_g.nodes if plot_g.nodes[node]["basin"] == basin ]
        region_activation[i][basin] = sum(relevant)/len(relevant)

In [None]:
fig, ax = plt.subplots()
sns.heatmap(region_activation, annot=True,cmap="Blues", cbar_kws={'label': 'Time Active'});
ax.set_xlabel("Regions")
ax.set_ylabel("Basins")
#region percentage time active

# Random walk

## Simulation

In [None]:

# graph with all edges
rw_g = g.copy() #g_c2_b has basin data for lookup, but we want to random walk on the actual graph (no edges removed)
rw_g.nodes(data=True)
#choose random neighbour, then move to there with probabiliry min[e^(ei-ej)]
current_node = random.choice(list(rw_g.nodes))
# set random seed
path = []
path_basins= []
# choose a random start node
while len(path)<100000:
    chosen_neighbour = random.choice([n for n in rw_g.neighbors(current_node)])
    current_node_EVk = rw_g.nodes[current_node]["EVk"]
    neighbour_node_EVk = rw_g.nodes[chosen_neighbour]["EVk"]
    chance = min (1,pow(math.e,current_node_EVk-neighbour_node_EVk))
    prob = random.uniform(0, 1)
    if prob<chance:
        current_node = chosen_neighbour
        path.append(chosen_neighbour)
        path_basins.append(g_c2_b.nodes[chosen_neighbour]["basin"])


In [None]:
[i for i in g_c2_b.nodes if g_c2_b.nodes[i]["local_minimum"]==True]

In [None]:
# replaced_basins

In [None]:

easier_names = dict()
letters = ["a", "b", "c", "d","e","f"]
easier_names = {element: letters[i % len(letters)] for i, element in enumerate(set(path_basins[100:]))}

# Replace occurrences in the sliced list
replaced_basins = [easier_names[item] for item in path_basins[100:]]
# use abcd instead, easier
# group the occurences for easier analyisis
last_basin= "Start"
current_counter = 0
basins_grouped=[]
for i in replaced_basins:
    current_basin = i
    current_counter+=1
    if last_basin != current_basin:
        basins_grouped.append([last_basin,current_counter])
        current_counter = 0
    last_basin = current_basin
basins_grouped.append([last_basin,current_counter+1])
basins_grouped = basins_grouped[1:]
#calculate dwelling time # how long in each state
# calculate transition frequency from state to state
print(basins_grouped)



In [None]:
states_grouped_transitions = {l:[] for l in letters}
states_grouped_dwell = {l:[] for l in letters}

last_state = None
first = True
for i in basins_grouped:
    states_grouped_dwell[i[0]].append(i[1])
    if not first:
        states_grouped_transitions[last_state[0]].append(i[0])
    first = False
    last_state = i



## Dwelling time 

In [None]:
dwell_plot_dict = Counter(states_grouped_dwell["b"])
plt.bar(dwell_plot_dict.keys(),dwell_plot_dict.values())

In [None]:
for l in letters:
    print(l,"mean dwelling time",np.mean(states_grouped_dwell[l]),"variance of dwelling time",np.var(states_grouped_dwell[l]))


In [None]:
plt.bar(letters, [np.mean(states_grouped_dwell[l])for l in letters])

In [None]:
for l in letters:
    print(l,Counter(states_grouped_transitions[l]))
#b and c most stable, refer back and align with 2 most stable from paper a and be are these the same??

In [None]:
transition_matrix = pd.DataFrame({l:Counter(states_grouped_transitions[l]) for l in letters},columns=letters,index=letters)
transition_matrix_av = (transition_matrix+transition_matrix.T)/2
transition_matrix_n = transition_matrix.div(transition_matrix.sum(axis=1), axis=0)#normalise

sns.heatmap(transition_matrix_n, annot=True,cmap="Blues");

## ELA Transition Graph

### Undirected

In [None]:
G_ST = nx.from_pandas_adjacency(transition_matrix_av)
G_ST.remove_edges_from(nx.selfloop_edges(G_ST))

In [None]:
fig = plt.figure(figsize = (7,7))
weights =  nx.get_edge_attributes(G_ST,'weight')
edge_weights_GST = [weights[i]/200 for i in G_ST.edges]
edge_c_GST = [(i/35)+0.5 for i in edge_weights_GST]
#scale dwelling time to node area
node_size_GST = [np.mean(states_grouped_dwell[i])**2*314 for i in G_ST.nodes]
nx.draw_networkx(G_ST, edge_color= cm.Greys(edge_c_GST), width = edge_weights_GST, pos= nx.spring_layout(G_ST, seed=42), node_color = "lightblue", node_size = node_size_GST,connectionstyle ="arc3,rad=0.08"  )

### Directed

In [None]:
G_ST_dir = nx.from_pandas_adjacency(transition_matrix_av,create_using=nx.DiGraph())
G_ST_dir.remove_edges_from(nx.selfloop_edges(G_ST_dir))

In [None]:
fig = plt.figure(figsize = (7,7))
weights =  nx.get_edge_attributes(G_ST_dir,'weight')
edge_weights_GST = [weights[i]/400 if weights[i]>0 else 0 for i in G_ST_dir.edges ]
edge_c_GST = [(i/18)+0.5 for i in edge_weights_GST]

#scale dwelling time to node volume if sphere (more intuitive)
node_size_GST = [(np.mean(states_grouped_dwell[i])**3)*314 for i in G_ST_dir.nodes]
nx.draw_networkx(G_ST_dir, edge_color= cm.Greys(edge_c_GST), width = edge_weights_GST, pos= nx.spring_layout(G_ST_dir, seed=42,k =100), node_color = "lightblue", node_size = node_size_GST,connectionstyle ="arc3,rad=0.08"  )