In [None]:
# For graph vis
!pip install pyvis

# For conversion from png to gif
!pip install imageio

In [None]:
import pandas as pd
import numpy as np
import networkx as nx
from random import randint
from random import random
from pyvis.network import Network
import matplotlib.pyplot as plt
import seaborn as sns
import imageio
import os
import re

# Population Class

In [None]:
class Population:
    def __init__(self, pop_size, fatality_prob, infection_prob, connection_prob):
        
        self.pop_size = pop_size
        
        # Probability hyperparams
        self.fatality_prob = fatality_prob
        self.infection_prob = infection_prob
        self.connection_prob = connection_prob
        
        # Metrics
        self.no_of_infections = 1
        self.no_of_deaths = 0
        self.no_citizens_recovered = 0
        
        # Long run metrics
        self.infections_per_day = {}
        self.deaths_per_day = {}
        
        # Citizen Info
        self.citizen_info = [{"is_infected": 0, 
                              "days_infected": 0, 
                              "is_dead":0,
                              "times_infected":0} 
                             for i in range(0, pop_size)]
        
        # Create connection matrix for population
        self.population_matrix = np.zeros(shape = (pop_size, pop_size), dtype = int)
        for citizen in range(pop_size): self.make_friends(citizen)
        
        # 1 person is randomly infected
        rand_citizen = randint(0, pop_size-1)
        self.citizen_info[rand_citizen]["is_infected"] = 1
        self.citizen_info[rand_citizen]["times_infected"] += 1
    
    """
    Helper function returns whether or not citizen is infected for longer than a day.
    Means that they are removed from the population and either die or are cured.
    (i.e days_infected = 1)
    """
    def get_infection_status_verbose(self, citizen):
        return bool(self.citizen_info[citizen]["is_infected"] and (self.citizen_info[citizen]["days_infected"] >= 1))
    
    """
    Helper function returns 1 if citizen is infection hence can transmit disease (NEWLY INFECTIOUS)
    """
    def get_contagiousness(self, citizen):
        return bool(self.citizen_info[citizen]["is_infected"] and (self.citizen_info[citizen]["days_infected"] == 0)) 

    """
    Given a citizen, remove all edges/connections connected to the citizen
    """
    def remove_connections(self, citizen):
        # Remove connections for both the infected person and their connections
        for connection, is_connected in enumerate(self.population_matrix[citizen]):
            if is_connected: 
                self.population_matrix[citizen][connection] = 0
                self.population_matrix[connection][citizen] = 0 

    """
    Returns False if citizen can still participate in society.
    This means they'll be able to connect to other people. 
    Returns True if they are dead.
    """
    def is_dead(self, citizen):
        return bool(self.citizen_info[citizen]["is_dead"])
    
    """
    Updates the metadata for infected citizens.
    Increases amount of days infected in the record.
    """
    def update_infected_citizens(self, citizen):
        self.citizen_info[citizen]["days_infected"] += 1
    
    """
    Goes through every connection and decides based on infection probability 
    whether or not connected citizens become infected.
    """
    def transmit_infection(self, citizen):
        for connection, is_connected in enumerate(self.population_matrix[citizen]):
            if is_connected and (random() < self.infection_prob) and (self.citizen_info[connection]["is_infected"] != 1):
                self.citizen_info[connection]["is_infected"] = 1
                self.citizen_info[connection]["times_infected"] += 1
                self.no_of_infections += 1
    
    """
    New connection is added if below given threshold (threshold = connection_prob)
    """
    def make_friends(self, citizen):
        for connection, is_connected in enumerate(self.population_matrix[citizen]):
            # don't connect with yourself or someone who is dead/iso
            if (connection == citizen) or self.is_dead(connection): continue 
            elif not is_connected and (random() < self.connection_prob): 
                self.population_matrix[citizen][connection] = 1
                self.population_matrix[connection][citizen] = 1
    
    """
    Function decides whether a infected citizen dies or gets cured.
    If they die they lose all connections and are removed from the network.
    Else, citizen is magically cured and gets to keep the same connections.
    """
    def decide_fate(self, citizen):
        if random() < self.fatality_prob: 
            self.remove_connections(citizen)
            self.citizen_info[citizen]["is_dead"] = 1
            self.citizen_info[citizen]["is_infected"] = 0
            
            # Update population data
            self.no_of_deaths += 1
            self.no_of_infections -= 1  

        else: 
            self.citizen_info[citizen] = {"is_infected": 0, 
                                          "days_infected": 0, 
                                          "is_dead":0,
                                          "times_infected":self.citizen_info[citizen]["times_infected"]} 
            self.no_of_infections -= 1
            self.no_citizens_recovered += 1
            
    """
    Plot metrics.
    Metrics:
        Infections = No. of infections (Cumulative)
        Non-infections = Infections - Population Size
        Deaths = No. of deaths (Total)
        Recovered = Citizens recovered (Total)
    """
    def print_metrics(self, day):
        print("Day: {}".format(day))
        print("Infections Vs Non-infections: {} to {}".format(self.no_of_infections, self.pop_size - self.no_of_infections))
        print("Deaths Vs Recovered: {} to {}".format(self.no_of_deaths, self.no_citizens_recovered))
        #print(self.citizen_info)
        #print(self.population_matrix)
        print("\n")
    
    """
    Plot infections and deaths on a per day time scale
    """
    def plot_metrics(self):
        sns.lineplot(x = self.infections_per_day.keys(), y = self.infections_per_day.values(), label = "Infection per day")
        sns.lineplot(x = self.deaths_per_day.keys(), y = self.deaths_per_day.values(), label = "Deaths per day")
        plt.xlabel("Day")
        plt.ylabel("Number of citizens")
        plt.ylim((0, self.pop_size))
        plt.legend()
        plt.title("Infections per day & Deaths per day")
        plt.show()

# 1. Run Sim for n days 
### (prints metrics and graph once at the end)

In [None]:
def run_sim(days, pop_size, n_metrics, fatility_prob, infection_prob, connection_prob):
    
    # Inititialise population
    population_master = Population(pop_size, fatility_prob, infection_prob, connection_prob)
    
    # Print metrics 10 times throughout simulation
    days_to_print_metrics = np.linspace(start = 0, stop = days, num = n_metrics, dtype=int)
    
    # Run for n days
    for day in range(days):
        
        if day in days_to_print_metrics: population_master.print_metrics(day)
        
        for citizen in range(pop_size):
            # Check if citizen is knowingly infectious then decide fate
            if population_master.get_infection_status_verbose(citizen):
                population_master.decide_fate(citizen)

            # Check citizen status (if not then go to next citizen)
            if population_master.is_dead(citizen): 
                continue
            
            # Transmit infection if infected
            if population_master.get_contagiousness(citizen):
                population_master.transmit_infection(citizen)
            
            # Make new connections
            population_master.make_friends(citizen)
                
            # Update metadata for infected citizens
            if population_master.get_contagiousness(citizen):
                population_master.update_infected_citizens(citizen)
                
        population_master.infections_per_day[day] = population_master.no_of_infections
        population_master.deaths_per_day[day] = population_master.no_of_deaths
        
    print("FINAL METRICS:")
    population_master.print_metrics(day)
    
    population_master.plot_metrics()
            
    return population_master.population_matrix, population_master.citizen_info

"""
days: days to run simulation.
pop_size: population size.
n_metrics: Amount of times that metrics will be printed throughout the simulation.
fatality_prob: Probability that an infected citizen will die.
infection_prob: Probability that a non-infected citizen will get infected.
connection_prob: Probability that any non-dead citizen will make a new connection
"""
adj_mat, info = run_sim(days=100, pop_size=100, n_metrics=5, fatility_prob=0.2, infection_prob=0.05, connection_prob=0.8)

### Print and Save Interactive Graph

In [None]:
G = nx.from_numpy_matrix(adj_mat)
g = Network(height = 800, width = 800, notebook = True, bgcolor="#FFFFFF", font_color="black")
for node in G.nodes:
    if info[node]["is_infected"]:
        g.add_node(node, color = "yellow", value = info[node]["times_infected"], title = f"No. Connections: {G.degree(node)}")
    elif info[node]["is_dead"]:
        g.add_node(node, color = "red", value = info[node]["times_infected"], title = f"No. Connections: {G.degree(node)}")
    else:
        g.add_node(node, color = "green", value = info[node]["times_infected"], title = f"No. Connections: {G.degree(node)}")
    
for edge in G.edges:
    g.add_edge(edge[0], edge[1])

g.barnes_hut()
g.set_options(
    """
    var options = {
  "edges": {
    "color": {
      "inherit": true
    },
    "smooth": false
  },
  "interaction": {
    "navigationButtons": true
  },
  "physics": {
    "barnesHut": {
      "gravitationalConstant": -80000,
      "springLength": 250,
      "springConstant": 0.001
    },
    "minVelocity": 0.75
  }
}
    """
)
g.show("disease_network.html")

# 2. Create graph gif

### Create graph png files
Run sim for n days where n_metrics is equivalent to the amount of times a graph will be created and saved as a png throughout the sim

In [None]:
def save_graph(adj_mat, info, day, pm):
    colour_map = []
    node_sizes = []
    G = nx.from_numpy_matrix(adj_mat)
    for node in G:
        if info[node]["is_infected"]:
            colour_map.append("yellow")
        elif info[node]["is_dead"]:
            colour_map.append("red")
        else:
            colour_map.append("green")
        
        if day >= 1:
            all_times_infected = [n["times_infected"] for n in info]
            avg_times_infected = sum(all_times_infected)/len(all_times_infected)
            node_sizes.append((info[node]["times_infected"]*100)/avg_times_infected)
        else: node_sizes.append(100)
        
    # Plot metrics
    fig, (ax1, ax2) = plt.subplots(2, figsize=(12,12), gridspec_kw = {'height_ratios': [1, 2]})
    sns.lineplot(x = pm.infections_per_day.keys(), y = pm.infections_per_day.values(), label = "Infection per day", ax = ax1)
    sns.lineplot(x = pm.deaths_per_day.keys(), y = pm.deaths_per_day.values(), label = "Deaths per day", ax = ax1)
    ax1.set_xlabel("Day")
    ax1.set_ylabel("Number of citizens")
    ax1.set_ylim((0, pm.pop_size))
    ax1.legend()
    ax1.set_title("Infections per day & Deaths per day")
    
    #Plot Network
    ax2.set_title(f"Day: {day}")
    nx.draw_kamada_kawai(G, node_color = colour_map, node_size = node_sizes, ax = ax2)    
    plt.savefig(f"disease_network/network_{day}.png")
    plt.show()

In [None]:
def run_sim_to_png(days, pop_size, n_metrics, fatility_prob, infection_prob, connection_prob):
    try:
        os.mkdir("disease_network")
        print("Made new directory \"disease_network\"")
    except: 
        raise
    
    # Inititialise population
    population_master = Population(pop_size, fatility_prob, infection_prob, connection_prob)
    
    # Print metrics 10 times throughout simulation
    days_to_print_metrics = np.linspace(start = 0, stop = days, num = n_metrics, dtype=int)
    
    # Run for n days
    for day in range(days):
        
        for citizen in range(pop_size):
            # Check if citizen is knowingly infectious then decide fate
            if population_master.get_infection_status_verbose(citizen):
                population_master.decide_fate(citizen)

            # Check citizen status (if not then go to next citizen)
            if population_master.is_dead(citizen): 
                continue
            
            # Transmit infection if infected
            if population_master.get_contagiousness(citizen):
                population_master.transmit_infection(citizen)
            
            # Make new connections
            population_master.make_friends(citizen)
                
            # Update metadata for infected citizens
            if population_master.get_contagiousness(citizen):
                population_master.update_infected_citizens(citizen)
                
        population_master.infections_per_day[day] = population_master.no_of_infections
        population_master.deaths_per_day[day] = population_master.no_of_deaths
        
        # Print Metrics + save chart for day
        if day in days_to_print_metrics: 
            population_master.print_metrics(day)
            save_graph(population_master.population_matrix, population_master.citizen_info, day, population_master)
        
    print("FINAL METRICS:")
    population_master.print_metrics(day)
    
    population_master.plot_metrics()

In [None]:
# Below will run the simulation and save the resulting png files in the new disease_network directory

"""
days: days to run simulation.
pop_size: population size.
n_metrics: Amount of times that metrics and graphs will be printed throughout the simulation.
fatality_prob: Probability that an infected citizen will die.
infection_prob: Probability that a non-infected citizen will get infected.
connection_prob: Probability that any non-dead citizen will make a new connection
"""

run_sim_to_png(days=100, pop_size=100, n_metrics=100, fatility_prob=0.05, infection_prob=0.7, connection_prob=0.1)

### Below creates a gif from the saved png files

In [None]:
# Functions below are from: https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside
def atof(text):
    try:
        retval = float(text)
    except ValueError:
        retval = text
    return retval

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    float regex comes from https://stackoverflow.com/a/12643073/190597
    '''
    return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ]

filenames = sorted([image for image in os.listdir("disease_network") if re.match("network_\d+.png", image)], key = natural_keys)

In [None]:
images = []
for filename in filenames:
    images.append(imageio.imread(f"disease_network/{filename}"))
imageio.mimsave('disease_network/networkx.gif', images, duration = 0.1)