### Piecewise linear models with pystan

Based on the following tutorial from Jan Vanhove: \
https://janhove.github.io/analysis/2018/07/04/bayesian-breakpoint-model

Useful reference when going between rstan and pystan: \
https://pystan.readthedocs.io/en/latest/differences_pystan_rstan.html

In [None]:
import numpy as np
import pandas as pd
import pystan as ps
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import plotly.express as px
import glob
import arviz
%matplotlib inline
import tqdm
import matplotlib
pd.set_option('display.max_columns', None)  
# Customize matplotlib
matplotlib.rcParams['font.family'] = 'monospace'
matplotlib.rcParams['font.sans-serif'] = ['Ubuntu']

In [None]:
# load the 10xv3 results with 30x sampling for each cell/depth combination
dfs={}
for item in glob.glob('./10xv3_final_summaries/*'):
    dfs[item.split('/')[2].split('-final_summary.csv')[0]] = pd.read_csv(item).sort_values(["sampled_cells", "total_UMIs"], ascending = (True, True))

In [None]:
stan_model = ps.StanModel(file="seqdepth_piecewise_v2.stan", 
                          model_name = "seqdepth_piecewise_v2")


In [None]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.expand_frame_repr', False)


In [None]:
results={}
for dataset in dfs:
    print(dataset)
    ddf = dfs[dataset]
    results[dataset]={}
    for ncells in ddf.sampled_cells.unique():
        print(dataset, ncells)
        df=ddf[ddf['sampled_cells']==ncells]
        data_dict = {"umis_per_cell": np.log2(df["UMIs_per_cell"]), "validation_error": np.log2(df["validation_error"]), "N": len(df)}
        stan_fit = stan_model.sampling(data=data_dict,
                               iter=4000,
                              control={'adapt_delta':1, 'max_treedepth': 12},
                                      )
        
        
        s = stan_fit.summary()
        summary = pd.DataFrame(s['summary'], columns=s['summary_colnames'], index=s['summary_rownames'])
        summary_head=pd.concat([summary.head(6),summary.iloc[-5:-1]]).copy()
        display(summary_head)

        results[dataset][ncells]=summary
        arviz.plot_trace(stan_fit,['slope_before','slope_after','intercept','bp',
                                   'before_variance','after_variance','after_over_before',
                                   'before_over_after', 'bp_umis'])
        plt.savefig('./results2/' + dataset+'-'+str(ncells)+'.png',format='png',dpi=200)
        
        
        full_stan_results = stan_fit.to_dataframe()
        full_stan_results.to_csv('./results2/full_stan_' + dataset+'-'+str(ncells)+'.csv')
        summary.to_csv('./results2/summary_stan_' + dataset+'-'+str(ncells)+'.csv')
        plt.show()
        
        summary_text = str(summary_head.round(3))
        
        extracted = stan_fit.extract()
        plt.figure(num=None, figsize=(16, 9), dpi=80, facecolor='w', edgecolor='k')
        plt.ylim(8, 12)
        plt.xlim(8, 15)
        nlines = 4000
        for intercept,slope,breakpoint in zip(extracted['intercept'][0:nlines],extracted['slope_before'][0:nlines],extracted['bp'][0:nlines]):
            axes = plt.gca()
            x_vals = np.array(axes.get_xlim())
            x_vals[1]=breakpoint
            y_vals = intercept + slope * (x_vals-breakpoint)
            plt.plot(x_vals, y_vals, '--', color = 'red', alpha = 0.01)

        for intercept,slope,breakpoint in zip(extracted['intercept'][0:nlines],extracted['slope_after'][0:nlines],extracted['bp'][0:nlines]):
            axes = plt.gca()
            x_vals = np.array(axes.get_xlim())
            x_vals[0]=breakpoint
            y_vals = intercept + slope * (x_vals-breakpoint)
            plt.plot(x_vals, y_vals, '-', color='green',linestyle='--', alpha = 0.002)
            plt.axvline(x=breakpoint,linestyle='--',color='gray', alpha =0.002)

            
        # plt.scatter(df['UMIs_per_cell'],df['validation_error'], alpha = 0.25)
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        x_vals[1]=summary_head.loc['bp']['50%']
        y_vals = summary_head.loc['intercept']['50%'] + summary_head.loc['slope_before']['50%'] * (x_vals-summary_head.loc['bp']['50%'])
        plt.plot(x_vals, y_vals, '-', color = 'black', alpha = 1)

        x_vals = np.array(axes.get_xlim())
        x_vals[0]=summary_head.loc['bp']['50%']
        y_vals = summary_head.loc['intercept']['50%'] + summary_head.loc['slope_after']['50%'] * (x_vals-summary_head.loc['bp']['50%'])
        plt.plot(x_vals, y_vals, '-', color = 'blue', alpha = 1)
        plt.axvline(x=summary_head.loc['bp']['50%'],linestyle='--',color='black', alpha =1)

        plt.scatter(np.log2(df["UMIs_per_cell"]),  np.log2(df["validation_error"]), color='black', alpha = 0.5, s = 10)

        plt.gca().set_aspect('equal', adjustable='box')

        print(plt.yticks())
        plt.yticks([8,9,10,11,12])
        plt.grid(True, ls = '--', alpha = 0.5)
        plt.scatter(np.log2(df["UMIs_per_cell"]),  np.log2(df["validation_error"]), color='cyan', alpha = 0.5, s = 10)
        title='Piecewise stan model \n '+ dataset+' '+str(ncells) + ' cells \n ' + summary_text
        plt.title(title)
        plt.savefig('./results2/linefits_' + dataset+'-'+str(ncells)+'.png',format='png',dpi=200)

        plt.show()

