In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from random import choices
import copy
from statistics import mean,stdev
from timing import timeit
from fileload import load_matfile

In [2]:
G_blogcatalog = load_matfile(file_ = "blogcatalog.mat",unDirected=False)

## traditional SIR 
- O(): susceptible_neighbors[0] + susceptible_neighbors[1] + ...
- if i hasn't recover, has to sample p(i->j) again
## my SIR 
- O(): max(susceptible_neighbors)
- i's recover time: sample once, t(i->j): sample once

In [3]:
import math
import random
# class State(object):
#     susceptible = 0
#     infected = 1
#     recovered = 2
# def reset(G):
#     nx.set_node_attributes(G, name = 'state', values = State.susceptible)  

def sample_infect_interval(r=0.03):
    '''
    r is infection rate
    sample from p(t) = r*exp(-r*t)
    '''
    return -(1.0/r)*math.log(random.random())

def sample_recover_interval(mu=0.05):
    '''
    mu is recover rate
    sample from p(t) = mu*exp(-mu*t)
    '''
    return -(1.0/mu)*math.log(random.random())

def insert_infect_time(nodelist, j, t_ij):
    if j in nodelist:
        n = 0
        while nodelist[j][n] < t_ij:
            n += 1
            if n == len(nodelist[j]):
                break
        nodelist[j].insert(n,t_ij)
    else:
        nodelist[j] = [t_ij]

In [4]:
dd = {1:[1,2]}
insert_infect_time(dd, 2, 0.5)
dd

{1: [1, 2], 2: [0.5]}

In [5]:
def insert_infect_times(G, infectious_node, t_i, recover_interval, nodes_to_infect):
    '''
    For each infectious_node's susceptible neighbor,
    insert possible infect time to its possible "be-infected time"s list.
    '''
    susceptible_nns = [k for _,k in list(G.out_edges(infectious_node))]
    for k in susceptible_nns:
        infect_interval = sample_infect_interval()
        if infect_interval < recover_interval:
            insert_infect_time(nodes_to_infect, k, t_i+infect_interval)  



@timeit
def set_seeds(G, nodes_to_infect, recover_times, seeds=[4,484,60]):
    '''
    For each infected node, 
    no node would infect it again, so we remove its in_edges
    it can only recover, or infect other nodes through out_edges
    '''
    for i in seeds:
        # remove i's in_edges
        G.remove_edges_from(list(G.in_edges(i)))
        
        # sample a recover_interval for i and add it to recover_times
        recover_interval = sample_recover_interval()
        recover_times[i] = 0.0 + recover_interval
        
        # insert possible infect times result from i
        insert_infect_times(G, i, 0.0, recover_interval, nodes_to_infect)
        
def update_state(G, nodes_to_infect, recover_times):
    '''
    choose a recovering/infection event to happen
    remove the node if it recovered
    remove the node's in_edges if it's infected
    '''
    j,possible_t_j = min(nodes_to_infect.items(), key=lambda t: t[1][0])
    min_t_j = possible_t_j[0]
    
    if len(recover_times) == 0:
        min_tau_i = math.inf
    else:
        i, min_tau_i = min(recover_times.items(), key=lambda t: t[1])
        
    if min_tau_i < min_t_j:
        recover_times.pop(i, None)
        G.remove_node(i)
    else:
        t_j = min_t_j
        recover_interval = sample_recover_interval()
        recover_times[j] = t_j + recover_interval
        
        insert_infect_times(G, j, t_j, recover_interval, nodes_to_infect)
        # remove the first element in dict
        nodes_to_infect.pop(j, None)
        G.remove_edges_from(list(G.in_edges(j)))

In [7]:
# G_blogcatalog = load_matfile(file_ = "blogcatalog.mat",unDirected=False)
# assert(G_blogcatalog.number_of_nodes() == 10312)
# assert(G_blogcatalog.number_of_edges() == 333983)
G_FB = nx.read_edgelist("NOLAfacebook.csv", create_using=nx.DiGraph(), delimiter=",", nodetype=int)
recover_times = {}
nodes_to_infect = {}

set_seeds(G_FB, nodes_to_infect, recover_times, seeds = [2332, 471, 554, 2322, 451])

# while len(nodes_to_infect) > 0:
#     update_state(G_FB, nodes_to_infect, recover_times)
#     print("\r num_nodes: {} ".format(G_FB.number_of_nodes()),end = "")
# print()

'set_seeds'  11.22 ms


In [14]:
time.time()

1538361417.823424