In [1]:
import networkx as nx
from matplotlib import pyplot as plt
import random
import math
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import threading
import queue

In [2]:
def initialize_graph(node_count, contact_radius = 1, queue = None, time = 0, edge_dist = 'exp'):
    square_size = math.sqrt(node_count) #size of square
    custom_pos = {} #custom position for plotting
    G = nx.Graph(time = time) #initialize graph
    for i in range(node_count):
        G.add_node(i)
        G.nodes[i]['state'] = 'S'  
        x,y = random.uniform(0, square_size), random.uniform(0, square_size)
        G.nodes[i]['pos'] = (x,y)
        G.nodes[i]['observed'] = False
        G.nodes[i]['SIR_prob'] = [1,0,0]
        custom_pos[i] = (x,y)
    #initialize node state, color, and position
    #set the colors
    #S = susceptible (green)
    #I = infected (red)
    #R = recovered (blue)
    #set the edges
    edges = []
    if edge_dist == 'geometric':
        edges = nx.geometric_edges(G, radius = contact_radius)
    elif edge_dist == 'exp':
        p_dist = lambda distance: math.exp(-distance/contact_radius)
        edges = nx.soft_random_geometric_graph(node_count, radius = contact_radius, pos=custom_pos, p = 2, p_dist=p_dist).edges
    G.add_edges_from(edges)
    queue.put((G, custom_pos))

In [3]:
def infectRandomNode(G, node_count):
    #infect a random node that is susceptible
    sample = random.sample(G.nodes(), G.number_of_nodes())
    count = 0
    for i in sample:
        if (count >= node_count):
            break
        if G.nodes[i]['state'] == 'S':
            G.nodes[i]['state'] = 'I'
        count += 1

In [4]:
def update(G_t1, G_t2, infection_rate, recovery_rate):
    #update Gt2 based on G_t1 assuming they are similar node size
    for node in G_t1.nodes():
        #check if node is observed
        if G_t1.nodes[node]['observed'] == True:
            G_t2.nodes[node]['observed'] = True
        
        #check node state
        if G_t1.nodes[node]['state'] == 'S':
            for neighbor in G_t1.neighbors(node):
                if G_t1.nodes[neighbor]['state'] == 'I':
                    if random.uniform(0, 1) < infection_rate:
                        G_t2.nodes[node]['state'] = 'I'
                        break
        elif G_t1.nodes[node]['state'] == 'I':
            if random.uniform(0, 1) < recovery_rate:
                G_t2.nodes[node]['state'] = 'R'
            else:
                G_t2.nodes[node]['state'] = 'I'
        else:
            G_t2.nodes[node]['state'] = 'R'
        
        #update SIR probabilities
        G_t2.nodes[node]['SIR_prob'] = G_t1.nodes[node]['SIR_prob']
    return G_t2

In [5]:
def SurroundInfectionRate(graph, node, infection_rate):
    #total infection rate of a node based on the infection rate of its neighbors
    neighbors = list(graph.neighbors(node))
    result = 0
    for neighbor in neighbors:
        result += infection_rate * graph.nodes[neighbor]['SIR_prob'][1]
    return result

In [6]:
def MeanFieldInference(graph_t, infection_rate, recovery_rate):
    predicted_state_t1 = {}
    for node in graph_t.nodes():
        factor = SurroundInfectionRate(graph_t, node, infection_rate)
        Ps = graph_t.nodes[node]['SIR_prob'][0]*(1-factor)
        Pi = (1-recovery_rate)*graph_t.nodes[node]['SIR_prob'][1] + graph_t.nodes[node]['SIR_prob'][0]*factor
        Pr = graph_t.nodes[node]['SIR_prob'][2] + recovery_rate*graph_t.nodes[node]['SIR_prob'][1]
        predicted_state_t1[node] = [Ps, Pi, Pr]
    return predicted_state_t1

In [7]:
def BackloopUpdate(temporal_graphs,t ,delta, node, state):
    mapper = {'S':[1,0,0], 'I':[0,1,0], 'R':[0,0,1]}
    for i in range(t, t-delta-1, -1):
        temporal_graphs[i][0].nodes[node]['SIR_prob'] = mapper[state]
        temporal_graphs[i][0].nodes[node]['observed'] = True

In [75]:
#@param: give a list of nodes that are observed
def DetectObserve(obs, temporal_graphs, observe_per_day, t, delta):
    list_of_nodes = []
    count = 0
    for entry in obs:
        if count == observe_per_day:
            break
        state = temporal_graphs[t][0].nodes[entry]['state']
        obs_state = temporal_graphs[t][0].nodes[entry]['observed']
        if obs_state == True:
            continue
        if state == 'S':
            BackloopUpdate(temporal_graphs, t, delta, entry, 'S')
        elif state == 'I':
            BackloopUpdate(temporal_graphs, t, delta, entry, 'I')
        elif state == 'R':
            BackloopUpdate(temporal_graphs, t, delta, entry, 'R')
        count += 1
        list_of_nodes.append(entry)
    return list_of_nodes
        
        

In [65]:
def SetRemove(graph, list):
    for entry in list:
        print(entry)
        graph.nodes[entry]['observed'] = True
        graph.nodes[entry]['SIR_prob'] = [0,0,1]
        graph.nodes[entry]['state'] = 'R'

In [10]:
def SIR_count(graph):
    S_count = 0
    I_count = 0
    R_count = 0
    for node in graph.nodes():
        if graph.nodes[node]['state'] == 'S':
            S_count += 1
        elif graph.nodes[node]['state'] == 'I':
            I_count += 1
        elif graph.nodes[node]['state'] == 'R':
            R_count += 1
    return S_count, I_count, R_count

In [11]:
def Sampler(graph, s_count):
    nodes = list(graph.nodes())
    random.shuffle(nodes)
    result = []
    count = 0
    for i in range(len(nodes)):
        if graph.nodes[nodes[i]]['observed']==False:
            result.append(nodes[i])
            count += 1
        if count == s_count:
            return result
    return result

In [12]:
def generate_graphs(node_count, contact_radius, nr_graphs):
    que_graphs = queue.Queue()
    thread_list = []
    for t in range(0, nr_graphs):
        thread = threading.Thread(target=initialize_graph, args=(node_count, contact_radius, que_graphs, t), name='thread'+str(t))
        thread.start()
        #print('thread'+str(t)+' started')
        thread_list.append(thread)

    for thread in tqdm(thread_list, desc='Generating Graphs'):
        #print('thread'+thread.name +' joined')
        thread.join()
    temporal_graph_dict = {}
    while not que_graphs.empty():
        G,custom_pos = que_graphs.get()
        temporal_graph_dict[G.graph['time']] = (G, custom_pos)
    return temporal_graph_dict

In [58]:
def contactTrace(temporalgraph_dic, t, node, delta):
    #count the number of nodes that are infected in the past delta days
    count = 0
    for i in range(t-delta, t):
        if temporalgraph_dic[i][0].nodes[node]['SIR_prob'][1] == 1.0 and temporalgraph_dic[i][0].nodes[node]['observed'] == True:
            count += 1
    return count

In [76]:
#mode = 'MF', 'R', 'None'
node_count = 10
infection_rate = 1
recovery_rate = 0.0
contact_radius = 1
time_steps = 20
temporal_graph_dict = {}
observe_per_day = 2
t_mf = 5
t_ct = 5
delay = 10
mode = 'CT'
color_map = {'S': 'green', 'I': 'red', 'R': 'blue'}
def sim(DataFrame, id, mode, queue = None):
    temporal_graph_dict = generate_graphs(node_count, contact_radius, time_steps)
    infectRandomNode(temporal_graph_dict[0][0], 1)
    graph = temporal_graph_dict[0][0]
    S, I, R = SIR_count(graph)
    DataFrame.loc[len(DataFrame)] = [id, 0, S, I, R]
    #string for description of loop:
    desc = 'Simulating (mode: '+mode +')'
    for i in tqdm(range(1, time_steps), desc=desc):
        graph = update(temporal_graph_dict[i-1][0], temporal_graph_dict[i][0], infection_rate, recovery_rate)
        if (mode == 'MF' and i > delay):
            pred = MeanFieldInference(graph, infection_rate, recovery_rate)
            sort_pred = sorted(pred.items(), key=lambda x: x[1])
            node_list = [node[0] for node in sort_pred]
            #print("day", i)
            #print(pred)
            #print(node_list)
            list_remove = DetectObserve(node_list, temporal_graph_dict,observe_per_day,i, t_mf)
            print(list_remove)
            SetRemove(temporal_graph_dict[i][0], list_remove)
            for time in range(i-t_mf, i):
                pred = MeanFieldInference(temporal_graph_dict[time][0], infection_rate, recovery_rate)
                for node in temporal_graph_dict[time][0].nodes():
                    if temporal_graph_dict[time][0].nodes[node]['observed'] == False:
                        temporal_graph_dict[time][0].nodes[node]['SIR_prob'] = pred[node]
        elif (mode == 'R' and i > delay):
            #print("dayR", i)
            sample = Sampler(temporal_graph_dict[i][0], observe_per_day)
            SetRemove(temporal_graph_dict[i][0], sample)
        elif (mode == 'CT' and i > delay):
            ranker = {}
            for node in temporal_graph_dict[i][0].nodes():
                if temporal_graph_dict[i][0].nodes[node]['observed'] == False:
                    ranker[node] = contactTrace(temporal_graph_dict, i, node, t_ct)
            sort_ranker = sorted(ranker.items(), key=lambda x: x[1])
            node_list = [node[0] for node in sort_ranker]
            list_remove = DetectObserve(node_list, temporal_graph_dict,observe_per_day,i, t_ct)
            SetRemove(temporal_graph_dict[i][0], list_remove)
        S, I, R = SIR_count(graph)
        DataFrame.loc[len(DataFrame)] = [id, i, S, I, R]
        
    if queue != None:
        queue.put(DataFrame)
    return temporal_graph_dict

In [80]:
Dataframe = pd.DataFrame(columns=['id', 'time', 'S', 'I', 'R'])
tempograph = sim(Dataframe, 1, 'CT')

Generating Graphs:   0%|          | 0/20 [00:00<?, ?it/s]

since Python 3.9 and will be removed in a subsequent version.
  sample = random.sample(G.nodes(), G.number_of_nodes())


Simulating (mode: CT):   0%|          | 0/19 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9


In [81]:
def drawgraph(tempograph, t):
    graph = tempograph[t][0]
    custom_pos = tempograph[t][1]
    colors = [color_map[graph.nodes[node]['state']] for node in graph.nodes()]
    nx.draw(graph, pos = custom_pos, node_color = colors, with_labels = True)

In [82]:
from ipywidgets import *
interact(drawgraph, tempograph = fixed(tempograph), t = widgets.IntSlider(min=0, max=time_steps-1, step=1, value=0))

interactive(children=(IntSlider(value=0, description='t', max=19), Output()), _dom_classes=('widget-interact',…

<function __main__.drawgraph(tempograph, t)>

In [None]:
#parrallelize simulation per 20 threads
number_of_sim = 100
thread_sim = 20
thread_list = []
Dataframe_final = pd.DataFrame(columns=['id', 'time', 'S', 'I', 'R'])
Queu_sim = queue.Queue()
for x in tqdm(range(number_of_sim), desc='Simulating Main (mode: '+mode +')'):
    Dataframe = pd.DataFrame(columns=['id', 'time', 'S', 'I', 'R'])
    thread = threading.Thread(target=sim, args=(Dataframe, x, mode, Queu_sim), name='thread' + str(x))
    thread.start()
    thread_list.append(thread)
    if len(thread_list) >= thread_sim:
        #reduce thread list
        thread = thread_list.pop(0)
        thread.join()
for thread in tqdm(thread_list, desc='Reducing thread Simulating (mode: '+mode +')'):
    thread.join()
while not Queu_sim.empty():
    Dataframe = Queu_sim.get()
    Dataframe_final = Dataframe_final.append(Dataframe)