In [90]:
%load_ext autoreload
%autoreload 2

import os
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import pytensor
from pytensor import tensor as T
from sklearn.preprocessing import scale, StandardScaler, LabelEncoder
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage, leaves_list
import itertools
import pickle
import dill
from pyprojroot.here import here
import numpyro

numpyro.enable_x64()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
here('submission/draft/survival_clustering.ipynb')

PosixPath('/Users/alzhang/Documents/projects/tfri_halo/submission/draft/survival_clustering.ipynb')

In [3]:
def create_trace_table(trace, export_variables = ['props', 'beta_clust', 'beta_stage', 'beta_age', 'beta_chemo', 'beta_rt', 'beta_brachy', 'beta_histotype']):
    # Create an empty list to store DataFrames
    export_dfs = []
    
    # Iterate over the list of column names
    for variable in export_variables:
        print(variable)
        posterior = trace.posterior[variable][0]
        
        if posterior.ndim == 2:
            posterior_clust_assignments = pd.DataFrame(posterior, columns=[f'{variable}_{i}' for i in range(posterior.shape[1])])
        elif posterior.ndim == 1:
            posterior_clust_assignments = pd.DataFrame(posterior, columns=[variable])
        else:
            print("Should not get here.")
        
        # Append the DataFrame to the list
        export_dfs.append(posterior_clust_assignments)
    
    # Concatenate the DataFrames horizontally (column bind)
    trace_table = pd.concat(export_dfs, axis=1)
    return trace_table

## Inputs

In [4]:
# Survival data, long format by outcome
clinical_long = pd.read_csv(here('results/survival_cluster/clinical_long.tsv'), sep='\t')

# Counts data, long format by TIL type and region
counts_long = pd.read_csv(here('results/survival_cluster/counts_final.tsv'), sep='\t')

In [5]:
# Filter for p53abn
clinical_long = clinical_long.loc[clinical_long['eclass2_ngs'] == 'p53abn'].dropna()

eclass_encoder = LabelEncoder()
stage_encoder = LabelEncoder()

clinical_long['eclass2_ngs_idx'] = eclass_encoder.fit_transform(clinical_long['eclass2_ngs'])
clinical_long['stage_idx'] = stage_encoder.fit_transform(clinical_long['stage_main'])
clinical_long['carcinosarcoma'] = (clinical_long['hist_rev'] == 'carcinosarcoma (MMMT)')
clinical_long['age_dx'] = scale(clinical_long['age_dx'], axis=0, with_mean=True, with_std=True, copy=True)

clinical_vars = ["acc_num", "outcome", "time", "status", "chemo", "rt", "brachy", "eclass2_ngs_idx", "stage_idx", "age_dx", "carcinosarcoma"]

clinical_selected = clinical_long[clinical_vars].drop_duplicates()

In [6]:
counts_long['variable_region'] = counts_long['variable'] + '_' + counts_long['region']

counts_wide = counts_long.pivot(index='acc_num', columns='variable_region', values='value').dropna()
areas_wide = counts_long.pivot(index='acc_num', columns='variable_region', values='area_region_mm').dropna()

In [7]:
# Accepts counts and areas in wide format, clinical data in long format
# Returns clinical data, areas, and counts indexed identically
def get_inputs(counts, areas, clinical, outcome):
    clinical = clinical[clinical['outcome'] == outcome].set_index('acc_num')
    
    common_samples = np.intersect1d(counts.index, clinical.index)
    counts = counts.loc[common_samples,:]
    areas = areas.loc[common_samples,:]
    clinical = clinical.loc[common_samples,:]

    return {'clinical': clinical, 'areas': areas, 'counts': counts}

In [8]:
# Output N x C (number of samples X number of region*cell types) counts array, with corresponding N x C area array
# Note that this is DIFFERENT from previous

os_inputs = get_inputs(counts_wide, areas_wide, clinical_selected, outcome = 'os')
pfs_inputs = get_inputs(counts_wide, areas_wide, clinical_selected, outcome = 'pfs')
dss_inputs = get_inputs(counts_wide, areas_wide, clinical_selected, outcome = 'dss')

In [44]:
def fit_survcluster_model(inputs, nclusts = 2, ncenters = 20, interval_length = 0.3, epsilon = 1e-6, ndraw = 1000, ntune=1000):
    count_mat = np.array(inputs['counts'])
    area_mat = np.array(inputs['areas'])
    clinical_df = inputs['clinical']
    time = clinical_df['time'].values
    event = clinical_df['status'].values
    stage = clinical_df['stage_idx'].values
    age = clinical_df['age_dx'].values
    chemo = clinical_df['chemo'].values.astype(int)
    rt = clinical_df['rt'].values.astype(int)
    brachy = clinical_df['brachy'].values.astype(int)
    
    nstages = len(np.unique(stage))
    ncelltypes = count_mat.shape[1]
    nsamples = count_mat.shape[0]
    
    mean_mu = np.sum(count_mat, axis = 0)/np.sum(area_mat, axis = 0)
    mean_mu = np.repeat(mean_mu[np.newaxis,:], nclusts, axis=0).transpose(0, 1)
    
    area_nonzero = area_mat + epsilon
    
    rbf_step = np.max(count_mat)/(ncenters-1.)
    centers = np.arange(ncenters) * rbf_step
    
    # intervals 
    
    samples = np.arange(nsamples)
    interval_bounds = np.arange(0, time.max() + interval_length + 1, interval_length)
    
    nintervals = interval_bounds.size - 1
    intervals = np.arange(nintervals)
    
    last_period = np.floor((time - 0.01) / interval_length).astype(int)
    
    death = np.zeros((nsamples, nintervals))
    death[samples, last_period] = event
    
    exposure = np.greater_equal.outer(time, interval_bounds[:-1]) * interval_length
    exposure[samples, last_period] = time - interval_bounds[last_period]
    
    coords = {"intervals": intervals}

    with pm.Model(coords = coords) as survival_mixture_model:
        # Priors for survival coefficients
        
        beta_clust0 = pm.Normal("beta_clust0", mu=0, sigma=5, shape=nclusts-1)
        beta_clust = pm.Deterministic("beta_clust", pm.math.concatenate([[0], beta_clust0]))
        beta_stage0 = pm.Normal("beta_stage0", mu=0, sigma=5, shape=nstages-1)
        beta_stage = pm.Deterministic("beta_stage", pm.math.concatenate([[0], beta_stage0]))
        beta_age = pm.Normal("beta_age", mu=0, sigma=5)
        beta_chemo = pm.Normal("beta_chemo", mu = 0, sigma = 5)
        beta_rt = pm.Normal("beta_rt", mu = 0, sigma = 5)
        beta_brachy = pm.Normal("beta_brachy", mu = 0, sigma = 5)
    
        # Priors for count coefficient
        mu_clust = pm.Gamma("mu_clust", mu = mean_mu, sigma = 100, shape = (nclusts, ncelltypes))
        
        # Latent categorical variable for 'clust'
        props = pm.Dirichlet('props', np.ones(nclusts))
        clust = pm.Categorical("clust", p=props, shape=nsamples)
    
        # NB distribution for count data using adjusted area and mu_clust
        count_mu = area_nonzero * mu_clust[clust]  # Use 'area' as multiplier
    
        # RBF based NB dispersion 
        theta_a = pm.Normal("theta_a", mu = 0, sigma=1, shape=ncenters)
        theta_b = pm.Normal("theta_b", mu = 0, sigma=1, shape=ncenters)
        count_disp = T.dot(T.exp(-T.exp(theta_b) * (T.reshape(T.repeat(count_mu, repeats=ncenters), newshape=(nsamples, ncelltypes, ncenters)) - centers[np.newaxis, np.newaxis, :])**2), T.exp(theta_a)) + epsilon
    
        # Likelihood for count data
        count_v = pm.NegativeBinomial("count_v", mu=count_mu, alpha=count_disp, observed=count_mat)
        
        # Don't forget lambda_0 AFTERWARDS ADD THIS -- need to fix intercepts and dimensions as a result
        lambda0 = pm.Gamma("lambda0", 0.1, 0.1, dims = "intervals") # 0.01, 0.05, 0.1
        
        # Linear predictor for Cox PH model
        linear_predictor = (
            beta_clust[clust]
            + beta_stage[stage]
            + beta_age * age
            + beta_chemo * chemo
            + beta_rt * rt
            + beta_brachy * brachy
        )
    
    
        lambda_ = pm.Deterministic("lambda_", T.outer(T.exp(linear_predictor), lambda0))
    
        mu = pm.Deterministic("mu", exposure * lambda_)
    
        # Poisson-Cox PH equivalence
        obs = pm.Poisson("obs", mu=mu, observed=death)
    
        survival_mixture_model.debug(verbose = True)
    
        trace = pm.sample(ndraw, tune=ntune, chains=1, progressbar=True)

        return {'trace': trace, 'model': survival_mixture_model}

In [55]:
os_outputs = fit_survcluster_model(os_inputs, nclusts = 2, ncenters = 20, interval_length = 0.3, epsilon = 1e-6, ndraw = 1000, ntune=1000)

point={'beta_clust0': array([0.]), 'beta_stage0': array([0., 0., 0.]), 'beta_age': array(0.), 'beta_chemo': array(0.), 'beta_rt': array(0.), 'beta_brachy': array(0.), 'mu_clust_log__': array([[5.74134442, 2.3838835 , 4.46843508, 1.23147393, 2.59370369,
        1.00939819, 5.61453408, 4.55177006, 2.46390665, 1.13416853,
        5.53723606, 3.99462307, 3.63119581, 2.00995528, 4.84529417,
        2.38146401, 5.03593986, 2.95538798, 4.79499169, 3.24137003,
        6.13982856, 4.427772  , 6.11182582, 3.85459295, 5.21345766,
        2.7528743 ],
       [5.74134442, 2.3838835 , 4.46843508, 1.23147393, 2.59370369,
        1.00939819, 5.61453408, 4.55177006, 2.46390665, 1.13416853,
        5.53723606, 3.99462307, 3.63119581, 2.00995528, 4.84529417,
        2.38146401, 5.03593986, 2.95538798, 4.79499169, 3.24137003,
        6.13982856, 4.427772  , 6.11182582, 3.85459295, 5.21345766,
        2.7528743 ]]), 'props_simplex__': array([0.]), 'clust': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

Sequential sampling (1 chains in 1 job)
CompoundStep
>NUTS: [beta_clust0, beta_stage0, beta_age, beta_chemo, beta_rt, beta_brachy, mu_clust, props, theta_a, theta_b, lambda0]
>BinaryGibbsMetropolis: [clust]


Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 622 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


In [56]:
pm.summary(os_outputs['trace'], var_names=["props", "beta_clust", "beta_stage", "beta_age", "beta_chemo", "beta_rt", "beta_brachy"])



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
props[0],0.509,0.038,0.439,0.577,0.004,0.003,91.0,541.0,
props[1],0.491,0.038,0.423,0.561,0.004,0.003,91.0,541.0,
beta_clust[0],0.0,0.0,0.0,0.0,0.0,0.0,1000.0,1000.0,
beta_clust[1],-0.447,0.253,-0.945,0.017,0.008,0.006,956.0,733.0,
beta_stage[0],0.0,0.0,0.0,0.0,0.0,0.0,1000.0,1000.0,
beta_stage[1],0.625,0.694,-0.609,1.938,0.024,0.018,956.0,613.0,
beta_stage[2],1.063,0.333,0.397,1.642,0.014,0.011,568.0,581.0,
beta_stage[3],2.894,0.44,2.069,3.745,0.023,0.017,360.0,446.0,
beta_age,0.205,0.127,-0.017,0.454,0.004,0.003,1042.0,588.0,
beta_chemo,-0.872,0.33,-1.458,-0.25,0.013,0.009,641.0,566.0,


In [51]:
pfs_outputs = fit_survcluster_model(pfs_inputs, nclusts = 2, ncenters = 20, interval_length = 0.3, epsilon = 1e-6, ndraw = 1000, ntune=1000)

point={'beta_clust0': array([0.]), 'beta_stage0': array([0., 0., 0.]), 'beta_age': array(0.), 'beta_chemo': array(0.), 'beta_rt': array(0.), 'beta_brachy': array(0.), 'mu_clust_log__': array([[5.68196172, 2.24199047, 4.48661082, 1.18186252, 2.61167384,
        1.02248025, 5.59179312, 4.53887048, 2.45747758, 1.10592862,
        5.5459417 , 4.01413411, 3.62124397, 2.01490522, 4.83820893,
        2.37418687, 5.05412827, 2.97242243, 4.77933392, 3.23039384,
        6.15899309, 4.37556219, 6.12371957, 3.82459632, 5.22114322,
        2.75491914],
       [5.68196172, 2.24199047, 4.48661082, 1.18186252, 2.61167384,
        1.02248025, 5.59179312, 4.53887048, 2.45747758, 1.10592862,
        5.5459417 , 4.01413411, 3.62124397, 2.01490522, 4.83820893,
        2.37418687, 5.05412827, 2.97242243, 4.77933392, 3.23039384,
        6.15899309, 4.37556219, 6.12371957, 3.82459632, 5.22114322,
        2.75491914]]), 'props_simplex__': array([0.]), 'clust': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

Sequential sampling (1 chains in 1 job)
CompoundStep
>NUTS: [beta_clust0, beta_stage0, beta_age, beta_chemo, beta_rt, beta_brachy, mu_clust, props, theta_a, theta_b, lambda0]
>BinaryGibbsMetropolis: [clust]


Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 584 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


In [53]:
pm.summary(pfs_outputs['trace'], var_names=["props", "beta_clust", "beta_stage", "beta_age", "beta_chemo", "beta_rt", "beta_brachy"])



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
props[0],0.494,0.036,0.427,0.562,0.001,0.001,787.0,739.0,
props[1],0.506,0.036,0.438,0.573,0.001,0.001,787.0,739.0,
beta_clust[0],0.0,0.0,0.0,0.0,0.0,0.0,1000.0,1000.0,
beta_clust[1],-0.43,0.235,-0.887,0.004,0.008,0.006,877.0,771.0,
beta_stage[0],0.0,0.0,0.0,0.0,0.0,0.0,1000.0,1000.0,
beta_stage[1],-0.215,0.833,-1.823,1.182,0.027,0.028,1111.0,444.0,
beta_stage[2],1.509,0.314,0.944,2.091,0.012,0.009,749.0,601.0,
beta_stage[3],2.865,0.43,2.071,3.66,0.018,0.013,569.0,474.0,
beta_age,0.013,0.124,-0.236,0.219,0.004,0.004,1115.0,680.0,
beta_chemo,-0.47,0.346,-1.083,0.196,0.014,0.01,649.0,578.0,


In [59]:
dss_outputs = fit_survcluster_model(dss_inputs, nclusts = 2, ncenters = 20, interval_length = 0.3, epsilon = 1e-6, ndraw = 1000, ntune=1000)

point={'beta_clust0': array([0.]), 'beta_stage0': array([0., 0., 0.]), 'beta_age': array(0.), 'beta_chemo': array(0.), 'beta_rt': array(0.), 'beta_brachy': array(0.), 'mu_clust_log__': array([[5.75870179, 2.31838358, 4.46375287, 1.24355049, 2.60599007,
        1.03637965, 5.60291106, 4.55877398, 2.46794208, 1.15072592,
        5.54594424, 4.0232128 , 3.6484534 , 2.04259638, 4.85501928,
        2.39959025, 5.05330594, 2.98003355, 4.80937621, 3.27254467,
        6.14341007, 4.42421188, 6.11291037, 3.83479628, 5.2219894 ,
        2.75504639],
       [5.75870179, 2.31838358, 4.46375287, 1.24355049, 2.60599007,
        1.03637965, 5.60291106, 4.55877398, 2.46794208, 1.15072592,
        5.54594424, 4.0232128 , 3.6484534 , 2.04259638, 4.85501928,
        2.39959025, 5.05330594, 2.98003355, 4.80937621, 3.27254467,
        6.14341007, 4.42421188, 6.11291037, 3.83479628, 5.2219894 ,
        2.75504639]]), 'props_simplex__': array([0.]), 'clust': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

Sequential sampling (1 chains in 1 job)
CompoundStep
>NUTS: [beta_clust0, beta_stage0, beta_age, beta_chemo, beta_rt, beta_brachy, mu_clust, props, theta_a, theta_b, lambda0]
>BinaryGibbsMetropolis: [clust]


Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 755 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


In [60]:
pm.summary(dss_outputs['trace'], var_names=["props", "beta_clust", "beta_stage", "beta_age", "beta_chemo", "beta_rt", "beta_brachy"])



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
props[0],0.513,0.038,0.44,0.58,0.002,0.001,551.0,559.0,
props[1],0.487,0.038,0.42,0.56,0.002,0.001,551.0,559.0,
beta_clust[0],0.0,0.0,0.0,0.0,0.0,0.0,1000.0,1000.0,
beta_clust[1],-0.481,0.287,-0.97,0.124,0.01,0.008,865.0,628.0,
beta_stage[0],0.0,0.0,0.0,0.0,0.0,0.0,1000.0,1000.0,
beta_stage[1],0.099,1.364,-2.207,2.542,0.052,0.051,806.0,476.0,
beta_stage[2],1.618,0.399,0.865,2.297,0.016,0.012,618.0,577.0,
beta_stage[3],3.449,0.54,2.494,4.444,0.024,0.017,521.0,641.0,
beta_age,0.093,0.14,-0.168,0.365,0.005,0.004,849.0,618.0,
beta_chemo,-0.991,0.402,-1.814,-0.305,0.015,0.011,739.0,569.0,


## Output traces

In [81]:
def create_trace_table(trace, export_variables = ['props', 'beta_clust', 'beta_stage', 'beta_age', 'beta_chemo', 'beta_rt', 'beta_brachy']):
    # Create an empty list to store DataFrames
    export_dfs = []
    
    # Iterate over the list of column names
    for variable in export_variables:
        #print(variable)
        posterior = trace.posterior[variable][0]
        
        if posterior.ndim == 2:
            posterior_clust_assignments = pd.DataFrame(posterior, columns=[f'{variable}_{i}' for i in range(posterior.shape[1])])
        elif posterior.ndim == 1:
            posterior_clust_assignments = pd.DataFrame(posterior, columns=[variable])
        else:
            print("Should not get here.")
        
        # Append the DataFrame to the list
        export_dfs.append(posterior_clust_assignments)
    
    # Concatenate the DataFrames horizontally (column bind)
    trace_table = pd.concat(export_dfs, axis=1)
    return trace_table

def extract_cluster_assignments(trace, sample_names):
    cluster_df = pd.DataFrame(trace.posterior['clust'][0].T)
    cluster_df.index = sample_names
    return cluster_df.reset_index()

In [83]:
os_trace_table = create_trace_table(os_outputs['trace'])
os_cluster_assignments = extract_cluster_assignments(os_outputs['trace'], sample_names = os_inputs['counts'].index)

pfs_trace_table = create_trace_table(pfs_outputs['trace'])
pfs_cluster_assignments = extract_cluster_assignments(pfs_outputs['trace'], sample_names = pfs_inputs['counts'].index)

dss_trace_table = create_trace_table(dss_outputs['trace'])
dss_cluster_assignments = extract_cluster_assignments(dss_outputs['trace'], sample_names = dss_inputs['counts'].index)

In [91]:
trace_output_dir = here('results/survival_cluster/traces')
#model_output_dir = here('results/survival_cluster/models')

os_trace_table.to_csv(os.path.join(trace_output_dir, 'os_trace_table.tsv'), sep='\t')
os_cluster_assignments.to_csv(os.path.join(trace_output_dir, 'os_cluster_assignments.tsv'), sep='\t')

pfs_trace_table.to_csv(os.path.join(trace_output_dir, 'pfs_trace_table.tsv'), sep='\t')
pfs_cluster_assignments.to_csv(os.path.join(trace_output_dir, 'pfs_cluster_assignments.tsv'), sep='\t')

dss_trace_table.to_csv(os.path.join(trace_output_dir, 'dss_trace_table.tsv'), sep='\t')
dss_cluster_assignments.to_csv(os.path.join(trace_output_dir, 'dss_cluster_assignments.tsv'), sep='\t')

# Output arviz inferencedata objects
os_outputs['trace'].to_netcdf(os.path.join(trace_output_dir, 'os_results.nc'))
pfs_outputs['trace'].to_netcdf(os.path.join(trace_output_dir, 'pfs_results.nc'))
dss_outputs['trace'].to_netcdf(os.path.join(trace_output_dir, 'dss_results.nc'))

'/Users/alzhang/Documents/projects/tfri_halo/results/survival_cluster/traces/dss_results.nc'

In [41]:
# ## ATTEMPT TO MARGINALIZE OUT LATENT VARIABLE

# with pm.Model(coords = coords) as survival_mixture_model:
#     # Priors for survival coefficients
    
#     beta_clust0 = pm.Normal("beta_clust0", mu=0, sigma=5, shape=nclusts-1)
#     beta_clust = pm.Deterministic("beta_clust", pm.math.concatenate([[0], beta_clust0]))
#     beta_stage0 = pm.Normal("beta_stage0", mu=0, sigma=5, shape=nstages-1)
#     beta_stage = pm.Deterministic("beta_stage", pm.math.concatenate([[0], beta_stage0]))
#     beta_age = pm.Normal("beta_age", mu=0, sigma=5)
#     beta_chemo = pm.Normal("beta_chemo", mu = 0, sigma = 5)
#     beta_rt = pm.Normal("beta_rt", mu = 0, sigma = 5)
#     beta_brachy = pm.Normal("beta_brachy", mu = 0, sigma = 5)

#     # Priors for count coefficient
#     mu_clust = pm.Gamma("mu_clust", mu = mean_mu, sigma = 100, shape = (nclusts, ncelltypes))
    
#     # Latent categorical variable for 'clust'
#     props = pm.Dirichlet('props', np.ones(nclusts))

#     # NB distribution for count data using adjusted area and mu_clust; Nsamp x Ncelltype x Nclust
#     count_mu = np.repeat(area_nonzero[:,:,np.newaxis], nclusts, axis = 2) * mu_clust.transpose(1, 0)

#     # RBF based NB dispersion 
#     theta_a = pm.Normal("theta_a", mu = 0, sigma=1, shape=ncenters)
#     theta_b = pm.Normal("theta_b", mu = 0, sigma=1, shape=ncenters)

#     count_disp = T.dot(T.exp(-T.exp(theta_b) * (T.reshape(T.repeat(count_mu, repeats=ncenters), newshape=(nsamples, ncelltypes, nclusts, ncenters)) - centers)**2), T.exp(theta_a)) + epsilon
    
#     # Likelihood for count data
#     comp_countdists = pm.NegativeBinomial.dist(mu=count_mu, alpha=count_disp)
#     count_v = pm.Mixture("count_v", w=props, comp_dists=comp_countdists, observed=count_mat)
    
#     # Don't forget lambda_0 AFTERWARDS ADD THIS -- need to fix intercepts and dimensions as a result
#     lambda0 = pm.Gamma("lambda0", 0.1, 0.1, dims = "intervals") # 0.01, 0.05, 0.1
    
#     linear_predictor_init = (
#         beta_stage[stage]
#         + beta_age * age
#         + beta_chemo * chemo
#         + beta_rt * rt
#         + beta_brachy * brachy
#     )
#     linear_predictor = np.repeat(linear_predictor_init[:,np.newaxis], nclusts, axis=1) + beta_clust

#     lambda_ = pm.Deterministic("lambda_", T.reshape(T.repeat(T.exp(linear_predictor), repeats=nintervals), newshape=(nsamples,nclusts,nintervals)) * lambda0)

#     mu = pm.Deterministic("mu", np.repeat(exposure[:,:,np.newaxis], nclusts, axis=2) * lambda_.transpose(0, 2, 1))

#     # Poisson-Cox PH equivalence
#     comp_cox = pm.Poisson.dist(mu=mu)
    
#     obs = pm.Mixture('obs', w=props, comp_dists=comp_cox, observed=death)

#     #survival_mixture_model.debug(verbose = True)

#     trace2 = pm.sample(100, tune=100, chains=1, nuts_sampler="nutpie", progressbar=True) #nuts_sampler="numpyro", progressbar=True)

#     #trace = pm.sample(100, tune=100, chains=1)

Only 100 samples in chain.
  def numba_funcified_fgraph(scalar_variable, scalar_variable_3, scalar_variable_9, scalar_variable_20, scalar_variable_22, scalar_variable_26):
  return inner(x)


In [261]:
# with pm.Model(coords = coords) as recover_model:
#     beta_clust0 = pm.Normal("beta_clust0", mu=0, sigma=5, shape=nclusts-1)
#     beta_clust = pm.Deterministic("beta_clust", pm.math.concatenate([[0], beta_clust0]))
#     beta_stage0 = pm.Normal("beta_stage0", mu=0, sigma=5, shape=nstages-1)
#     beta_stage = pm.Deterministic("beta_stage", pm.math.concatenate([[0], beta_stage0]))
#     beta_age = pm.Normal("beta_age", mu=0, sigma=5)
#     beta_chemo = pm.Normal("beta_chemo", mu = 0, sigma = 5)
#     beta_rt = pm.Normal("beta_rt", mu = 0, sigma = 5)
#     beta_brachy = pm.Normal("beta_brachy", mu = 0, sigma = 5)

#     # Priors for count coefficient
#     mu_clust = pm.Gamma("mu_clust", mu = mean_mu, sigma = 100, shape = (nclusts, ncelltypes))
    
#     # Latent categorical variable for 'clust'
#     props = pm.Dirichlet('props', np.ones(nclusts))

#     # NB distribution for count data using adjusted area and mu_clust
#     count_mu = np.repeat(area_nonzero[:,:,np.newaxis], nclusts, axis = 2) * mu_clust.transpose(1, 0)

#     # RBF based NB dispersion 
#     theta_a = pm.Normal("theta_a", mu = 0, sigma=1, shape=ncenters)
#     theta_b = pm.Normal("theta_b", mu = 0, sigma=1, shape=ncenters)

#     count_disp = T.dot(T.exp(-T.exp(theta_b) * (T.reshape(T.repeat(count_mu, repeats=ncenters), newshape=(nsamples, ncelltypes, nclusts, ncenters)) - centers[np.newaxis, np.newaxis, np.newaxis, :])**2), T.exp(theta_a)) + epsilon
    
#     # Likelihood for count data
#     comp_countdists = pm.NegativeBinomial.dist(mu=count_mu, alpha=count_disp)
    
#     # Don't forget lambda_0 AFTERWARDS ADD THIS -- need to fix intercepts and dimensions as a result
#     lambda0 = pm.Gamma("lambda0", 0.1, 0.1, dims = "intervals") # 0.01, 0.05, 0.1
    
#     linear_predictor = (
#         beta_stage[stage]
#         + beta_age * age
#         + beta_chemo * chemo
#         + beta_rt * rt
#         + beta_brachy * brachy
#     )
#     linear_predictor = np.repeat(linear_predictor[:,np.newaxis], nclusts, axis=1) + beta_clust

#     lambda_ = pm.Deterministic("lambda_", T.reshape(T.repeat(T.exp(linear_predictor), repeats=len(intervals)), newshape=(nsamples,nclusts,nintervals)) * lambda0)

#     mu = pm.Deterministic("mu", np.repeat(exposure[:,:,np.newaxis], nclusts, axis=2) * lambda_.transpose(0, 2, 1))

#     # Poisson-Cox PH equivalence
#     comp_cox = pm.Poisson.dist(mu=mu)

#     log_probs = T.sum(pm.logp(comp_countdists, np.repeat(count_mat[:,:,np.newaxis], nclusts, axis = 2)), axis = 1) + T.sum(pm.logp(comp_cox, np.repeat(death[:,:,np.newaxis], nclusts, axis = 2)), axis = 1) + T.log(props)

#     idx = pm.Categorical("idx", logit_p=log_probs)

#     recover_model.debug(verbose=True)

#     pp = pm.sample_posterior_predictive(trace, var_names=['idx'])

[261   3]
Add.0
point={'beta_clust0': array([0., 0.]), 'beta_stage0': array([0., 0., 0., 0.]), 'beta_age': array(0.), 'beta_chemo': array(0.), 'beta_rt': array(0.), 'beta_brachy': array(0.), 'mu_clust_log__': array([[5.62589496, 2.3172128 , 4.40131308, 1.12417924, 2.55557906,
        0.9915293 , 5.55979347, 4.52830928, 2.58190128, 1.32119953,
        5.50099733, 4.06111265, 3.57283163, 2.02989106, 4.82881983,
        2.40047905, 5.04023114, 3.03233591, 4.78287891, 3.28191928,
        6.12732447, 4.46113984, 6.04317306, 3.84333015, 5.23025371,
        2.84658753],
       [5.62589496, 2.3172128 , 4.40131308, 1.12417924, 2.55557906,
        0.9915293 , 5.55979347, 4.52830928, 2.58190128, 1.32119953,
        5.50099733, 4.06111265, 3.57283163, 2.02989106, 4.82881983,
        2.40047905, 5.04023114, 3.03233591, 4.78287891, 3.28191928,
        6.12732447, 4.46113984, 6.04317306, 3.84333015, 5.23025371,
        2.84658753],
       [5.62589496, 2.3172128 , 4.40131308, 1.12417924, 2.55557906,
 

Sampling: [idx]


No problems found


In [204]:
pm.summary(trace, var_names=["props", "beta_clust", "beta_stage", "beta_age", "beta_chemo", "beta_rt", "beta_brachy"])



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
props[0],0.366,0.056,0.26,0.45,0.03,0.023,5.0,13.0,
props[1],0.141,0.095,0.043,0.345,0.062,0.05,4.0,13.0,
props[2],0.493,0.049,0.396,0.548,0.032,0.026,3.0,14.0,
beta_clust[0],0.0,0.0,0.0,0.0,0.0,0.0,100.0,100.0,
beta_clust[1],-4.07,2.18,-7.294,-0.725,0.808,0.596,7.0,21.0,
beta_clust[2],-4.985,1.537,-7.198,-2.419,0.528,0.391,9.0,21.0,
beta_stage[0],0.0,0.0,0.0,0.0,0.0,0.0,100.0,100.0,
beta_stage[1],0.842,0.699,-0.323,2.125,0.149,0.107,22.0,52.0,
beta_stage[2],1.195,0.319,0.541,1.728,0.047,0.043,38.0,117.0,
beta_stage[3],2.404,0.353,1.839,3.107,0.121,0.092,9.0,19.0,


In [20]:
pm.summary(trace, var_names=["props", "beta_clust", "beta_stage", "beta_age", "beta_chemo", "beta_rt", "beta_brachy"])



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
props[0],0.508,0.038,0.439,0.583,0.003,0.002,215.0,202.0,
props[1],0.492,0.038,0.417,0.561,0.003,0.002,215.0,202.0,
beta_clust[0],0.0,0.0,0.0,0.0,0.0,0.0,500.0,500.0,
beta_clust[1],-0.451,0.23,-0.897,-0.064,0.014,0.01,283.0,327.0,
beta_stage[0],0.0,0.0,0.0,0.0,0.0,0.0,500.0,500.0,
beta_stage[1],0.689,0.677,-0.631,1.766,0.033,0.031,415.0,315.0,
beta_stage[2],1.056,0.321,0.479,1.624,0.022,0.016,203.0,262.0,
beta_stage[3],2.899,0.427,2.153,3.673,0.035,0.025,147.0,219.0,
beta_age,0.21,0.121,-0.005,0.446,0.008,0.006,232.0,203.0,
beta_chemo,-0.867,0.31,-1.414,-0.271,0.025,0.018,151.0,209.0,


In [185]:
np.arange(nclusts)

array([0, 1, 2])

In [258]:
pm.summary(pp, group="posterior_predictive")



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
idx[0],1.06,0.993,0.0,2.0,0.183,0.132,30.0,31.0,
idx[1],0.33,0.473,0.0,1.0,0.168,0.123,8.0,8.0,
idx[2],1.82,0.386,1.0,2.0,0.212,0.166,3.0,3.0,
idx[3],0.62,0.885,0.0,2.0,0.475,0.369,3.0,3.0,
idx[4],0.96,0.887,0.0,2.0,0.091,0.076,94.0,20.0,
...,...,...,...,...,...,...,...,...,...
idx[256],0.41,0.637,0.0,2.0,0.297,0.225,4.0,3.0,
idx[257],1.90,0.302,1.0,2.0,0.118,0.087,7.0,7.0,
idx[258],1.37,0.906,0.0,2.0,0.120,0.085,58.0,54.0,
idx[259],1.79,0.409,1.0,2.0,0.242,0.192,3.0,3.0,


In [262]:
pm.summary(pp, group="posterior_predictive")



Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
idx[0],0.30,0.704,0.0,2.0,0.119,0.086,36.0,37.0,
idx[1],0.21,0.409,0.0,1.0,0.129,0.094,10.0,10.0,
idx[2],1.83,0.378,1.0,2.0,0.202,0.157,3.0,3.0,
idx[3],0.74,0.917,0.0,2.0,0.455,0.348,4.0,3.0,
idx[4],0.74,0.860,0.0,2.0,0.128,0.095,45.0,22.0,
...,...,...,...,...,...,...,...,...,...
idx[256],0.32,0.566,0.0,1.0,0.241,0.181,5.0,3.0,
idx[257],1.90,0.302,1.0,2.0,0.120,0.089,6.0,6.0,
idx[258],1.46,0.858,0.0,2.0,0.119,0.088,50.0,55.0,
idx[259],1.79,0.409,1.0,2.0,0.242,0.192,3.0,3.0,


In [263]:
pp

In [286]:
trace

In [302]:
os_inputs['clinical']['stage_idx']

acc_num
02S-2772      2
02S-39903     1
1621017269    0
1621020349    0
1621020869    2
             ..
VS17-3680     3
VS17-3718     0
VS17-4669     1
VS17-57       2
VS17-5945     2
Name: stage_idx, Length: 261, dtype: int64

In [37]:
tmp1 = np.reshape(np.repeat([[0, 1, 2], [3,4,5]], repeats=4), (2,3,4))
tmp1[:,:,0]

array([[0, 1, 2],
       [3, 4, 5]])