### Piecewise linear models with pystan

Based on the following tutorial from Jan Vanhove: \
https://janhove.github.io/analysis/2018/07/04/bayesian-breakpoint-model

Useful reference when going between rstan and pystan: \
https://pystan.readthedocs.io/en/latest/differences_pystan_rstan.html

In [None]:
import numpy as np
import pandas as pd
import pystan as ps
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import plotly.express as px
import glob
import arviz
%matplotlib inline

pd.set_option('display.max_columns', None)  


In [None]:
# load the 10xv3 results with 30x sampling for each cell/depth combination
dfs={}
for item in glob.glob('./10xv3_final_summaries/*'):
    dfs[item.split('/')[2].split('-final_summary.csv')[0]] = pd.read_csv(item).sort_values(["sampled_cells", "total_UMIs"], ascending = (True, True))

In [None]:
df = dfs['10x_genomics_data-pbmc_1k_v3'][dfs['10x_genomics_data-pbmc_1k_v3']['sampled_cells']==999].copy()
df

In [None]:
df.plot.scatter('UMIs_per_cell', 'validation_error', alpha = 0.1)
plt.show()

In [None]:
stan_model = ps.StanModel(file="seqdepth_piecewise_v1.stan", 
                          model_name = "seqdepth_piecewise_v1")


In [None]:

data_dict = {"umis_per_cell": np.log(df["UMIs_per_cell"]), "validation_error": np.log(df["validation_error"]), "N": len(df)}
stan_fit = stan_model.sampling(data=data_dict,
                               iter=4000,
                              control={'adapt_delta':1})

In [None]:
s = stan_fit.summary()
summary = pd.DataFrame(s['summary'], columns=s['summary_colnames'], index=s['summary_rownames'])
summary

In [None]:
arviz.plot_trace(stan_fit,['slope_before','slope_after','intercept','bp','before_variance','after_variance','slope_ratio'])
plt.show()

In [None]:
stan_fit.to_dataframe()

In [None]:
stan_fit.traceplot()

In [None]:
dope = stan_fit.extract()


In [None]:
dope.keys()

In [None]:
for thing in ['slope_before','slope_after','intercept','bp','before_variance','after_variance']:
    print(thing,np.std(dope[thing]))
    plt.hist(dope[thing], bins = 100)
    plt.title(thing)
    plt.show()

In [None]:
plt.figure(num=None, figsize=(16, 9), dpi=80, facecolor='w', edgecolor='k')
plt.ylim(5, 9)
plt.xlim(6, 11)
nlines = 8000
for intercept,slope,breakpoint in zip(dope['intercept'][0:nlines],dope['slope_before'][0:nlines],dope['bp'][0:nlines]):
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    x_vals[1]=breakpoint
    y_vals = intercept + slope * (x_vals-breakpoint)
    plt.plot(x_vals, y_vals, '--', color = 'red', alpha = 0.01)
    
for intercept,slope,breakpoint in zip(dope['intercept'][0:nlines],dope['slope_after'][0:nlines],dope['bp'][0:nlines]):
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    x_vals[0]=breakpoint
    y_vals = intercept + slope * (x_vals-breakpoint)
    plt.plot(x_vals, y_vals, '-', color='green', alpha = 0.01)
    plt.axvline(x=breakpoint,linestyle='-',color='black', alpha =0.01)
    plt.scatter(breakpoint,intercept, color = 'black', alpha = 0.1, s =10)
    
# plt.scatter(df['UMIs_per_cell'],df['validation_error'], alpha = 0.25)
plt.scatter(np.log(df["UMIs_per_cell"]),  np.log(df["validation_error"]), alpha = 0.25)

plt.title('Piecewise model with breakpoint '+ str(nlines)+' lines')
plt.show()


In [None]:
summary.head(6)