In [1]:
# pip install matplotlib kaleido pillow

In [4]:
# pip install pathpy

In [1]:
import networkx as nx
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import random
import kaleido as kld
import matplotlib.animation as animation
from scipy.spatial import distance
from matplotlib.animation import FuncAnimation
import numpy as np
import pathpy as pp
from kaleido.scopes.plotly import PlotlyScope

# Create the two regions
region1 = nx.erdos_renyi_graph(50, 0.1, seed=0)
for node in region1.nodes():
    region1.nodes[node]["color"]="purple"
    region1.nodes[node]["region"]="1"
    
region2 = nx.barabasi_albert_graph(50, 2, seed=0,)
for node in region2.nodes():
    region2.nodes[node]["color"]="cyan"
    region2.nodes[node]["region"]="2"

# Combine the two regions into one graph
G = nx.disjoint_union(region1, region2)
# Assign positions to the nodes
pos = nx.spring_layout(G)
nx.set_node_attributes(G, pos, 'pos')
#nx.draw(G, with_labels=False, node_color= list(nx.get_node_attributes(G,'color').values()))

In [2]:
#Initially 30% of the nodes in both regions 1 and 2 are infected
init_p_infection=0.3
for node in G.nodes:
    if G.nodes[node]['region']=="1":
        if random.random() < init_p_infection:
            G.nodes[node]['status']= "infected"
            G.nodes[node]['color']= "red"
        else:
            G.nodes[node]['status']= "susceptible"
            G.nodes[node]['color']="blue"
    else:
        if random.random() < init_p_infection:
            G.nodes[node]['status']= "infected"
            G.nodes[node]['color']="red"
        else:
            G.nodes[node]['status']= "susceptible"
            G.nodes[node]['color']= "blue"

In [3]:
"""  
Lockdown Condition
- Strict: 20% movement
- Moderate: 40% movement
- None: 70% movement
"""
#initialize lockdown conditions: set region 1 as strict and region 2 as moderate
for node in G.nodes:
    if G.nodes[node]['region']=="1":
        G.nodes[node]["lockdown_status"]="strict"   #assign strict to all regionn 1 agents
    else:
        G.nodes[node]["lockdown_status"]="moderate" 
    

nextG = G.copy()

In [6]:
# G.nodes(data=True
#        )

In [7]:
#0: {'color': 'blue', 'region': '1', 'pos': array([0.56026903, 0.04683293]), 'status': 'susceptible'}

In [12]:
def update_infection_status_base_on_neighbors(G, ref_node):
    
    #get list of neighbors
    neighbors_list = G.neighbors(ref_node)
    neighbor_params = [G.nodes(data=True)[node] for node in neighbors_list]
    neighbor_infection_status = [G.nodes()[node]["infection_status"] for node in neighbors_list]
    num_neighbor_infections = neighbor_infection_status.count("infected")
    
    #updates infection status and position based on neighbors
    if num_neighbor_infections > 5 and G.nodes[ref_node]["infection_status"]!="recovered" :
        if random.random() < 0.8:
            G.nodes[ref_node]["infection_status"]="infected"
            
    elif num_neighbor_infections > 5 and G.nodes[ref_node]["infection_status"]=="recovered":
        if random.random() < 0.3:
            G.nodes[ref_node]["infection_status"]="infected"
        
    return G
    

In [None]:
#update rule for region 1
#if initial status is strict
all_nodes_in_region1 = [node for node in G.nodes if G.nodes[node]["region"]=="1"]

#randomly select 20% nodes as mobile
mobile_nodes_in_reg1 = np.random.choice(all_nodes_in_region1, round(0.2*len(all_nodes_in_region1)))
                                        
for node in G.nodes:
    if G.nodes[node]["lockdown_status"]=="strict" and G.nodes[node]["region"]=="1":
        
        
        if G.nodes[node]["infection_status"]=="susceptible" or "infected":
            #update states based on neighbors
            next_G = update_infection_status_base_on_neighbors(G, node)
        else:
            
        

for nodes in all_nodes_in_region1:
    

In [None]:
cross_region_prob = 0.3

def node_dynamics(graph, cross_region_prob):
    #print(len(list(G.nodes))) 
    #if random.random() < cross_region_prob:
    updated_G = graph.copy()
    node_list = list(updated_G.nodes)
    nodes_attributes = dict(G.nodes(data=True))
    #randomly select nodes to cross
    crossing_nodes = np.random.choice(node_list, int(0.1*len(node_list)), replace=False)
    crossing_nodes_attr = [nodes_attributes[node] for node in crossing_nodes]
    
    #get the regions of the selected nodes
    crossing_nodes_region = [updated_G.nodes[node]["region"] for node in crossing_nodes]
    #count the number of node's neighbors
    crossing_nodes_neighbors_num = [len(list(updated_G.neighbors(node))) for node in crossing_nodes ]
    
    
    #print(G.nodes(data=True))
    #remove the selected nodes fromm their current position
    updated_G.remove_nodes_from(crossing_nodes)
    #print(G.nodes(data=True))

        
        #randomly reattach the removed node to the counterpart region
    for i, (node, region, num_neighbors) in enumerate(zip( crossing_nodes,\
                                               crossing_nodes_region,crossing_nodes_neighbors_num)):
        #get nodes in the other region
        nodes_in_ther_region = [node for node in updated_G.nodes if updated_G.nodes[node]["region"]!=region]

        #randomly select nodes for reattachment
        nodes_for_reattachment= np.random.choice(nodes_in_ther_region, num_neighbors, replace=False)
        
        #readd node and restore the old attributes of the node : note conected yet
        updated_G.add_node(node, **(crossing_nodes_attr[i]))
        
        #create edges for connection
        create_tuple_of_source_target_nodes = [(node,target) for target in  nodes_for_reattachment]
        
        #add edges
        updated_G.add_edges_from(create_tuple_of_source_target_nodes)
        
        #update the region as well
        updated_G.nodes[node]["region"]= G.nodes[nodes_in_ther_region[0]]["region"]
        print(nodes_for_reattachment)
        #update position as the mid of its neighbors: create function for this later
        all_neighbors_pos= np.array([updated_G.nodes[each_node]["pos"] for each_node in nodes_for_reattachment])
        print(all_neighbors_pos)
        avg_x = all_neighbors_pos[:,0].mean()
        avg_y = all_neighbors_pos[:,1].mean()
        updated_G.nodes[node]["pos"] = np.array([avg_x, avg_y])
        
        
    #print(len(list(G.nodes)))    
#     col = [updated_G.nodes[node]['color'] for node in updated_G.nodes]
#    # print(len(list(G.nodes)))
#     print(len(col))
    
    #pos = nx.spring_layout(G)
    #nx.set_node_attributes(G, pos, 'pos')
    #col = [G.nodes[node]['color'] for node in G.nodes]
    #print(len(col))
    #node_color = list(nx.get_node_attributes(update_graph,'color').values())
    #print(len(node_color))
    #print(len(list(G.nodes)))  
    #print()
    #print(updated_G.nodes(data=True))
#         #get node neighbors
#         #node_neighbors = list(G.neighbors(each_node))
        
#         for each_neighbor in node_neighbors:
#             if (each_node, each_neighbor) in G.edges:
#                 G.remove_
    
    return updated_G



In [None]:
cross_region_prob = 0.3

def node_dynamics(graph, cross_region_prob):
    #print(len(list(G.nodes))) 
    #if random.random() < cross_region_prob:
    updated_G = graph.copy()
    node_list = list(updated_G.nodes)
    nodes_attributes = dict(G.nodes(data=True))
    #randomly select nodes to cross
    crossing_nodes = np.random.choice(node_list, int(0.1*len(node_list)), replace=False)
    crossing_nodes_attr = [nodes_attributes[node] for node in crossing_nodes]
    
    #get the regions of the selected nodes
    crossing_nodes_region = [updated_G.nodes[node]["region"] for node in crossing_nodes]
    #count the number of node's neighbors
    crossing_nodes_neighbors_num = [len(list(updated_G.neighbors(node))) for node in crossing_nodes ]
    
    
    #print(G.nodes(data=True))
    #remove the selected nodes fromm their current position
    updated_G.remove_nodes_from(crossing_nodes)
    #print(G.nodes(data=True))

        
        #randomly reattach the removed node to the counterpart region
    for i, (node, region, num_neighbors) in enumerate(zip( crossing_nodes,\
                                               crossing_nodes_region,crossing_nodes_neighbors_num)):
        #get nodes in the other region
        nodes_in_ther_region = [node for node in updated_G.nodes if updated_G.nodes[node]["region"]!=region]

        #randomly select nodes for reattachment
        nodes_for_reattachment= np.random.choice(nodes_in_ther_region, num_neighbors, replace=False)
        
        #readd node and restore the old attributes of the node : note conected yet
        updated_G.add_node(node, **(crossing_nodes_attr[i]))
        
        #create edges for connection
        create_tuple_of_source_target_nodes = [(node,target) for target in  nodes_for_reattachment]
        
        #add edges
        updated_G.add_edges_from(create_tuple_of_source_target_nodes)
        
        #update the region as well
        updated_G.nodes[node]["region"]= G.nodes[nodes_in_ther_region[0]]["region"]
        print(nodes_for_reattachment)
        #update position as the mid of its neighbors: create function for this later
        all_neighbors_pos= np.array([updated_G.nodes[each_node]["pos"] for each_node in nodes_for_reattachment])
        print(all_neighbors_pos)
        avg_x = all_neighbors_pos[:,0].mean()
        avg_y = all_neighbors_pos[:,1].mean()
        updated_G.nodes[node]["pos"] = np.array([avg_x, avg_y])
        
        
    #print(len(list(G.nodes)))    
#     col = [updated_G.nodes[node]['color'] for node in updated_G.nodes]
#    # print(len(list(G.nodes)))
#     print(len(col))
    
    #pos = nx.spring_layout(G)
    #nx.set_node_attributes(G, pos, 'pos')
    #col = [G.nodes[node]['color'] for node in G.nodes]
    #print(len(col))
    #node_color = list(nx.get_node_attributes(update_graph,'color').values())
    #print(len(node_color))
    #print(len(list(G.nodes)))  
    #print()
    #print(updated_G.nodes(data=True))
#         #get node neighbors
#         #node_neighbors = list(G.neighbors(each_node))
        
#         for each_neighbor in node_neighbors:
#             if (each_node, each_neighbor) in G.edges:
#                 G.remove_
    
    return updated_G



In [3]:
# Define a function to update the graph for each time step
def update(frame):
    global G
    
    G = node_dynamics(G, cross_region_prob)
    nx.draw(G, pos=nx.get_node_attributes(G, 'pos'), with_labels=False, node_color=list(nx.get_node_attributes(G, 'color').values()))

# # Animate the node movements
fig, ax = plt.subplots()
ani = animation.FuncAnimation(fig, update, frames=100, interval=200, repeat= True)
plt.show()

[ 7 97 47 99 58 95 85]
[[-0.67308212 -0.20612624]
 [-0.98061189 -0.30528113]
 [-0.70166028 -0.16921061]
 [-0.90811017 -0.21029949]
 [-0.7635808  -0.27699452]
 [-0.85338488 -0.24726062]
 [-0.50063845 -0.31279412]]
[51 63  6 84 33 31 39]
[[0.84009001 0.30886355]
 [0.76472704 0.19348605]
 [0.75366828 0.19652421]
 [0.74766765 0.21040452]
 [0.68756368 0.17784398]
 [0.79700961 0.27011971]
 [0.6876449  0.08168655]]
[75 12 51 32 50 71]
[[0.76524528 0.10609062]
 [0.65325208 0.27006813]
 [0.84009001 0.30886355]
 [0.72274311 0.28580334]
 [0.70070025 0.14456742]
 [0.83642954 0.22657967]]
[16 53 89 68 61]
[[-0.73075834 -0.11973068]
 [-0.70807435 -0.18909039]
 [-0.97512741 -0.05515627]
 [-0.83197069 -0.33084295]
 [-0.62606997 -0.31663416]]
[76  9]
[[0.81552015 0.20906808]
 [0.71753481 0.07893089]]
[ 4 43 46  6 32]
[[0.61234605 0.1863414 ]
 [0.70744931 0.24392271]
 [0.9725395  0.34055419]
 [0.75366828 0.19652421]
 [0.72274311 0.28580334]]
[44]
[[0.72867914 0.22380883]]
[31 39  0]
[[0.79700961 0.27011

Exception in Tkinter callback
Traceback (most recent call last):
  File "/Users/workstation/opt/anaconda3/lib/python3.9/tkinter/__init__.py", line 1892, in __call__
    return self.func(*args)
  File "/Users/workstation/opt/anaconda3/lib/python3.9/tkinter/__init__.py", line 814, in callit
    func(*args)
  File "/Users/workstation/opt/anaconda3/lib/python3.9/site-packages/matplotlib/backends/_backend_tk.py", line 144, in _on_timer
    super()._on_timer()
  File "/Users/workstation/opt/anaconda3/lib/python3.9/site-packages/matplotlib/backend_bases.py", line 1193, in _on_timer
    ret = func(*args, **kwargs)
  File "/Users/workstation/opt/anaconda3/lib/python3.9/site-packages/matplotlib/animation.py", line 1404, in _step
    still_going = super()._step(*args)
  File "/Users/workstation/opt/anaconda3/lib/python3.9/site-packages/matplotlib/animation.py", line 1097, in _step
    self._draw_next_frame(framedata, self._blit)
  File "/Users/workstation/opt/anaconda3/lib/python3.9/site-packages

In [None]:
#Plot multi-graphs in 3D.
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection


class LayeredNetworkGraph(object):

    def __init__(self, graphs, node_labels=None, layout=nx.spring_layout, ax=None):
        """Given an ordered list of graphs [g1, g2, ..., gn] that represent
        different layers in a multi-layer network, plot the network in
        3D with the different layers separated along the z-axis.

        Within a layer, the corresponding graph defines the connectivity.
        Between layers, nodes in subsequent layers are connected if
        they have the same node ID.

        Arguments:
        ----------
        graphs : list of networkx.Graph objects
            List of graphs, one for each layer.

        node_labels : dict node ID : str label or None (default None)
            Dictionary mapping nodes to labels.
            If None is provided, nodes are not labelled.

        layout_func : function handle (default networkx.spring_layout)
            Function used to compute the layout.

        ax : mpl_toolkits.mplot3d.Axes3d instance or None (default None)
            The axis to plot to. If None is given, a new figure and a new axis are created.

        """

        # book-keeping
        self.graphs = graphs
        self.total_layers = len(graphs)

        self.node_labels = node_labels
        self.layout = layout

        if ax:
            self.ax = ax
        else:
            fig = plt.figure()
            self.ax = fig.add_subplot(111, projection='3d')

        # create internal representation of nodes and edges
        self.get_nodes()
        self.get_edges_within_layers()
        self.get_edges_between_layers()

        # compute layout and plot
        self.get_node_positions()
        self.draw()


    def get_nodes(self):
        """Construct an internal representation of nodes with the format (node ID, layer)."""
        self.nodes = []
        for z, g in enumerate(self.graphs):
            self.nodes.extend([(node, z) for node in g.nodes()])


    def get_edges_within_layers(self):
        """Remap edges in the individual layers to the internal representations of the node IDs."""
        self.edges_within_layers = []
        for z, g in enumerate(self.graphs):
            self.edges_within_layers.extend([((source, z), (target, z)) for source, target in g.edges()])


    def get_edges_between_layers(self):
        """Determine edges between layers. Nodes in subsequent layers are
        thought to be connected if they have the same ID."""
        self.edges_between_layers = []
        for z1, g in enumerate(self.graphs[:-1]):
            z2 = z1 + 1
            h = self.graphs[z2]
            shared_nodes = set(g.nodes()) & set(h.nodes())
            self.edges_between_layers.extend([((node, z1), (node, z2)) for node in shared_nodes])


    def get_node_positions(self, *args, **kwargs):
        """Get the node positions in the layered layout."""
        # What we would like to do, is apply the layout function to a combined, layered network.
        # However, networkx layout functions are not implemented for the multi-dimensional case.
        # Futhermore, even if there was such a layout function, there probably would be no straightforward way to
        # specify the planarity requirement for nodes within a layer.
        # Therefor, we compute the layout for the full network in 2D, and then apply the
        # positions to the nodes in all planes.
        # For a force-directed layout, this will approximately do the right thing.
        # TODO: implement FR in 3D with layer constraints.

        composition = self.graphs[0]
        for h in self.graphs[1:]:
            composition = nx.compose(composition, h)

        pos = self.layout(composition, *args, **kwargs)

        self.node_positions = dict()
        for z, g in enumerate(self.graphs):
            self.node_positions.update({(node, z) : (*pos[node], z) for node in g.nodes()})


    def draw_nodes(self, nodes, *args, **kwargs):
        x, y, z = zip(*[self.node_positions[node] for node in nodes])
        self.ax.scatter(x, y, z, *args, **kwargs)


    def draw_edges(self, edges, *args, **kwargs):
        segments = [(self.node_positions[source], self.node_positions[target]) for source, target in edges]
        line_collection = Line3DCollection(segments, *args, **kwargs)
        self.ax.add_collection3d(line_collection)


    def get_extent(self, pad=0.1):
        xyz = np.array(list(self.node_positions.values()))
        xmin, ymin, _ = np.min(xyz, axis=0)
        xmax, ymax, _ = np.max(xyz, axis=0)
        dx = xmax - xmin
        dy = ymax - ymin
        return (xmin - pad * dx, xmax + pad * dx), \
            (ymin - pad * dy, ymax + pad * dy)


    def draw_plane(self, z, *args, **kwargs):
        (xmin, xmax), (ymin, ymax) = self.get_extent(pad=0.1)
        u = np.linspace(xmin, xmax, 10)
        v = np.linspace(ymin, ymax, 10)
        U, V = np.meshgrid(u ,v)
        W = z * np.ones_like(U)
        self.ax.plot_surface(U, V, W, *args, **kwargs)


    def draw_node_labels(self, node_labels, *args, **kwargs):
        for node, z in self.nodes:
            if node in node_labels:
                ax.text(*self.node_positions[(node, z)], node_labels[node], *args, **kwargs)


    def draw(self):

        self.draw_edges(self.edges_within_layers,  color='k', alpha=0.3, linestyle='-', zorder=2)
        self.draw_edges(self.edges_between_layers, color='k', alpha=0.3, linestyle='--', zorder=2)

        for z in range(self.total_layers):
            self.draw_plane(z, alpha=0.2, zorder=1)
            self.draw_nodes([node for node in self.nodes if node[1]==z], s=300, zorder=3)

        if self.node_labels:
            self.draw_node_labels(self.node_labels,
                                  horizontalalignment='center',
                                  verticalalignment='center',
                                  zorder=100)


if __name__ == '__main__':

#     # define graphs
#     n = 5
#     g = nx.erdos_renyi_graph(4*n, p=0.1)
#     h = nx.erdos_renyi_graph(3*n, p=0.2)
#     i = nx.erdos_renyi_graph(2*n, p=0.4)

    #node_labels = {nn : str(nn) for nn in range(4*n)}
    #plt.rcParams["figure.figsize"]=(10,14)
    # initialise figure and plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    LayeredNetworkGraph([region1, region2],ax=ax, layout=nx.spring_layout)
    ax.set_axis_off()
    plt.show()

In [None]:
# Initialize the figure
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

# Define the update function for the animation
def update(frame):
    global G
    G = node_dynamics(G, cross_region_prob)
    node_color = list(nx.get_node_attributes(G, 'color').values())
    pos = nx.get_node_attributes(G, 'pos')
    nx.draw(G, pos, with_labels=False, node_color=node_color, ax=ax)

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=100, interval=100, repeat=True)

# Show the animation
plt.show()
