In [13]:
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Union, Tuple, Iterable, Optional, Any
from tqdm import tqdm
from fancy_einsum import einsum

In [3]:
@dataclass
class Disease:
    beta: float
    gamma: float
    delta: float

In [130]:
class City:
    def __init__(self,
                 N0s: np.ndarray,
                 groups: list,
                 compartments: str = 'SEIR',
                 mixing_LR: Optional[np.ndarray] = None):
        """
        mixing_LR: shape (n_groups, n_groups). The matrix of mixing likelihood ratios for groups i and j. All diagonal elements should be one
        """
        self.n_groups = len(groups)
        self.groups = groups
        assert len(N0s) == self.n_groups, "N0s should have length n_groups"
        self.N0s = N0s
        self.N0 = self.N0s.sum()
        assert compartments == 'SEIR', "Only SEIR compartment model supported at the moment"
        self.model_type = compartments
        self.compartments = [l for l in compartments]
        if mixing_LR is None:
            mixing_LR = np.identity(self.n_groups)
        assert mixing_LR.shape == (self.n_groups, self.n_groups)
        assert np.all(mixing_LR.diagonal() == 1)
        assert np.all(mixing_LR == mixing_LR.T)
        self.mixing_LR = mixing_LR

        self.S_index = self.compartments.index('S')
        self.E_index = self.compartments.index('E')
        self.I_index = self.compartments.index('I')
        self.R_index = self.compartments.index('R')

    def reset_parameters(self, I0: int = 0, n_sims: int = 1, simulation_steps: int = 100):
        """
        Sets self.data to a dict of dicts, containing an array of zeros of size (n_sims, timesteps) for each compartment in the city
        The only values which aren't zero are the initial value of I which is I0 and the initial value of S which is N-I0
        The default value of I0 is the value set in the class initialization, but can be overridden
        """
        # if I0 is None:
        #     I0 = self.I0
        self.I0 = I0
        self.data = np.zeros((self.n_groups,len(self.compartments),self.n_sims, self.simulation_steps),dtype = np.int64)
        self.arrivals = np.zeros((self.n_groups,len(self.compartments),self.n_sims, self.simulation_steps),dtype = np.int64)
        self.departures = np.zeros((self.n_groups,len(self.compartments),self.n_sims, self.simulation_steps),dtype = np.int64)
        self.initial_conditions()
    
    def initial_conditions(self):
        raise ModuleNotFoundError
    
    def multiple_sims(self, delta_t: float, epidemic_time: Union[int,float], disease: Disease, I0: int = 0, n_sims: int = 100):
        self.n_sims = n_sims
        self.delta_t = delta_t
        self.epidemic_time = epidemic_time
        p_recovery = 1 - np.exp( - delta_t * disease.gamma)
        p_infectious = 1 - np.exp( - delta_t * disease.delta)
        self.simulation_steps = int(epidemic_time // delta_t) + 1
        self.times = np.linspace(0, epidemic_time, self.simulation_steps)
        self.reset_parameters(I0, n_sims, self.simulation_steps)
        for sim_step in tqdm(range(self.simulation_steps-1)):
            self.step_internal(disease.beta,
                      delta_t,
                      p_infectious,
                      p_recovery,
                      sim_step)
        self.daily_flight_data()

    def step_internal(self,
                      beta:float,
                      delta_t: float,
                      p_infectious: float,
                      p_recovery: float,
                      simulation_step: int):
        N = self.data[...,simulation_step].sum(axis = 1)
        
        S = self.data[:,self.S_index,:,simulation_step]
        E = self.data[:,self.E_index,:,simulation_step]
        I = self.data[:,self.I_index,:,simulation_step]
        R = self.data[:,self.R_index,:,simulation_step]
        modified_I = einsum('group1 group2, group2 n_sims -> group1 n_sims', self.mixing_LR, I)
        modified_N = einsum('group1 group2, group2 n_sims -> group1 n_sims', self.mixing_LR, N)
        exposure_rate = beta * modified_I/modified_N
        p_exposure = 1 - np.exp(- delta_t * exposure_rate)
        n_exposed = np.random.binomial(S, p_exposure)
        n_infectious = np.random.binomial(E,p_infectious)
        n_recovered = np.random.binomial(I,p_recovery)

        self.data[:,self.S_index,:,simulation_step + 1] = S - n_exposed
        self.data[:,self.E_index,:,simulation_step + 1] = E + n_exposed - n_infectious
        self.data[:,self.I_index,:,simulation_step + 1] = I + n_infectious - n_recovered
        self.data[:,self.R_index,:,simulation_step + 1] = R + n_recovered

    def daily_flight_data(self, moving_average = False):
        window = int(1/self.delta_t)

        self.daily_arrivals = self.arrivals.reshape((self.n_groups, len(self.compartments), self.n_sims, self.simulation_steps//window, window)).sum(axis = -1)
        self.daily_departures = self.departures.reshape((self.n_groups, len(self.compartments), self.n_sims, self.simulation_steps//window, window)).sum(axis = -1)

    def select_travellers(self, daily_mixnumber: int, n_sims: int, simulation_step: int, delta_t: float):
        raise ModuleNotFoundError

    def plot_sims(self,
                  times: Optional[np.ndarray] = None,
                  cityname: Union[int,str] = 0, 
                  shift_index: Optional[np.ndarray] = None,
                  separate_groups: bool = False):
        include_flight_data = 'arrivals' in dir(self)
        subplots = self.n_groups if separate_groups else 1
        if include_flight_data:
            fig, axs = plt.subplots(3,subplots, figsize = (20,30))
            axs = np.array(axs)
        else:
            fig, axs = plt.subplots(1,subplots, figsize = (20,10))
            axs = np.expand_dims(np.array(axs),0)
        if self.n_groups == 1:
            axs = np.expand_dims(axs, -1)
        if times is None:
            times = self.times
        if shift_index is None:
            shift = np.zeros((self.n_sims,1))
        else:
            shift = times[shift_index].reshape((self.n_sims,1))
        # assert shift.shape == (self.n_sims,1), "Wrong shift shape"
        days = np.array(range(int(max(times))), dtype = np.float64)
        
        times = times - shift + shift.mean()
        days = days - shift + shift.mean()
        cols = ['green', 'orange', 'red', 'blue']
        labels = ['Whole City', 'Arrivals', 'Departures']
        if separate_groups:
            for i, group in enumerate(self.groups):
                for j,compartment in tqdm(enumerate(self.compartments)):
                    axs[0,i].plot(times[0],
                                self.data[i,j,0],
                                label = compartment,
                                color = cols[j])
                    if include_flight_data:
                        axs[1,i].plot(days[0],  
                                    self.daily_arrivals[i,j,0],
                                    label = compartment,
                                    color = cols[j])
                        axs[2,i].plot(days[0],
                                    self.daily_departures[i,j,0],
                                    label = compartment,
                                    color = cols[j])
                    for k,datum in enumerate(self.data[i,j,1:]):
                        axs[0,i].plot(times[k+1], datum, color = cols[j])
                    if include_flight_data:
                        for k, (arrival, departure) in enumerate(zip(self.daily_arrivals[i,j,1:],self.daily_departures[i,j,1:])):
                            axs[1,i].plot(days[k+1], arrival, color = cols[j])
                            axs[2,i].plot(days[k+1], departure, color = cols[j])
                num_subplots = axs.shape[0]
                for j in range(num_subplots):
                    axs[j,i].legend()
                    axs[j,i].set_title(f"City {cityname}, {group}: {labels[j]}")
        else:
            for j, compartment in tqdm(enumerate(self.compartments)):
                axs[0].plot(times[0],
                            self.data[:,j,0].sum(axis = 0),
                            label = compartment,
                            color = cols[j])
                if include_flight_data:
                    axs[1].plot(days[0],
                                self.daily_arrivals[:,j,0].sum(axis = 0),
                                label = compartment,
                                color = cols[j])
                    axs[2].plot(days[0],
                                self.daily_departures[:,j,0].sum(axis = 0),
                                label = compartment,
                                color = cols[j])
                for k,datum in enumerate(self.data[:,j,1:].sum(axis = 0)):
                    axs[0].plot(times[k+1], datum, color = cols[j])
                    if include_flight_data:
                        for k, (arrival, departure) in enumerate(zip(self.daily_arrivals[:,j,1:].sum(axis = 0),self.daily_departures[:,j,1:].sum(axis = 0))):
                            axs[1].plot(days[k+1], arrival, color = cols[j])
                            axs[2].plot(days[k+1], departure, color = cols[j])
            num_subplots = axs.shape[0]
            for j in range(num_subplots):
                axs[j].legend()
                axs[j].set_title(f"City {cityname}: {labels[j]}")
        fig.show()
    
    def infections_at_time(self, time_indexes: Optional[np.ndarray] = None, time: Optional[np.ndarray] = None, included_compartments: str = 'EIR'):
        """
        Returns the number of normal, infectious, and total people who have left the susceptible compartment by time t
        times: size (n_sims) the (different) time at which to evaluate number who have been exposed for each sim
        """
        assert (time is not None) != (time_indexes is not None), "Exactly one of time_indexes and times should be specified"
        if time_indexes is None:
            assert time is not None
            time_indexes = self.times.searchsorted(time)
        assert len(time_indexes) == self.n_sims, "len(times) must be n_sims"
        # values = np.zeros((self.n_sims,2,4))
        to_include = np.array([l in included_compartments for l in self.compartments])
        out = self.data[:,to_include,].sum(axis = 0)
        out = out[...,time_indexes]
        # for i,group in enumerate(self.groups):
        #     out[group] = np.zeros(self.n_sims)
        #     for j, compartment in enumerate(self.compartments):
        #         # values[:,i,j] = self.data[group][compartment][np.array(range(self.n_sims)),time_indexes]
        #         if compartment in to_include:
        #             out[group] += self.data[group][compartment][np.array(range(self.n_sims)),time_indexes]
        #     # out[group] = values[:,i].sum(axis = -1) - values[:,i,0]
        #     assert out[group].shape == (self.n_sims,), "Logical Error Somewhere!"
        # out['total'] = sum([out[group] for group in self.groups])
        return out

    def peak_I_times(self):
        return self.data[:,self.I_index].sum(axis = 0).argmax(axis = -1)


    def __str__(self):
        raise ModuleNotFoundError



In [131]:
class FrequentFlierCity(City):
    def __init__(self,
                 N0: int = 10**6,
                 frequent_flier_frac: float = 0.1,
                 p_ff: Optional[float] = None,
                 flying_LR: Optional[float] = None,
                 mixing_LR: float = 5,
                 compartments: str = 'SEIR'):
        mixing_LR_matrix = np.array([[1.,1/mixing_LR],[1/mixing_LR,1.]])
        self.groups = ['normal', 'frequent_fliers']
        self.N0s = np.array([N0 * (1 - frequent_flier_frac), N0 * frequent_flier_frac],dtype = np.int64)
        super().__init__(self.N0s, self.groups, compartments, mixing_LR_matrix)
        self.frequent_flier_frac = frequent_flier_frac
        if (p_ff is None) and (flying_LR is None):
            flying_LR = 10
        assert (p_ff is None) != (flying_LR is None), "Specify exactly one of p_ff OR flying_LR!"
        if flying_LR is not None:
            self.flying_LR = flying_LR
            self.p_ff = flying_LR * frequent_flier_frac/ (flying_LR * frequent_flier_frac + (1 - frequent_flier_frac))
        if p_ff is not None:
            self.p_ff = p_ff
            self.flying_LR = (p_ff / frequent_flier_frac) / ((1-p_ff) / (1 - frequent_flier_frac))

    def initial_conditions(self):
        I_n = np.random.binomial(self.I0, 1-self.frequent_flier_frac, self.n_sims)
        I_ff = self.I0 - I_n
        self.data[0,2,:,0] = I_n
        self.data[1,2,:,0] = I_ff
        self.data[0,0,:,0] = self.N0s[0] - I_n
        self.data[1,0,:,0] = self.N0s[1] - I_ff

    def __str__(self):
        out = 'City Type:\n FrequentFlierCity'
        out += f'\nN:\n {self.N0}'
        out += f'\nfrequent_flier_frac:\n {self.frequent_flier_frac}'
        out += f'\nflying_LR:\n {self.flying_LR}'
        out += f'\np_ff:\n {self.p_ff}'
        out += f'\nmixing_LR:\n {self.mixing_LR}'
        return out


In [132]:
measles = Disease(1.5, 1/8,1/10)

100%|██████████| 999/999 [00:00<00:00, 9096.85it/s]


In [134]:
class BasicCity(City):
    def __init__(self,
                 N0: int = 10**6,
                 compartments: str = 'SEIR'):
        mixing_LR_matrix = np.array([[1.]])
        self.groups = ['normal']
        self.N0s = np.array([N0],dtype = np.int64)
        super().__init__(self.N0s, self.groups, compartments, mixing_LR_matrix)
    
    def initial_conditions(self):
        self.data[0,2,:,0] = self.I0
        self.data[0,0,:,0] = self.N0s[0] - self.I0

In [None]:
class Travel:
    def __init__(self,
                 cities: list[City],
                 mixmatrix: np.ndarray,
                 I0s: Optional[list[int]] = None):
        self.cities = cities
        self.mixmatrix = mixmatrix.astype(np.int64)
        assert np.all(mixmatrix == mixmatrix.T), "mixmatrix must be symmetrical!"
        assert np.trace(mixmatrix) == 0, "mixmatrix must be traceless!"
        self.reset_parameters(I0s)