In [12]:
import json
import numpy as np
import pandas as pd
import numpy.random as rd
from collections import Counter
from ppa.pars_cas import restructure_pars_cas

In [2]:
locations = list(pd.read_csv('data/locations.csv').Location)

In [3]:
loc = 'India'

In [4]:
pars = json.load(open(f'data/{loc}/pars_cas.json', 'r'))

In [5]:
class DistCompeting:
    def __init__(self, kr):
        self.KeyRates = kr
        
    def __call__(self):
        kr = [(k, rd.exponential(1 / r)) for k, r in self.KeyRates.items()]
        nxt, tte = min(kr, key = lambda x: x[1])
        return nxt, tte
    
class DistSemi:
    def __init__(self, kr):
        self.States = list(kr.keys())
        ps = np.array(list(kr.values()))
        self.Dur = 1 / ps.sum()
        self.Probs = ps * self.Dur

    def __call__(self):
        nxt = rd.choice(self.States, p=self.Probs)
        return nxt, self.Dur

In [7]:
class Agent:
    def __init__(self):
        self.History = list()
        
        self.N_Vis = 0
        self.HasOnset = False
        self.HasDx = False
        self.HasTxI = False
        self.HasTxSucc = False

    def append(self, st, t):
        self.History.append((st, t))


class Simulator:
    def __init__(self, p):
        p = restructure_pars_cas(p)
        self.Pars = p
        self.Samplers = {
            'Asym': DistCompeting({'Sym': p['r_onset'], 'DieUt': p['r_die_asym'], 'SelfCure': p['r_sc']}),
            'Sym': DistCompeting({'CSI': p['r_csi'], 'DieUt': p['r_die_sym'], 'SelfCure': p['r_sc']}),
            'ExCS': DistCompeting({'ReCSI': p['r_recsi'], 'DieUt': p['r_die_sym'], 'SelfCure': p['r_sc']}),
            'TxPub': DistSemi({'Succ': p['r_txs'][0], 'DieTx': p['r_txd'][0], 'LTFU': p['r_txl'][0]}),
            'TxPri': DistSemi({'Succ': p['r_txs'][1], 'DieTx': p['r_txd'][1], 'LTFU': p['r_txl'][1]})
        }
        self.Absorbing = ['DieUt', 'SelfCure', 'DieTx', 'Succ', 'LTFU']
        
    def _sim_an_agent(self):
        p = simulator.Pars
        
        ag = Agent()

        state, ti = 'Asym', 0

        ag.append(state, ti)

        while not (state in self.Absorbing):
            nxt, tte = simulator.Samplers[state]()

            if nxt == 'Sym':
                ag.HasOnset = True
            elif nxt in ['CSI', 'ReCSI']:
                ag.N_Vis += 1
                sector = rd.choice([0, 1, 2], p=p['p_ent'])
                if rd.random() < p['p_dx'][sector]:
                    ag.HasDx = True
                    if rd.random() < p['p_txi'][sector]:
                        ag.HasTxI = True
                        if rd.random() < p['tx_alo'][sector][0]:
                            nxt = 'TxPub'
                        else:
                            nxt = 'TxPri'
                    else:
                        nxt = 'ExCS'
                else:
                    nxt = 'ExCS'
            state = nxt
            ti += tte
            ag.append(state, ti)
        return ag
        
    def simulate(self, n = 100):
        return [self._sim_an_agent() for _ in range(n)]
        
        

In [8]:
p = pars['pars'][4]
simulator = Simulator(p)

In [9]:
ags = simulator.simulate(n = 30)

In [11]:
ts = np.linspace(0, 2, 41)
ts

array([0.  , 0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 ,
       0.55, 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1.  , 1.05,
       1.1 , 1.15, 1.2 , 1.25, 1.3 , 1.35, 1.4 , 1.45, 1.5 , 1.55, 1.6 ,
       1.65, 1.7 , 1.75, 1.8 , 1.85, 1.9 , 1.95, 2.  ])

In [13]:
def count_history(ags, ts):
    ds = []

    for t0 in ts:
        cnt = Counter()
        for ag in ags:
            history = ag.History
            evt, _ = [h for h in history if h[1] <= t0][-1]
            cnt[evt] += 1

        d = {'Time': t0}
        d.update(cnt)
        ds.append(d)

    ds = pd.DataFrame(ds).fillna(0)
    return ds

def 

In [91]:
test_ssm = 11178433
test_naat = 2179976

det_bac = 1173775
det_cdx = 1231040
det_all = det_bac + det_cdx

ppv = 0.50

tp_all = det_all * ppv

sens_ssm, spec_ssm = 0.64, 0.98
sens_naat, spec_naat = 0.8, 0.98
cs = 0.0192
x = cs
test_ssm * (x * sens_ssm + (1 - x) * (1 - spec_ssm)) + test_naat * (x * sens_naat + (1 - x) * (1 - spec_naat))

432883.5670080002

In [92]:
test_ssm * (1 - spec_ssm) + test_naat * (1 - spec_naat) + x * (test_ssm * (sens_ssm - 1 + spec_ssm) + test_naat * (sens_naat - 1 + spec_naat))

432883.5670080002

In [93]:
x = (det_bac - test_ssm * (1 - spec_ssm) - test_naat * (1 - spec_naat)) / (test_ssm * (sens_ssm - 1 + spec_ssm) + test_naat * (sens_naat - 1 + spec_naat))
x

0.10504064383085725

In [94]:
tp = test_ssm * (x * sens_ssm) + test_naat * (x * sens_naat)
fp = test_ssm * (1 - x) * (1 - spec_ssm) + test_naat * (1 - x) * (1 - spec_naat)

tp + fp

1173775.0

In [95]:
fn = test_ssm * (x * (1 - sens_ssm)) + test_naat * (x * (1 - sens_naat))
tn = test_ssm * ((1 - x) * spec_ssm) + test_naat * ((1 - x) * spec_naat)
fn, tn

(468505.5442775997, 11716128.455722399)

In [96]:
sens_cdx = (tp_all - tp) / fn
fp_all = det_all - tp_all
spec_cdx = 1 - (fp_all - fp) / tn
sens_cdx, spec_cdx

(0.5714706381426339, 0.9177797647679579)

In [97]:
tp1 = test_ssm * (sens_ssm + (1 - sens_ssm) * sens_cdx) + test_naat * (sens_naat + (1 - sens_naat) * sens_cdx)
sens_all = tp1 / (test_ssm + test_naat)

In [98]:
fp1 = test_ssm * ((1 - spec_ssm) + spec_ssm * (1 - spec_cdx)) + test_naat * ((1 - spec_naat) + spec_naat * (1 - spec_cdx))
spec_all = 1 - fp1 / (test_ssm + test_naat)

In [99]:
sens_all, spec_all

(0.8569185912447516, 0.8994241694725987)

In [76]:
x * sens_all / (x * sens_all + (1 - x) * (1 - spec_all))

0.5000000000000001