In [1]:
import pymc as pm
import pandas as pd
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import arviz as az



In [2]:
data_dir = '.'
cts = pd.read_csv(f'{data_dir}/cells_per_ko.csv', index_col=0).iloc[:,7:12]

## Dirichlet - multinomial

Taking into account the different number of cells per sample.

In [44]:
cts_unperturbed = cts.loc[cts.index=='Unperturbed', ]
cts_perturbed = cts.loc[cts.index!='Unperturbed', ]
cts_perturbed.head(3)

Unnamed: 0_level_0,cycling,effector,other,progenitor,terminal exhausted
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Tox2,2300.0,247.0,117.0,75.0,1594.0
Arid5b,1234.0,107.0,42.0,27.0,992.0
Dvl2,1194.0,251.0,14.0,53.0,661.0


In [None]:
# Fit dirichlet - multinomial with unperturbed samples
k=cts_unperturbed.shape[1]

with pm.Model(coords={"cell_type": cts_unperturbed.columns.values}) as dirMulti_model:
    cts = pm.Data("cts", cts_unperturbed.values, mutable=True)
    n = pm.Data("n", cts_unperturbed.sum(axis=1).values, mutable=True)
    # Dirichlet prior for p
    proportions = pm.Dirichlet(
        'proportions',
        a=np.array([5.0] * k).astype("float32"),
        initval=np.array([0.1] * k),
        shape=(5,)
    )
    # Likelihood (sampling distribution) of observations
    counts = pm.Multinomial(
        'counts',
        n=n,
        p=proportions,
        observed=cts
    )
    dirMulti_trace = pm.sample(2000, chains=4, return_inferencedata=True)

az.summary(dirMulti_trace, round_to=2)

In [None]:
params_plot = az.plot_forest(dirMulti_trace)
fig = params_plot.ravel()[0].figure
fig.savefig('dirMult_forest_plot.png')

trace_plot = az.plot_trace(dirMulti_trace)
fig = trace_plot.ravel()[0].figure
fig.savefig('dirMult_trace_plot.png')

In [None]:
# Sample from the posterior predictive distribution for perturbed samples
pm.set_data({"cts": cts_perturbed.values, "n": cts_perturbed.sum(axis=1).values}, model=dirMulti_model)
ppc_test = pm.sample_posterior_predictive(dirMulti_trace, model=dirMulti_model)

# Probability of the sample being drawn from the fitted distribution
observed = ppc_test['observed_data']['counts'].values
predicted = ppc_test['posterior_predictive']['counts'].values
geq = np.zeros(observed.shape)
nsamples = predicted.shape[0] * predicted.shape[1]

for i in range(geq.shape[0]):
    for j in range(geq.shape[1]):
        geq[i,j] = min(np.sum(observed[i,j] >= predicted[:,:,i,j])/nsamples, np.sum(observed[i,j] <= predicted[:,:,i,j])/nsamples)

np.sum(geq < 0.05, axis=1)

## Dirichlet multinomial

Ignoring different number of cells per sample.

In [38]:
cts_unperturbed = (cts_unperturbed.div(cts_unperturbed.sum(axis=1), axis=0)*100).round()
cts_perturbed = (cts_perturbed.div(cts_perturbed.sum(axis=1), axis=0)*100).round()
cts_perturbed.head(3)

Unnamed: 0_level_0,cycling,effector,other,progenitor,terminal exhausted
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Tox2,0.53081,0.057004,0.027002,0.017309,0.367874
Arid5b,0.513739,0.044546,0.017485,0.011241,0.412989
Dvl2,0.549471,0.115509,0.006443,0.02439,0.304188


In [None]:
# Fit dirichlet - multinomial with unperturbed samples
k=cts_unperturbed.shape[1]

with pm.Model(coords={"cell_type": cts_unperturbed.columns.values}) as dirMulti_model:
    cts = pm.Data("cts", cts_unperturbed.values, mutable=True)
    n = pm.Data("n", cts_unperturbed.sum(axis=1).values, mutable=True)
    # Dirichlet prior for p
    proportions = pm.Dirichlet(
        'proportions',
        a=np.array([5.0] * k).astype("float32"),
        initval=np.array([0.1] * k),
        shape=(5,)
    )
    # Likelihood (sampling distribution) of observations
    counts = pm.Multinomial(
        'counts',
        n=n,
        p=proportions,
        observed=cts
    )
    dirMulti_trace = pm.sample(2000, chains=4, return_inferencedata=True)

az.summary(dirMulti_trace, round_to=2)

In [None]:
params_plot = az.plot_forest(dirMulti_trace)
fig = params_plot.ravel()[0].figure
fig.savefig('dirMult_normalized_forest_plot.png')

trace_plot = az.plot_trace(dirMulti_trace)
fig = trace_plot.ravel()[0].figure
fig.savefig('dirMult_normalized_trace_plot.png')

In [None]:
# Sample from the posterior predictive distribution for perturbed samples
pm.set_data({"cts": cts_perturbed.values, "n": cts_perturbed.sum(axis=1).values}, model=dirMulti_model)
ppc_test = pm.sample_posterior_predictive(dirMulti_trace, model=dirMulti_model)

# Probability of the sample being drawn from the fitted distribution
observed = ppc_test['observed_data']['counts'].values
predicted = ppc_test['posterior_predictive']['counts'].values
geq = np.zeros(observed.shape)
nsamples = predicted.shape[0] * predicted.shape[1]

for i in range(geq.shape[0]):
    for j in range(geq.shape[1]):
        geq[i,j] = min(np.sum(observed[i,j] >= predicted[:,:,i,j])/nsamples, np.sum(observed[i,j] <= predicted[:,:,i,j])/nsamples)

np.sum(geq < 0.05, axis=1)