In [None]:
import time, os, sys
from pyDOE import lhs
from pathlib import Path

In [None]:
# local settings
basepath = Path().absolute().parent

# Add repository's src folder to python path
sys.path.append(str(basepath.joinpath('src')))
from fit_mc import *
from tcell_model import *

exp = 257
well = 'F'
datapath = basepath.joinpath(f'expdata/kinetics_exp{exp}')
ct = 'U-D-CTLA4'
celltypes = ['undivided', 'divided CTLA4-', 'divided CTLA4+']
df_ct = pd.read_csv(datapath.joinpath(f'celltypes_{ct}.csv.gz'),
                    sep='\t', compression='gzip')

In [None]:
def get_interpolated_data(line, well, cnt0):
    data = {}
    ct = 'U-D-CTLA4'
    # FCS data
    fn_expr = datapath.joinpath(f'intensity_for_celltypes_{ct}.csv.gz')
    df_expr = pd.read_csv(fn_expr, sep='\t', compression='gzip')
    sdf_expr = df_expr[(df_expr.line == line) & (df_expr.well == well) & (df_expr.cnt0 == cnt0)]
    data['CD25u'] = sdf_expr[sdf_expr.celltype == 'U'].groupby('time').mean()['CD25']
    data['CD25neg'] = sdf_expr[sdf_expr.celltype == 'Dneg'].groupby('time').mean()['CD25']
    data['CD25pos'] = sdf_expr[sdf_expr.celltype == 'Dpos'].groupby('time').mean()['CD25']
    data['CD80u'] = sdf_expr[sdf_expr.celltype == 'U'].groupby('time').mean()['CD80']
    data['CD80neg'] = sdf_expr[sdf_expr.celltype == 'Dneg'].groupby('time').mean()['CD80']
    data['CD80pos'] = sdf_expr[sdf_expr.celltype == 'Dpos'].groupby('time').mean()['CD80']
    data['CD86u'] = sdf_expr[sdf_expr.celltype == 'U'].groupby('time').mean()['CD86']
    data['CD86neg'] = sdf_expr[sdf_expr.celltype == 'Dneg'].groupby('time').mean()['CD86']
    data['CD86pos'] = sdf_expr[sdf_expr.celltype == 'Dpos'].groupby('time').mean()['CD86']

    # IL-2
    fn_pop = datapath.joinpath(f'processed_data.csv.gz')
    df_pop = pd.read_csv(fn_pop, sep='\t', compression='gzip')
    df_pop.loc[df_pop.well == 'Flat', 'well'] = 'F'
    data['IL2'] = df_pop[(df_pop.line == line) & (df_pop.well == well) &
                         (df_pop.cnt0 == cnt0)].groupby('time').mean()['IL-2']

    for key, df in data.items():
        data[key] = df.dropna()
    return data


def get_data_for_fit(line, well, cnt0, replicate=None):
    if replicate is not None:
        sdf = df_ct[(df_ct.line == line) & (df_ct.cnt0 == cnt0) &
                    (df_ct.well == well) & (df_ct.replicate == replicate)]
    else:
        sdf = df_ct[(df_ct.line == line) & (df_ct.cnt0 == cnt0) &
                    (df_ct.well == well)]
    sdf = sdf.sort_values('time')
    T = sdf.time.unique()
    sdf = sdf[sdf.time > 0]
    if replicate is None:
        g = sdf.groupby(('time')).mean()
        y = np.concatenate([g['cnt ' + celltypes[i]].values for i in range(3)])
        g = sdf.groupby(('time')).std()
        yerr = np.concatenate([g['cnt ' + celltypes[i]].values for i in range(3)])
    else:
        y = np.concatenate([sdf['cnt ' + celltypes[i]].values for i in range(3)])
        yerr = np.zeros_like(y)
    return T, y, yerr

In [None]:

def set_up_model_WT(p):
    activ = lambda state: p.p_U
    Uu_death = lambda state: p.d_Uu
    Us_death = lambda state: np.maximum(0, p.d_Us - p.f * state.CD25u * hill(state.IL2, p.n_U, p.k_U))
    p_D_cd25 = lambda state: p.p_D_cd25 * hill(state.IL2, p.n_D_il2, p.k_D_il2)
    w_U = lambda state: np.minimum(1, state.CD80u + state.CD86u) * state.Us
    w_Dneg = lambda state: np.minimum(1, state.CD80neg + state.CD86neg) * state.Dneg
    w_Dpos = lambda state: np.minimum(1, state.CD80pos + state.CD86pos) * state.Dpos
    p_D_cd28 = lambda state: p.p_D_cd28 * hill(w_U(state) + w_Dneg(state) + w_Dpos(state), p.n_D_cd28, p.k_D_cd28) * (
            1 - hill(state.Dpos, p.n_D_ctla4, p.k_D_ctla4))
    Dneg_growth = lambda state: p.p_D + np.minimum(state.CD25neg * p_D_cd25(state) + p_D_cd28(state),
                                                   p.p_D_cd25)
    Dpos_growth = lambda state: p.p_D + np.minimum(state.CD25pos * p_D_cd25(state) + p_D_cd28(state),
                                                   p.p_D_cd25)
    CTLA4_on = lambda state: p.p_Dpos * state.CD25neg * hill(state.IL2, p.n_Dpos, p.k_Dpos)
    CTLA4_off = lambda state: p.p_Dneg
    return ActivationModel(activ, Uu_death, Us_death, Dneg_growth, Dpos_growth, CTLA4_on,
                           CTLA4_off, delay=p.delay)


def set_up_model_DKO(p):
    activ = lambda state: p.p_U
    Uu_death = lambda state: p.d_Uu
    Us_death = lambda state: np.maximum(0, p.d_Us - p.f * state.CD25u * hill(state.IL2, p.n_U, p.k_U))
    p_D_cd25 = lambda state: p.p_D_cd25 * hill(state.IL2, p.n_D_il2, p.k_D_il2)
    Dneg_growth = lambda state: p.p_D + np.minimum(state.CD25neg * p_D_cd25(state),
                                                   state.CD25neg * p.p_D_cd25)
    Dpos_growth = lambda state: p.p_D + np.minimum(state.CD25pos * p_D_cd25(state),
                                                   state.CD25pos * p.p_D_cd25)
    CTLA4_on = lambda state: p.p_Dpos * state.CD25neg * hill(state.IL2, p.n_Dpos, p.k_Dpos)
    CTLA4_off = lambda state: p.p_Dneg
    return ActivationModel(activ, Uu_death, Us_death, Dneg_growth, Dpos_growth, CTLA4_on,
                           CTLA4_off, delay=p.delay)


In [None]:
# defaults, parameters in parnames will be overwritten
pars_def = Parameters(p_U=.1, d_U=.024, p_D=-.13, p_D_cd25=.2, p_D_cd28=.06, p_Dpos=1, p_Dneg=.1,
                      k_D_ctla4=100000, k_D_cd28=9000, k_D_cd25=300, k_D_il2=300, k_Dpos=40,
                      k_Dpos_CD25=4000, k_Dneg=1e-10,
                      n_D_ctla4=2, n_D_cd28=2, n_D_cd25=5, n_D_il2=5, n_Dpos=2, n_Dpos_CD25=3,
                      delay=28, U_0=0, n_U=5, f_unsens=0)
# set pars_def to good fit
df_pars = pd.read_csv(basepath.joinpath('results','best_fits.csv'),sep='\t')
pars_def = dfrow_to_pars(pd.DataFrame(df_pars[df_pars['well']==well]).transpose(),pars_def)


In [None]:
# function that runs the models included in the fit
def run_models(*args):
    pars = copy.deepcopy(pars_def)
    for i, pn in enumerate(parnames):
        setattr(pars, pn, args[0][i])
    # DKO
    cnts = [25000, 50000, 100000]
    success = True
    out = np.empty((0))
    for cnt0 in cnts:
        models = {'DKO': set_up_model_DKO(pars), 'WT': set_up_model_WT(pars)}
        pars.U_0 = cnt0
        y0 = [pars.f_unsens * pars.U_0, (1 - pars.f_unsens) * pars.U_0, pars.Dneg_0, pars.Dpos_0]
        for line, model in models.items():
            x, yref, ysd = get_data_for_fit(line, well, cnt0)
            if success:
                try:
                    model.run(y0, x, get_interpolated_data(line, well, cnt0), method=solver)
                except:
                    success = False
                if success:
                    y = np.concatenate((model.state.Us[1:]+model.state.Uu[1:],
                                        model.state.Dneg[1:], model.state.Dpos[1:]))
                    yerr = yref - y
                else:
                    yerr = np.ones_like(yref) * 1e18
            else:
                yerr = np.ones_like(yref) * 1e18
            out = np.concatenate((out, yerr / pars.U_0))
    if not success:
        out = np.ones_like(out) * 1e18
    return out

In [None]:
# fittings settings
solver = 'BDF'
nsamp = 2
ncores = 1
fit_id = f'fit_run00'
parnames = ['delay', 'p_U', 'd_Us']

# parameter fitting limits
p_min = [0, 0, 0]
p_max = [50, 1, 1]
# default values
p_def = np.array([getattr(pars_def, pn) for pn in parnames])
# limits for initial guesses
p0_min = np.array([0, 0, 0])
p0_max = np.array([50, 1, 1])
# LHS sample
# for the paper we did LHS sampling, but to get this sample running quickly, we just modify the fitted values slightly
#lhd = lhs(len(parnames), samples=nsamp)
#p0 = np.array(p0_min + (p0_max - p0_min) * lhd)
# # correct parameters values outside of limits
# for i, pn in enumerate(parnames):
#     p0[:, i][p0[:, i] > p_max[i]] = p_max[i]
#     p0[:, i][p0[:, i] < p_min[i]] = p_min[i]
# p0_list = list(p0)
p0_list = []
for n in range(nsamp):
    noise = np.random.normal(1,1e-3,size=len(parnames))
    p0_list.append(noise*p_def)
    
# run fit - this may take a while!
best = fit_pars_lsq(run_models, p0_list, parnames, bounds=(p_min, p_max), ncores=ncores,
                    fn_log=None)