In [2]:
import numpy as np
import time as t
import torch
import networkx as nx
import ndlib.models.ModelConfig as mc
import ndlib.models.epidemics as ep
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import cm
import scipy as sc
import sys
import os

In [3]:
def runtime(func):
    def wrapper(*args, **kargs):
        print("\033[1mExecuting function", func.__name__,":\n\033[0m")
        start = t.time()
        result = func(*args, **kargs)
        print ("\033[1m\nFinished execution of function {0} in {1:.5f} seconds.\n\033[0m".format(func.__name__, (t.time()-start)))
        return result
    return wrapper

In [4]:
class SEIRdata():
    def __init__(self, graph_type="ER", p =0.1, m = 5, k = 3, batches=1, iterations = 1000, N=100, beta=0.01, gamma=0.005, alpha=0.05, initial_infected=1, show_progress=False, one_hot_target=False, show_error=False):
        
        # resulting data
        self.x = np.zeros(shape=(batches, N))
        self.y = np.zeros(shape=(batches))
        
        # parameters
        self.graph_type = graph_type
        self.p = p # for ER and WS graph topology
        self.m = m # for BA graph topology
        self.k = k # for WS graph topology
        self.batches = batches
        self.N = N
        self.beta = beta
        self.gamma = gamma
        self.alpha = alpha
        self.initial_infected = initial_infected
        self.iterations = iterations
        self.show_progress = show_progress
        self.one_hot_target = one_hot_target
        self.show_error=show_error
        
    def run_once(self):
            
        # Erdős–Rényi-Graph
        if self.graph_type == "ER":
            
            # Network topology
            g = nx.erdos_renyi_graph(self.N, self.p)
        
        # Barabási–Albert graph
        elif self.graph_type == "BA":
            
            # Network topology
            g = nx.barabasi_albert_graph(self.N, self.m)
        
        
        # Watts–Strogatz small-world graph
        elif self.graph_type == "WS":
            
            # Network topology
            g = nx.watts_strogatz_graph(self.N, self.k, self.p)
            
        # no valid graph model
        else:
            print("Error: No valid graph type")
            return
        
        # Model selection
        model = ep.SEIRModel(g)
        
        # Model Configuration
        cfg = mc.Configuration()
        cfg.add_model_parameter('beta', self.beta)
        cfg.add_model_parameter('gamma', self.gamma)
        cfg.add_model_parameter('alpha', self.alpha)
        cfg.add_model_parameter("fraction_infected", self.initial_infected/self.N)
        model.set_initial_status(cfg)
        
        # P0 array
        y = np.array(list(model.status.items()))[:,1].reshape(self.N, -1)
        iterations = model.iteration_bunch(self.iterations)
        if self.one_hot_target:
            y = np.hstack((1-y, y))
        
        # Graph labels/status (SEIR)
        x = np.array(list((model.status).items()))[:,1].reshape(self.N, -1)
        
        # edges in COO sparse form
        sparse = sc.sparse.coo_matrix(nx.to_scipy_sparse_matrix(g))
        
        # x = final label, y = initial status, and edges in array COO form
        return x, y, np.stack((sparse.row, sparse.col))
    
    @runtime
    def generate_data(self):
        
        data = np.empty(self.batches, dtype=object)
        
        if not self.show_error:
            old_stderr = sys.stderr # backup current stdout
            sys.stderr = open(os.devnull, "w")
        
        if self.show_progress:
            progress = 0
        
        for batch in range(self.batches):
            
            if self.show_progress:
                while (batch/self.batches*100>=progress):
                    print("{} %".format(progress))
                    progress = progress+max(0.05, min(10, 10000/self.batches))
            
            batch_array = np.full(self.N, batch)
            x, y, edges = self.run_once()
            
            data[batch] = (batch_array, torch.tensor(edges, dtype=torch.long), torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.float))
        
        if not self.show_error:
            sys.stderr = old_stderr # reset old stdout
        
        return data

In [6]:
class SIRdata():
    def __init__(self, graph_type="ER", p =0.1, m = 5, k = 3, batches=1, iterations = 1000, N=100, beta=0.01, gamma=0.005, alpha=0.05, initial_infected=1, show_progress=False, one_hot_target=False, show_error=False, directed=False, one_graph=False):
        
        # resulting data
        self.x = np.zeros(shape=(batches, N))
        self.y = np.zeros(shape=(batches))
        
        # parameters
        self.graph_type = graph_type
        self.p = p # for ER and WS graph topology
        self.m = m # for BA graph topology
        self.k = k # for WS graph topology
        self.batches = batches
        self.N = N
        self.beta = beta
        self.gamma = gamma
        self.alpha = alpha
        self.initial_infected = initial_infected
        self.iterations = iterations
        self.show_progress = show_progress
        self.one_hot_target = one_hot_target
        self.show_error=show_error
        self.directed = directed
        self.one_graph = one_graph
        
    def run_once(self, batch):
            
        # Erdős–Rényi-Graph
        if self.graph_type == "ER":
            
            # Network topology
            
            if self.one_graph == False or batch==0:
                self.g = nx.erdos_renyi_graph(self.N, self.p)
            if self.directed:
                self.g = nx.to_directed(self.g)
        
        # Barabási–Albert graph
        elif self.graph_type == "BA":
            
            # Network topology
            if self.one_graph == False or batch==0:
                self.g = nx.barabasi_albert_graph(self.N, self.m)
        
        
        # Watts–Strogatz small-world graph
        elif self.graph_type == "WS":
            
            # Network topology
            if self.one_graph == False or batch==0:
                self.g = nx.watts_strogatz_graph(self.N, self.k, self.p)
            
        # no valid graph model
        else:
            print("Error: No valid graph type")
            return
        
        # Model selection
        model = ep.SIRModel(self.g)
        
        # Model Configuration
        cfg = mc.Configuration()
        cfg.add_model_parameter('beta', self.beta)
        cfg.add_model_parameter('gamma', self.gamma)
        cfg.add_model_parameter("fraction_infected", self.initial_infected/self.N)
        model.set_initial_status(cfg)
        
        # P0 array
        y = np.array(list(model.status.items()))[:,1].reshape(self.N, -1)
        iterations = model.iteration_bunch(self.iterations)
        if self.one_hot_target:
            y = np.hstack((1-y, y))
        
        # Graph labels/status (SEIR)
        x = np.array(list((model.status).items()))[:,1].reshape(self.N, -1)
        
        # edges in COO sparse form
        sparse = sc.sparse.coo_matrix(nx.to_scipy_sparse_matrix(self.g))
        
        # x = final label, y = initial status, and edges in array COO form
        return x, y, np.stack((sparse.row, sparse.col))
    
    @runtime
    def generate_data(self):
        
        data = np.empty(self.batches, dtype=object)
        
        if not self.show_error:
            old_stderr = sys.stderr # backup current stdout
            sys.stderr = open(os.devnull, "w")
        
        if self.show_progress:
            progress = 0
        
        for batch in range(self.batches):
            
            if self.show_progress:
                while (batch/self.batches*100>=progress):
                    print("{} %".format(progress))
                    progress = progress+max(0.05, min(10, 10000/self.batches))
            
            batch_array = np.full(self.N, batch)
            x, y, edges = self.run_once(batch)
            
            data[batch] = (batch_array, torch.tensor(edges, dtype=torch.long), torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.float))
        
        if not self.show_error:
            sys.stderr = old_stderr # reset old stdout
        
        return data