In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import icomo
import jax.numpy as jnp

In [None]:
 
# Parameters from GannaRozhnova's paper (Elimination prospects of the Dutch HIV epidemic among men who have sex with men in the era of pre-exposure prophylaxis)
N0 = 210000 # initial population size
N01 = 0.451*N0 # initial population size of risk group 1
N02 = 0.353*N0 # initial population size of risk group 2
N03 = 0.125*N0 # initial population size of risk group 3
N04 = 0.071*N0 # initial population size of risk group 4
N0s = [N01,N02,N03,N04]
mu = 1/45 # per year, rate of recruitment to sexually active population
Omega = 1-0.86 # PrEP effectiveness, baseline
c = [0.13,1.43,5.44,18.21] # per year, average number of partners in risk group l
h = [0.62,0.12,0.642,0.0] # infectivity of untreated individuals in stage k of infection
epsilon = 0.01 # infectivity of treated individuals
epsilonP = h[1]/2 # infectivity of MSM infected on PrEP
Lambda = 0.25 # transmission prob. per partnership
omega = 0.5 # mixing parameter, (0: assortative, 1: proportionate mixing)
Phi = -jnp.log(1-0.05) # per year, annual ART dropout rate
tau = -jnp.log(1-0.3) # per year, annual ART uptake rate
tauP1 = -jnp.log(1-0.95) # per year, annual ART uptake rate for MSM infected on PrEP
tauP2 = tauP1
tauP3 = tauP1
tauP4 = tauP1
tauPs = [tau,tauP1,tauP2,tauP3,tauP4]
rho1 = 1/0.142 # per year, rate of transition from stage 1 to 2 for untreated individuals
rho2 = 1/8.439 # per year, rate of transition from stage 2 to 3 for untreated individuals
rho3 = 1/1.184 # per year, rate of transition from stage 3 to 4 for untreated individuals
rho4 = 1/1.316 # per year, mortality rate for untreated individuals
rhos = [rho1,rho2,rho3,rho4]
gamma1 = 1/8.21 # per year, rate of transition from stage 1 to 2 for tretaed individuals
gamma2 = 1/54.0 # per year, rate of transition from stage 2 to 3 for treated individuals
gamma3 = 1/2.463 # per year, rate of transition from stage 3 to 4 for treated individuals
gamma4 = 1/2.737 # per year, mortality rate for treated individuals
gammas = [gamma1,gamma2,gamma3,gamma4]
Kon1 = -jnp.log(1-PrEPuptake_rg1) # annual PrEP uptake rate in risk group 1
Kon2 = -jnp.log(1-PrEPuptake_rg2) # annual PrEP uptake rate in risk group 2
Kon3 = -jnp.log(1-PrEPuptake_rg3) # annual PrEP uptake rate in risk group 3
Kon4 = -jnp.log(1-PrEPuptake_rg4) # annual PrEP uptake rate in risk group 4
Kons = [Kon1,Kon2,Kon3,Kon4]
Koff1 = 1/5.0 # per year, average duration of taing PrEP in risk group 1
Koff2 = Koff1
Koff3 = Koff1
Koff4 = Koff1
Koffs = [Koff1,Koff2,Koff3,Koff4]

args = dict(N0s=N0s, mu=mu, Omega=Omega, c=c, h=h, epsilon=epsilon, epsilonP=epsilonP, Lambda=Lambda, omega=omega, Phi=Phi, tau = tau, tauPs=tauPs, rhos=rhos, gammas=gammas, Kons=Kons, Koffs=Koffs)

def N(y): # number of people per risk group for a given state y (which means at a given time)
    N1 = np.sum([y["S1"],y["SP1"],y["I11"],y["IP11"],y["I12"],y["I13"],y["I14"],y["A11"],y["A12"],y["A13"],y["A14"]])
    N2 = np.sum([y["S2"],y["SP2"],y["I21"],y["IP21"],y["I22"],y["I23"],y["I24"],y["A21"],y["A22"],y["A23"],y["A24"]])
    N3 = np.sum([y["S3"],y["SP3"],y["I31"],y["IP31"],y["I32"],y["I33"],y["I34"],y["A31"],y["A32"],y["A33"],y["A34"]])
    N4 = np.sum([y["S4"],y["SP4"],y["I41"],y["IP41"],y["I42"],y["I43"],y["I44"],y["A41"],y["A42"],y["A43"],y["A44"]])
    return [N1,N2,N3,N4]

def M(l,ll,y,args): # mixing matrix
    c = args["c"] # unpack args we need
    omega = args["omega"]
    l = l-1 # because we start counting from 0
    ll = ll-1
    Ns = np.array(N(y))
    return omega*c[ll]*Ns[ll]/(jnp.sum([c*Ns])) + (1-omega)*jnp.where(l == ll, 1, 0)

def JP(l,y,args): # force of infection per year in group l
    h = args["h"] # unpack args we need
    epsilon = args["epsilon"]
    epsilonP = args["epsilonP"]
    Lambda = args["Lambda"]
    c = args["c"]
    Is = np.array([[y["I11"], y["I12"], y["I13"], y["I14"]],[y["I21"], y["I22"], y["I23"], y["I24"]],[y["I31"], y["I32"], y["I33"], y["I34"]],[y["I41"], y["I42"], y["I43"], y["I44"]]])
    IPs = np.array([y["Ip11"], y["IP21"], y["IP31"], y["IP41"]])
    N = N(y)
    As = np.array([[y["A11"], y["A12"], y["A13"], y["A14"]],[y["A21"], y["A22"], y["A23"], y["A24"]],[y["A31"], y["A32"], y["A33"], y["A34"]],[y["A41"], y["A42"], y["A43"], y["A44"]]])
    Ms = np.array([M(l,1,y),M(l,2,y),M(l,3,y),M(l,4,y)])
    innersums = [jnp.sum(h*Is[ll]/N[ll] + epsilon*As[ll]/N[ll]) for ll in range(4)]
    sum = jnp.sum((Ms*(epsilonP*IPs/N + innersums)))
    JP = Lambda * c[l-1] * sum
    return JP

In [None]:
# Parameters we have to decide
PrEPuptake_rg1 =  # annual PrEP uptake in risk group 1 (fraction)
PrEPuptake_rg2 =  # annual PrEP uptake in risk group 2 (fraction)
PrEPuptake_rg3 =  # annual PrEP uptake in risk group 3 (fraction)
PrEPuptake_rg4 =  # annual PrEP uptake in risk group 4 (fraction)

#Starting values -- The total number per risk group has to amount to the starting values N0j (see above)
# Susceptible per risk group
S1_0 = 
S2_0 =
S3_0 =
S4_0 =
# Susceptible on PrEP per risk group
SP1_0 =
SP2_0 =
SP3_0 =
SP4_0 =
# Infected in stage 1 per risk group
I11_0 =
I21_0 =
I31_0 =
I41_0 =
# Infected on PrEP in stage 1 per risk group
IP11_0 =
IP21_0 =
IP31_0 =
IP41_0 =
# Infected in stage 2 per risk group
I12_0 =
I22_0 =
I32_0 =
I42_0 =
# Infected in stage 3 per risk group
I13_0 =
I23_0 =
I33_0 =
I43_0 =
# Infected in stage 4 per risk group
I14_0 =
I24_0 =
I34_0 =
I44_0 =
# ART in stage 1 per risk group
A11_0 =
A21_0 =
A31_0 =
A41_0 =
# ART in stage 2 per risk group
A12_0 =
A22_0 =
A32_0 =
A42_0 =
# ART in stage 3 per risk group
A13_0 =
A23_0 =
A33_0 =
A43_0 =
# ART in stage 4 per risk group
A14_0 =
A24_0 =
A34_0 =
A44_0 =


In [None]:
# Equations

def HIV_model(t, y, args): # model from Ganna Rozhnova
    cm = icomo.CompModel(y)  # Initialize the compartmental model

    # unpack args
    Kons, N0s, mu, Omega, Koffs, tau, rhos, tauPs, Phi, gammas = args['Kons'], args['N0s'], args['mu'], args['Omega'], args['Koffs'], args['tau'], args['rhos'], args['tauPs'], args['Phi'], args['gammas']

    # Flow out of S compartments
    cm.flow("S1", "I11",  JP(1,cm.y,args))
    cm.flow("S2", "I21",  JP(2,cm.y,args))
    cm.flow("S3", "I31",  JP(3,cm.y,args))
    cm.flow("S4", "I41",  JP(4,cm.y,args))
    cm.flow("S1", "SP1", Kons[0])
    cm.flow("S2", "SP2", Kons[1])
    cm.flow("S3", "SP3", Kons[2])
    cm.flow("S4", "SP4", Kons[3])
    cm.dy["S1"] = cm.dy["S1"] - mu*cm.y["S1"] + mu*N0s[0]
    cm.dy["S2"] = cm.dy["S2"] - mu*cm.y["S2"] + mu*N0s[1]
    cm.dy["S3"] = cm.dy["S3"] - mu*cm.y["S3"] + mu*N0s[2]
    cm.dy["S4"] = cm.dy["S4"] - mu*cm.y["S4"] + mu*N0s[3]

    # Flow out of SP compartments
    cm.flow("SP1", "IP11", Omega*JP(1,cm.y,args))
    cm.flow("SP2", "IP21", Omega*JP(2,cm.y,args))
    cm.flow("SP3", "IP31", Omega*JP(3,cm.y,args))
    cm.flow("SP4", "IP41", Omega*JP(4,cm.y,args))
    cm.flow("SP1", "S1", Koffs[0])
    cm.flow("SP2", "S2", Koffs[1])
    cm.flow("SP3", "S3", Koffs[2])
    cm.flow("SP4", "S4", Koffs[3])
    cm.dy["SP1"] = cm.dy["SP1"] - mu*cm.y["SP1"]
    cm.dy["SP2"] = cm.dy["SP2"] - mu*cm.y["SP2"]
    cm.dy["SP3"] = cm.dy["SP3"] - mu*cm.y["SP3"]
    cm.dy["SP4"] = cm.dy["SP4"] - mu*cm.y["SP4"]

    # Flow out of Il1 compartments
    cm.flow("I11", "A11", tau)
    cm.flow("I21", "A21", tau)
    cm.flow("I31", "A31", tau)
    cm.flow("I41", "A41", tau)
    cm.dy["I11"] = cm.dy["I11"] - (mu + rhos[0])*cm.y["I11"]
    cm.dy["I21"] = cm.dy["I21"] - (mu + rhos[0])*cm.y["I21"]
    cm.dy["I31"] = cm.dy["I31"] - (mu + rhos[0])*cm.y["I31"]
    cm.dy["I41"] = cm.dy["I41"] - (mu + rhos[0])*cm.y["I41"]

    # Flow out of IPl1 compartments
    cm.flow("IP11", "A11", tauPs[0])
    cm.flow("IP21", "A21", tauPs[1])
    cm.flow("IP31", "A31", tauPs[2])
    cm.flow("IP41", "A41", tauPs[3])
    cm.dy["IP11"] = cm.dy["IP11"] - mu*cm.y["IP11"]
    cm.dy["IP21"] = cm.dy["IP21"] - mu*cm.y["IP21"]
    cm.dy["IP31"] = cm.dy["IP31"] - mu*cm.y["IP31"]
    cm.dy["IP41"] = cm.dy["IP41"] - mu*cm.y["IP41"]

    # Flow out of Ilk compartments (k=2,3,4)
    # k=2
    cm.flow("I12", "A12", tau)
    cm.flow("I22", "A22", tau)
    cm.flow("I32", "A32", tau)  
    cm.flow("I42", "A42", tau)
    cm.dy["I12"] = cm.dy["I12"] - (mu + rhos[1])*cm.y["I12"] + rhos[0]*cm.y["I11"]
    cm.dy["I22"] = cm.dy["I22"] - (mu + rhos[1])*cm.y["I22"] + rhos[0]*cm.y["I21"]
    cm.dy["I32"] = cm.dy["I32"] - (mu + rhos[1])*cm.y["I32"] + rhos[0]*cm.y["I31"]
    cm.dy["I42"] = cm.dy["I42"] - (mu + rhos[1])*cm.y["I42"] + rhos[0]*cm.y["I41"]
    # k=3
    cm.flow("I13", "A13", tau)
    cm.flow("I23", "A23", tau)
    cm.flow("I33", "A33", tau)
    cm.flow("I43", "A43", tau)
    cm.dy["I13"] = cm.dy["I13"] - (mu + rhos[2])*cm.y["I13"] + rhos[1]*cm.y["I12"]
    cm.dy["I23"] = cm.dy["I23"] - (mu + rhos[2])*cm.y["I23"] + rhos[1]*cm.y["I22"]
    cm.dy["I33"] = cm.dy["I33"] - (mu + rhos[2])*cm.y["I33"] + rhos[1]*cm.y["I32"]
    cm.dy["I43"] = cm.dy["I43"] - (mu + rhos[2])*cm.y["I43"] + rhos[1]*cm.y["I42"]
    # k=4
    cm.flow("I14", "A14", tau)
    cm.flow("I24", "A24", tau)
    cm.flow("I34", "A34", tau)
    cm.flow("I44", "A44", tau)
    cm.dy["I14"] = cm.dy["I14"] - (mu + rhos[3])*cm.y["I14"] + rhos[2]*cm.y["I13"]
    cm.dy["I24"] = cm.dy["I24"] - (mu + rhos[3])*cm.y["I24"] + rhos[2]*cm.y["I23"]
    cm.dy["I34"] = cm.dy["I34"] - (mu + rhos[3])*cm.y["I34"] + rhos[2]*cm.y["I33"]
    cm.dy["I44"] = cm.dy["I44"] - (mu + rhos[3])*cm.y["I44"] + rhos[2]*cm.y["I43"]

    # Flow out of Al1 compartments
    cm.flow("A11", "I11", Phi)
    cm.flow("A21", "I21", Phi)
    cm.flow("A31", "I31", Phi)
    cm.flow("A41", "I41", Phi)
    cm.dy["A11"] = cm.dy["A11"] - (mu + gammas[0])*cm.y["A11"]
    cm.dy["A21"] = cm.dy["A21"] - (mu + gammas[0])*cm.y["A21"]
    cm.dy["A31"] = cm.dy["A31"] - (mu + gammas[0])*cm.y["A31"]
    cm.dy["A41"] = cm.dy["A41"] - (mu + gammas[0])*cm.y["A41"]

    # Flow out of Alk compartments (k=2,3,4)
    # k=2
    cm.flow("A12", "I12", Phi)
    cm.flow("A22", "I22", Phi)
    cm.flow("A32", "I32", Phi)
    cm.flow("A42", "I42", Phi)
    cm.dy["A12"] = cm.dy["A12"] - (mu + gammas[1])*cm.y["A12"] + gammas[0]*cm.y["A11"]
    cm.dy["A22"] = cm.dy["A22"] - (mu + gammas[1])*cm.y["A22"] + gammas[0]*cm.y["A21"]
    cm.dy["A32"] = cm.dy["A32"] - (mu + gammas[1])*cm.y["A32"] + gammas[0]*cm.y["A31"]
    cm.dy["A42"] = cm.dy["A42"] - (mu + gammas[1])*cm.y["A42"] + gammas[0]*cm.y["A41"]
    # k=3
    cm.flow("A13", "I13", Phi)
    cm.flow("A23", "I23", Phi)
    cm.flow("A33", "I33", Phi)
    cm.flow("A43", "I43", Phi)
    cm.dy["A13"] = cm.dy["A13"] - (mu + gammas[2])*cm.y["A13"] + gammas[1]*cm.y["A12"]
    cm.dy["A23"] = cm.dy["A23"] - (mu + gammas[2])*cm.y["A23"] + gammas[1]*cm.y["A22"]
    cm.dy["A33"] = cm.dy["A33"] - (mu + gammas[2])*cm.y["A33"] + gammas[1]*cm.y["A32"]
    cm.dy["A43"] = cm.dy["A43"] - (mu + gammas[2])*cm.y["A43"] + gammas[1]*cm.y["A42"]
    # k=4
    cm.flow("A14", "I14", Phi)
    cm.flow("A24", "I24", Phi)
    cm.flow("A34", "I34", Phi)
    cm.flow("A44", "I44", Phi)
    cm.dy["A14"] = cm.dy["A14"] - (mu + gammas[3])*cm.y["A14"] + gammas[2]*cm.y["A13"]
    cm.dy["A24"] = cm.dy["A24"] - (mu + gammas[3])*cm.y["A24"] + gammas[2]*cm.y["A23"]
    cm.dy["A34"] = cm.dy["A34"] - (mu + gammas[3])*cm.y["A34"] + gammas[2]*cm.y["A33"]
    cm.dy["A44"] = cm.dy["A44"] - (mu + gammas[3])*cm.y["A44"] + gammas[2]*cm.y["A43"]

    # Return the differential changes
    return cm.dy


# Function to setup the model and return the integrator
def setup_model():

    # Define the time span for the simulation
    ts = np.linspace(0, 3600 * 5, 3600)

    # Create an ODE integrator object using the icomo library
    integrator_object = icomo.ODEIntegrator(
        ts_out=ts,  # Output time points
        t_0=min(ts),  # Initial time point
        ts_solver=ts,  # Time points for the solver to use
        ts_arg=ts,
    )

    # Get the integration function for the model
    integrator = integrator_object.get_func(HIV_model)  # Returns a function that can be used to solve the ODEs defined in the 'model' function


    return integrator






In [None]:
integrator = setup_model()
output = output = integrator(y0=y0, constant_args=args)