In [26]:
import numpy as np
from scipy.integrate import solve_ivp
from abc import ABCMeta, abstractmethod
import pandas as pd
import matplotlib.pyplot as plt

In [33]:
class AbsModel(metaclass = ABCMeta):
    @abstractmethod
    def get_y0(self, p):
        pass
    
    @abstractmethod
    def __call__(self, t, y, p):
        pass
    
    @abstractmethod
    def measure(self, t, y, p):
        pass
    
    def simulate(self, p, t_end=10):
        y0 = self.get_y0(p)
        ys = solve_ivp(self, [0, t_end], y0, args = (p, ), dense_output=True)
        
        ms = [self.measure(t, ys.sol(t), p) for t in np.linspace(0, t_end, round(t_end) + 1)]
        ms = pd.DataFrame(ms)
        return ys, ms

In [44]:
class ModelPlain(AbsModel):
    def get_y0(self, p):
        y0 = np.zeros(4)
        
        r_die, r_sc = p['r_die'], p['r_sc']
        mu = r_die + r_sc
        r_onset, r_csi, r_det = p['r_onset'], p['r_csi'], p['r_det']
        adr = p['adr']
        
        asym = p['inc'] / (r_onset + r_sc - adr)
        sym = r_onset * asym / (r_csi + mu - adr)
        ex = r_csi * sym / (r_det + mu - adr)
        y0 = np.array([asym, sym, ex])
        return y0
    
    def __call__(self, t, y, p):
        dy = np.zeros_like(y)
        
        r_die, r_sc = p['r_die'], p['r_sc']
        mu = r_die + r_sc
        r_onset, r_csi, r_det = p['r_onset'], p['r_csi'], p['r_det']
        adr = p['adr']
        
        asym, sym, ex = y
        n = y.sum()
        inc = r_sc * asym + mu * (sym + ex) + r_det * ex - adr * n
        
        dy[0] = inc - (r_onset + r_sc) * asym
        dy[1] = r_onset * asym - (r_csi + mu) * sym
        dy[2] = r_csi * sym - (r_det + mu) * ex
        
        return dy
    
    def measure(self, t, y, p):
        r_die, r_sc = p['r_die'], p['r_sc']
        mu = r_die + r_sc
        r_det = p['r_det']
        adr = p['adr']
        
        asym, sym, ex = y
        n = y.sum()
        inc = r_sc * asym + mu * (sym + ex) + r_det * ex - adr * n
        
        return {
            'Time': t,
            'Inc': inc,
            'CNR': r_det * ex,
            'Prev': y.sum(),
            'PrA': asym / n,
            'PrS': sym / n,
            'PrC': ex / n
        }
    

In [45]:
mp = ModelPlain()

pars = {
    "inc": 200,
    "r_die": 0.1,
    "r_sc": 0.2,
    "r_onset": 3,
    "r_csi": 2, 
    "r_det": 2,
    "adr": 0.01,
} 

y0 = mp.get_y0(pars)

ys, ms = mp.simulate(pars)
ms

Unnamed: 0,Time,Inc,CNR,Prev,PrA,PrS,PrC
0,0.0,200.0,143.4662,216.563424,0.289504,0.379263,0.331234
1,1.0,198.009967,142.038687,214.408582,0.289504,0.379263,0.331234
2,2.0,196.039735,140.625378,212.275181,0.289504,0.379263,0.331234
3,3.0,194.089108,139.226134,210.163007,0.289504,0.379263,0.331234
4,4.0,192.157908,137.840831,208.07185,0.289504,0.379262,0.331234
5,5.0,190.245922,136.46931,206.001501,0.289504,0.379262,0.331234
6,6.0,188.352937,135.111411,203.951752,0.289504,0.379262,0.331234
7,7.0,186.478763,133.766997,201.922398,0.289504,0.379263,0.331234
8,8.0,184.623227,132.435949,199.913236,0.289504,0.379263,0.331234
9,9.0,182.786176,131.118168,197.924067,0.289504,0.379263,0.331234


array([-0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01,
       -0.01])