In [49]:
from scipy.integrate import solve_ivp
import numpy as np
import pandas as pd
import json

In [50]:
loc = 'India'

with open(f'docs/pars/pars_nods_{loc}.json', 'r') as f:
    post = json.load(f)


In [51]:
class I:
    NonTB = 0
    Asym = 1
    Sym = 2
    ExCS = 3
    TxPub = 4
    TxEng = 5
    TxPri = 6
    FpPub = 7
    FpEng = 8
    FpPri = 9

In [84]:
class Model:
    def get_y0(self, t, pars):
        k = np.exp(- pars['adr'] * (t - 2020))
        prev = k * pars['prv0']

        p_pub, p_ppm = pars['p_pub'], pars['p_ppm']

        p_entry = np.array([p_pub, (1 - p_pub) * p_ppm, (1 - p_pub) * (1 - p_ppm)])

        p_txi = np.array([pars['p_txi_pub'], pars['p_txi_eng'], pars['txi_pri']])
        
        r_det_all = (pars['r_det'] * p_entry * p_txi).sum()

        asc = np.array([
            (pars['rs'] + pars['r_aware'] - pars['adr']) / pars['r_sym'],
            1,
            pars['r_aware'] / (pars['rc'] + r_det_all - pars['adr'])
        ])
        asc /= asc.sum()

        y0 = np.zeros(10)
        y0[I.NonTB] = 1 - prev
        y0[I.Asym] = asc[0] * prev * k
        y0[I.Sym] = asc[1] * prev * k
        y0[I.ExCS] = asc[2] * prev * k
        return y0
    
    def calc(self, t, y, pars):
        calc = dict()

        calc['onset'] = onset = pars['r_sym'] *  y[I.Asym]
        calc['aware'] = pars['r_aware'] * y[I.Sym]
    
        p_det = np.array([pars['p_pub'], (1 - pars['p_pub']) * pars['p_ppm'], (1 - pars['p_pub']) * (1 - pars['p_ppm'])])
        p_txi = np.array([pars['p_txi_pub'], pars['p_txi_eng'], pars['txi_pri']])
        calc['det'] = pars['r_det'] * p_det * p_txi * y[I.ExCS]

        calc['sc_a'] = sc_a = pars['r_sc'] * y[I.Asym]
        calc['sc_s'] = pars['r_sc'] * y[I.Sym]
        calc['sc_c'] = pars['r_sc'] * y[I.ExCS]
        calc['die_a'] = die_a = pars['r_death_a'] * y[I.Asym]
        calc['die_s'] = pars['r_death_s'] * y[I.Sym]
        calc['die_c'] = pars['r_death_s'] * y[I.ExCS]

        calc['acf_a'] = acf_a = pars['r_acf'] * pars['sens_acf'] * y[I.Asym]
        calc['acf_s'] = pars['r_acf'] * pars['sens_acf'] * y[I.Sym]
        calc['acf_c'] = pars['r_acf'] * pars['sens_acf'] * y[I.ExCS]

    
        calc['inc'] = onset + acf_a + sc_a + die_a - pars['adr'] * y[I.Asym]

        return calc

    def __call__(self, t, y, pars):
        calc = self.calc(t, y, pars)
        dy = np.zeros_like(y)
        
        inc, onset, aware = calc['inc'], calc['onset'], calc['aware']
        sc_a, sc_s, sc_c = calc['sc_a'], calc['sc_s'], calc['sc_c']
        die_a, die_s, die_c = calc['die_a'], calc['die_s'], calc['die_c']
        acf_a, acf_s, acf_c = calc['acf_a'], calc['acf_s'], calc['acf_c']
        det = calc['det']
        
        dy[I.Asym] += inc - onset - acf_a - sc_a - die_a
        dy[I.Sym] += onset - aware - acf_s - sc_s - die_s
        dy[I.ExCS] += aware - det.sum() - acf_c - sc_c - die_c
        return dy
    
    def measure(self, t, y, pars):
        calc = self.calc(t, y, pars)
        mea = {'Time': t}
        mea['N'] = n = y.sum()
        mea['PrevA'] = y[I.Asym]
        mea['PrevS'] = y[I.Sym]
        mea['PrevC'] = y[I.ExCS]
        mea['Inc'] = calc['inc']
        mea['DetR'] = calc['det'].sum()
        return mea
    
    
    def simulate(self, pars):
        y0 = self.get_y0(2015, pars)
        ys = solve_ivp(self, [2015, 2025], y0 = y0, args = (pars, ), dense_output=True)
        ms = [self.measure(t, ys.sol(t), pars) for t in np.linspace(2015, 2025, 21)]
        ms = pd.DataFrame(ms)
        return ys, ms
        
    

In [85]:
m = Model()
ys, ms = m.simulate(post[0])

In [86]:
ms.PrevC

0     0.000627
1     0.000614
2     0.000603
3     0.000588
4     0.000569
5     0.000554
6     0.000548
7     0.000534
8     0.000523
9     0.000512
10    0.000499
11    0.000490
12    0.000476
13    0.000468
14    0.000455
15    0.000447
16    0.000434
17    0.000428
18    0.000414
19    0.000408
20    0.000397
Name: PrevC, dtype: float64

In [92]:
- np.diff(np.log(ms.PrevS)) * 2, post[0]['adr']

(array([0.04608175, 0.04742551, 0.04461951, 0.04148573, 0.04329306,
        0.04877578, 0.04575413, 0.04565983, 0.04561016, 0.04522713,
        0.04556752, 0.0452651 , 0.04546274, 0.04536335, 0.04544936,
        0.04536657, 0.04543827, 0.04537266, 0.0454197 , 0.04539416]),
 0.0454)

In [77]:
ys = solve_ivp(fn, [2015, 2025], y0 = y0, args = (p, ))

In [93]:
m.calc(2015, m.get_y0(2015, pars), pars)

{'onset': 0.0023904670944685223,
 'aware': 0.0021189126810792175,
 'det': array([1.37435449e-03, 5.24939999e-04, 3.52863407e-06]),
 'sc_a': 0.0006317696799994423,
 'sc_s': 0.00021070100057753613,
 'sc_c': 0.0001676654267631957,
 'die_a': 0.0,
 'die_s': 9.4590797116739e-05,
 'die_c': 7.527067419223152e-05,
 'acf_a': 6.286971001865008e-06,
 'acf_s': 2.096762669421048e-06,
 'acf_c': 1.6684999445944653e-06,
 'inc': 0.002921219804310179}