# Multi-Strain SIR Model

This implements a simple multi-strain SIR model where there is cross immunity between strains defined by a matrix.

First we do imports and setup.

In [1]:
import tabularepimdl as tepi

import pandas as pd
import numpy as np

## Defining a Population

Let's make the population. This is just like a one strain model, but with placeholders for three strains. 

In [2]:
## Population
pop_3strain = pd.DataFrame(
    {
        "Strain1": pd.Categorical(['S','I'], categories=['S','I','R']),
        "Strain2": pd.Categorical(['S','S'], categories=['S','I','R']),
        "Strain3": pd.Categorical(['S','S'], categories=['S','I','R']),
        "N": np.array([990,10]),
        "T":0
    }
)


## Defining the Cross Protection Matrix and Infection Rule

We now define immunity such tht "adjacent" strains have 50% cross protection and "non-adjacent" strains have 20% cross protection. That is, cross immunity is defined by this matrix:

$$
\begin{bmatrix}
1&.5&.2\\
.5&1&.5\\
.2&.5&1
\end{bmatrix}
$$




In [3]:
##Cross protection matrix.
cp_matrix = np.array([[1,.5,.2],
               [.5,1,.5],
               [.2,.5,1]])

We now use this to construct an infection rule. Assuming all strains have the same infecitous period (5 days), we are going to set this up so the strains are progressively more infectious, with $R_0$ progressing from 2.5 to 5.

In [4]:
three_strain_infect = tepi.MultiStrainInfectiousProcess(
    betas=np.array([0.5, 0.75, 1]),
    columns=["Strain1", "Strain2", "Strain3"],
    cross_protect=cp_matrix)

## Creating and Running a Deterministic Epidemic Model

Now we put it all together in a deterministic epidemic model, 
specifying the recovery times.

In [5]:
SIR_3Strain = tepi.EpiModel(pop_3strain, rules = [
    three_strain_infect,
    tepi.SimpleTransition('Strain1','I','R',.2),
    tepi.SimpleTransition('Strain2','I','R',.2),
    tepi.SimpleTransition('Strain3','I','R',.2)
])

Now we run it. In the loop we can introduce an infected person with the appropriate characteristics on the appropriate days. In this case
day 11 and day 20.

In [6]:
intro_day1 = 11
intro_day2 = 20

for t in np.arange(0,100,0.25):
    if t == intro_day1:
        to_add = pd.DataFrame({
            'Strain1':['S'],
            'Strain2':['I'],
            'Strain3':['S'],
            'N':1,
            'T':[t]
         })
        SIR_3Strain.cur_state = pd.concat([SIR_3Strain.cur_state, to_add])
    if t == intro_day2:
        to_add = pd.DataFrame({
            'Strain1':['S'],
            'Strain2':['S'],
            'Strain3':['I'],
            'N':1,
            'T':[t]
         })
        SIR_3Strain.cur_state = pd.concat([SIR_3Strain.cur_state, to_add])
    SIR_3Strain.do_timestep(dt=0.25)

Now let's plot it.

In [7]:

import plotly.express as px

long_epi = SIR_3Strain.full_epi.melt(id_vars=['N','T'], var_name='Strain', value_name="InfState")
long_epi = long_epi.groupby(["T","Strain","InfState"]).sum().reset_index()

epi_fig = px.line(long_epi, x="T", y="N", color="InfState", line_dash="Strain")
epi_fig.show()


## Making a Stochastic Version

Let's make the model stochastic. Since we are making everything stochastic, we can just take the same model, reset it, and set the `EpiModel.stoch_policy` parameter to `"stochastic"`. To make things more interesting, we will do multiple runs where the introduction day is a random variable. 

In [8]:
SIR_3Strain.reset()
SIR_3Strain.stoch_policy="stochastic"

##go through multiple times appending the simulation. 
all_sims = []

for sim in range(15):
    SIR_3Strain.reset()
    intro_day1 = np.random.poisson(10)
    intro_day2 = intro_day1 + np.random.poisson(10)
    for t in np.arange(0,100,0.25):
        if t == intro_day1:
            to_add = pd.DataFrame({
                'Strain1':['S'],
                'Strain2':['I'],
                'Strain3':['S'],
                'N':1,
                'T':[t]
            }) 
            SIR_3Strain.cur_state = pd.concat([SIR_3Strain.cur_state, to_add])
        if t == intro_day2:
            to_add = pd.DataFrame({
                'Strain1':['S'],
                'Strain2':['S'],
                'Strain3':['I'],
                'N':1,
                'T':[t]
            })
            SIR_3Strain.cur_state = pd.concat([SIR_3Strain.cur_state, to_add])
        SIR_3Strain.do_timestep(dt=0.25)

    tmp = SIR_3Strain.full_epi
    tmp['sim'] = sim
    all_sims = all_sims + [tmp]


all_sims = pd.concat(all_sims)



Now let's plot them all!

In [9]:

long_epi = all_sims.melt(id_vars=['N','T','sim'], var_name='Strain', value_name="InfState")
long_epi = long_epi.groupby(["T","Strain","InfState","sim"]).sum().reset_index()


epi_fig = px.line(long_epi, x="T", y="N", color="InfState", line_dash="Strain", line_group="sim")
epi_fig.update_traces(opacity=0.25)
epi_fig.show()
#long_epi.reset_index()

## Creating a Custom Rule for Hospitalization

The power of the tabular approach comes, in part, from its ability to have rules that operate on multiple columns and divorce these rules from the disease process. One such rule would be to have different probabilities of hospitalization between a primary infection and subsequent infections. Here we create such a rule, where there is a 5% per day chance of being hospitalized from a primary infection and a 1% chance from a secondary infection. 

Note that, because this is a custom rule, we don't need to make it as flexible as rules in the package.



In [8]:
class HospRule(tepi.Rule):
    '''This rule takes multiple columns. You have some risk of hospitalization if infected from any column,
    but that probability is reduced if you are recovered in any column. We will additionally track which strain you were
    hospitalized with. Only tracking total hospitalizations.'''

    def __init__(self, inf_cols, hosp_cols, prim_hrate, sec_hrate, stochastic=False) -> None:
        ''' Presume that hosp cols is of the same length as the secondary, and they hav corresponding indices. '''
        super().__init__()
        self.inf_cols = inf_cols
        self.hosp_cols = hosp_cols
        self.prim_hrate = prim_hrate
        self.sec_hrate = sec_hrate
        self.stochastic= stochastic

    def get_deltas(self, current_state, dt=1.0, stochastic=None):
        if stochastic is None:
            stochastic = self.stochastic
        
        #probably not the most efficient but we will loop
        #through the rows and create the deltas as we go.
        deltas = pd.DataFrame()#make a list, turn it into a df later

        #this is pretty inefficent...just for illustration.
        for ind, row in current_state.iterrows():

           match_ind = -1
           prev_inf = False
           for idx, col in enumerate(self.inf_cols):
                if row[col] == "I":
                   match_ind = idx
                elif row[col] == "R":
                   prev_inf = True

        
           #leave if no one was infected.
           if match_ind == -1:
               continue
           
           row = pd.DataFrame(row).transpose() #so we can use assign, etc.
           rate = self.sec_hrate if prev_inf else self.prim_hrate

           if not stochastic:
               delta_decobs = row.assign(N=-row.N*(1-np.exp(-dt*rate)))
           else:
               delta_decobs = row.assign(
                   N= -np.random.binomial(row.N, 1-np.exp(-dt*rate))
               )

           delta_incobs = delta_decobs.assign(N=-delta_decobs.N)
           delta_incobs[self.hosp_cols[match_ind]] = "H"

           deltas = pd.concat([deltas,delta_incobs, delta_decobs])
         
        #print(deltas)
        #print(deltas['N'].sum())
        return deltas

                
    def to_yaml(self):
        #no serialization since this is a one off rule
        pass    


Now that we have the rule let's create some test data and make sure it is working. 

In [9]:
pop_3strain_hosp = pd.DataFrame(
    {
        "Strain1": pd.Categorical(['S','I','R'], categories=['S','I','R']),
        "Strain2": pd.Categorical(['S','S','I'], categories=['S','I','R']),
        "Strain3": pd.Categorical(['S','S','S'], categories=['S','I','R']),
        "Hosp1": "U",
        "Hosp2": "U",
        "Hosp3": "U",
        "N": np.array([1000,1000,1000]),
        "T":0
    }
)


hosp_rule = HospRule(["Strain1","Strain2","Strain3"],
                     ["Hosp1","Hosp2","Hosp3"],
                     0.05,
                     0.01)

print(hosp_rule.get_deltas(pop_3strain_hosp))
print(hosp_rule.get_deltas(pop_3strain_hosp,stochastic=True))


  Strain1 Strain2 Strain3 Hosp1 Hosp2 Hosp3          N  T
1       I       S       S     H     U     U  48.770575  0
1       I       S       S     U     U     U -48.770575  0
2       R       I       S     U     H     U   9.950166  0
2       R       I       S     U     U     U  -9.950166  0
  Strain1 Strain2 Strain3 Hosp1 Hosp2 Hosp3   N  T
1       I       S       S     H     U     U  44  0
1       I       S       S     U     U     U -44  0
2       R       I       S     U     H     U  12  0
2       R       I       S     U     U     U -12  0


Bow that we have a working rule, we can add it to our epidemic model and run it as before. Note that this rule need to be in a different rule group so there is not "competition" for people from various states.


In [10]:
#Create the base population, this time with only one person infected
#with strain one.

pop_3strain_hosp = pd.DataFrame(
    {
        "Strain1": pd.Categorical(['S','I'], categories=['S','I','R']),
        "Strain2": pd.Categorical(['S','S'], categories=['S','I','R']),
        "Strain3": pd.Categorical(['S','S'], categories=['S','I','R']),
        "Hosp1": "U",
        "Hosp2": "U",
        "Hosp3": "U",
        "N": np.array([9999,1]),
        "T":0
    }
)


# Make the epidemic model
SIR_3Strain_hosp =  tepi.EpiModel(pop_3strain_hosp, rules = [[
    three_strain_infect,
    tepi.SimpleTransition('Strain1','I','R',.2),
    tepi.SimpleTransition('Strain2','I','R',.2),
    tepi.SimpleTransition('Strain3','I','R',.2)],
    [hosp_rule]
])


#Run the model, seeding states, as before
intro_day1 = 11
intro_day2 = 20

for t in np.arange(0,100,0.25):
    #print(t)
    if t == intro_day1:
        to_add = pd.DataFrame({
            'Strain1':['S'],
            'Strain2':['I'],
            'Strain3':['S'],
            'N':1,
            'T':[t],
            'Hosp1': ['U'],
            'Hosp2': ['U'],
            'Hosp3': ['U']
         })
        SIR_3Strain_hosp.cur_state = pd.concat([SIR_3Strain_hosp.cur_state, to_add])
    if t == intro_day2:
        to_add = pd.DataFrame({
            'Strain1':['S'],
            'Strain2':['S'],
            'Strain3':['I'],
            'N':1,
            'T':[t],
            'Hosp1': ['U'],
            'Hosp2': ['U'],
            'Hosp3': ['U']
         })
        SIR_3Strain_hosp.cur_state = pd.concat([SIR_3Strain_hosp.cur_state, to_add])
    SIR_3Strain_hosp.do_timestep(dt=0.25)

Now let's look at the results. Of most interest is the number of folks hospitalized with each strain versus the total number of infections for each strain.

In [11]:
res_sum_3strain = pd.DataFrame({
    'Strain': [1,2,3],
    'Infect':[np.sum((SIR_3Strain_hosp.cur_state['Strain1']=="R")*SIR_3Strain_hosp.cur_state['N']),
            np.sum((SIR_3Strain_hosp.cur_state['Strain2']=="R")*SIR_3Strain_hosp.cur_state['N']),
            np.sum((SIR_3Strain_hosp.cur_state['Strain3']=="R")*SIR_3Strain_hosp.cur_state['N'])],
    'Hosp':[np.sum((SIR_3Strain_hosp.cur_state['Hosp1']=="H")*SIR_3Strain_hosp.cur_state['N']),
            np.sum((SIR_3Strain_hosp.cur_state['Hosp2']=="H")*SIR_3Strain_hosp.cur_state['N']),
            np.sum((SIR_3Strain_hosp.cur_state['Hosp3']=="H")*SIR_3Strain_hosp.cur_state['N'])]
})

res_sum_3strain['IHR'] = res_sum_3strain['Hosp']/res_sum_3strain['Infect']


print(res_sum_3strain)

br_grph = px.scatter(res_sum_3strain, x="Strain", y="IHR")
br_grph.show()

   Strain       Infect         Hosp       IHR
0       1  5499.176843   871.728132  0.158520
1       2  7873.143915  1231.696919  0.156443
2       3  8791.795619   540.870908  0.061520


In [12]:
import plotly.graph_objects as go

fig = go.Figure(data=[
    go.Bar(name='Infect', x=res_sum_3strain.Strain, y=res_sum_3strain.Infect),
    go.Bar(name='Hosp', x=res_sum_3strain.Strain, y=res_sum_3strain.Hosp)
])
fig.update_layout(barmode='group')
fig.show()