### Piecewise linear models with pystan

Based on the following tutorial from Jan Vanhove: \
https://janhove.github.io/analysis/2018/07/04/bayesian-breakpoint-model

Useful reference when going between rstan and pystan: \
https://pystan.readthedocs.io/en/latest/differences_pystan_rstan.html

In [None]:
import numpy as np
import pandas as pd
import pystan as ps
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import plotly.express as px
import glob
import arviz
%matplotlib inline
import tqdm
import matplotlib
pd.set_option('display.max_columns', None)  
# Customize matplotlib
matplotlib.rcParams['font.family'] = 'monospace'
matplotlib.rcParams['font.sans-serif'] = ['Ubuntu']

In [None]:
# load the 10xv3 results with 30x sampling for each cell/depth combination
dfs={}
for item in glob.glob('./10xv3_final_summaries/*'):
    dfs[item.split('/')[2].split('-final_summary.csv')[0]] = pd.read_csv(item).sort_values(["sampled_cells", "total_UMIs"], ascending = (True, True))

In [None]:
stan_model = ps.StanModel(file="seqdepth_piecewise_v1.stan", 
                          model_name = "seqdepth_piecewise_v1")


In [None]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.expand_frame_repr', False)


In [None]:
results={}
for dataset in dfs:
    print(dataset)
    ddf = dfs[dataset]
    results[dataset]={}
    for ncells in ddf.sampled_cells.unique():
        print(dataset, ncells)
        df=ddf[ddf['sampled_cells']==ncells]
        data_dict = {"umis_per_cell": np.log(df["UMIs_per_cell"]), "validation_error": np.log(df["validation_error"]), "N": len(df)}
        stan_fit = stan_model.sampling(data=data_dict,
                               iter=5000,
                              control={'adapt_delta':1, 'max_treedepth': 12})
        
        
        s = stan_fit.summary()
        summary = pd.DataFrame(s['summary'], columns=s['summary_colnames'], index=s['summary_rownames'])
        display(summary.head(10))
        results[dataset][ncells]=summary
        arviz.plot_trace(stan_fit,['slope_before','slope_after','intercept','bp','before_variance','after_variance','slope_ratio'])
        plt.savefig('./results1/' + dataset+'-'+str(ncells)+'.png',format='png',dpi=200)
        
        
        full_stan_results = stan_fit.to_dataframe()
        full_stan_results.to_csv('./results1/full_stan_' + dataset+'-'+str(ncells)+'.csv')
        summary.to_csv('./results1/summary_stan_' + dataset+'-'+str(ncells)+'.csv')
        plt.show()
        
        summary_head=summary.head(6).copy()
        summary_head['2.5%UMI']=np.exp(summary_head['2.5%']).astype(int)
        summary_head['97.5%UMI']=np.exp(summary_head['97.5%']).astype(int)
        summary_head['50%UMI']=np.exp(summary_head['50%']).astype(int)
        summary_head=summary_head[['mean', 'se_mean', 'sd', '2.5%','2.5%UMI', '25%', '50%',  '50%UMI', '75%', '97.5%','97.5%UMI',  'n_eff', 'Rhat']]
        summary_text = str(summary_head.head(6).round(3))
        
        extracted = stan_fit.extract()
        plt.figure(num=None, figsize=(16, 9), dpi=80, facecolor='w', edgecolor='k')
        plt.ylim(5, 9)
        plt.xlim(6, 11)
        nlines = 8000
        for intercept,slope,breakpoint in zip(extracted['intercept'][0:nlines],extracted['slope_before'][0:nlines],extracted['bp'][0:nlines]):
            axes = plt.gca()
            x_vals = np.array(axes.get_xlim())
            x_vals[1]=breakpoint
            y_vals = intercept + slope * (x_vals-breakpoint)
            plt.plot(x_vals, y_vals, '--', color = 'red', alpha = 0.01)

        for intercept,slope,breakpoint in zip(extracted['intercept'][0:nlines],extracted['slope_after'][0:nlines],extracted['bp'][0:nlines]):
            axes = plt.gca()
            x_vals = np.array(axes.get_xlim())
            x_vals[0]=breakpoint
            y_vals = intercept + slope * (x_vals-breakpoint)
            plt.plot(x_vals, y_vals, '-', color='green',linestyle='--', alpha = 0.002)
            plt.axvline(x=breakpoint,linestyle='--',color='black', alpha =0.002)
            plt.scatter(breakpoint,intercept, color = 'blue', alpha = 0.002, s =10)

        # plt.scatter(df['UMIs_per_cell'],df['validation_error'], alpha = 0.25)
        plt.scatter(np.log(df["UMIs_per_cell"]),  np.log(df["validation_error"]), color='cyan', alpha = 0.5, s = 10)
        title='Piecewise stan model \n '+ dataset+' '+str(ncells) + ' cells \n ' + summary_text
        plt.title(title)
        plt.grid(True)
        plt.savefig('./results1/linefits_' + dataset+'-'+str(ncells)+'.png',format='png',dpi=200)

        plt.show()



In [None]:
results={}
for dataset in dfs:
    print(dataset)
    ddf = dfs[dataset]
    results[dataset]={}
    for ncells in df.sampled_cells.unique():
        print(dataset, ncells)
        df=ddf[ddf['sampled_cells']==ncells]
        data_dict = {"umis_per_cell": np.log(df["UMIs_per_cell"]), "validation_error": np.log(df["validation_error"]), "N": len(df)}
        stan_fit = stan_model.sampling(data=data_dict,
                               iter=5000,
                              control={'adapt_delta':1})
        s = stan_fit.summary()
        summary = pd.DataFrame(s['summary'], columns=s['summary_colnames'], index=s['summary_rownames'])
        display(summary.head(10))
        results[dataset][ncells]=summary
        arviz.plot_trace(stan_fit,['slope_before','slope_after','intercept','bp','before_variance','after_variance','slope_ratio'])
        plt.savefig('./results1/' + dataset+'-'+str(ncells)+'.png',format='png',dpi=200)
        
        
        full_stan_results = stan_fit.to_dataframe()
        full_stan_results.to_csv('./results1/full_stan_' + dataset+'-'+str(ncells)+'.csv')
        summary.to_csv('./results1/summary_stan_' + dataset+'-'+str(ncells)+'.csv')
        plt.show()
        
        summary_head=summary.head(6).copy()
        summary_head['2.5%UMI']=np.exp(summary_head['2.5%']).astype(int)
        summary_head['97.5%UMI']=np.exp(summary_head['97.5%']).astype(int)
        summary_head['50%UMI']=np.exp(summary_head['50%']).astype(int)
        summary_head=summary_head[['mean', 'se_mean', 'sd', '2.5%','2.5%UMI', '25%', '50%',  '50%UMI', '75%', '97.5%','97.5%UMI',  'n_eff', 'Rhat']]
        summary_text = str(summary_head.head(6).round(3))
        
        extracted = stan_fit.extract()
        plt.figure(num=None, figsize=(16, 9), dpi=80, facecolor='w', edgecolor='k')
        plt.ylim(5, 9)
        plt.xlim(6, 11)
        nlines = 8000
        for intercept,slope,breakpoint in zip(extracted['intercept'][0:nlines],extracted['slope_before'][0:nlines],extracted['bp'][0:nlines]):
            axes = plt.gca()
            x_vals = np.array(axes.get_xlim())
            x_vals[1]=breakpoint
            y_vals = intercept + slope * (x_vals-breakpoint)
            plt.plot(x_vals, y_vals, '--', color = 'red', alpha = 0.01)

        for intercept,slope,breakpoint in zip(extracted['intercept'][0:nlines],extracted['slope_after'][0:nlines],extracted['bp'][0:nlines]):
            axes = plt.gca()
            x_vals = np.array(axes.get_xlim())
            x_vals[0]=breakpoint
            y_vals = intercept + slope * (x_vals-breakpoint)
            plt.plot(x_vals, y_vals, '-', color='green', alpha = 0.002)
            plt.axvline(x=breakpoint,linestyle='--',color='black', alpha =0.002)
            plt.scatter(breakpoint,intercept, color = 'blue', alpha = 0.002, s =10)

        # plt.scatter(df['UMIs_per_cell'],df['validation_error'], alpha = 0.25)
        plt.scatter(np.log(df["UMIs_per_cell"]),  np.log(df["validation_error"]), color='cyan', alpha = 0.5, s = 10)
        title='Piecewise stan model \n '+ dataset+' '+str(ncells) + ' cells \n ' + summary_text
        plt.title(title)
        plt.grid(True)
        plt.savefig('./results1/linefits_' + dataset+'-'+str(ncells)+'.png',format='png',dpi=200)

        plt.show()



In [None]:
s = stan_fit.summary()
summary = pd.DataFrame(s['summary'], columns=s['summary_colnames'], index=s['summary_rownames'])
summary

In [None]:
arviz.plot_trace(stan_fit,['slope_before','slope_after','intercept','bp','before_variance','after_variance','slope_ratio'])
plt.show()

In [None]:
stan_fit.to_dataframe()

In [None]:
stan_fit.traceplot()

In [None]:
dope = stan_fit.extract()


In [None]:
dope.keys()

In [None]:
for thing in ['slope_before','slope_after','intercept','bp','before_variance','after_variance']:
    print(thing,np.std(dope[thing]))
    plt.hist(dope[thing], bins = 100)
    plt.title(thing)
    plt.show()

In [None]:
plt.figure(num=None, figsize=(16, 9), dpi=80, facecolor='w', edgecolor='k')
plt.ylim(5, 9)
plt.xlim(6, 11)
nlines = 8000
for intercept,slope,breakpoint in zip(dope['intercept'][0:nlines],dope['slope_before'][0:nlines],dope['bp'][0:nlines]):
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    x_vals[1]=breakpoint
    y_vals = intercept + slope * (x_vals-breakpoint)
    plt.plot(x_vals, y_vals, '--', color = 'red', alpha = 0.01)
    
for intercept,slope,breakpoint in zip(dope['intercept'][0:nlines],dope['slope_after'][0:nlines],dope['bp'][0:nlines]):
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    x_vals[0]=breakpoint
    y_vals = intercept + slope * (x_vals-breakpoint)
    plt.plot(x_vals, y_vals, '-', color='green', alpha = 0.01)
    plt.axvline(x=breakpoint,linestyle='-',color='black', alpha =0.01)
    plt.scatter(breakpoint,intercept, color = 'black', alpha = 0.1, s =10)
    
# plt.scatter(df['UMIs_per_cell'],df['validation_error'], alpha = 0.25)
plt.scatter(np.log(df["UMIs_per_cell"]),  np.log(df["validation_error"]), alpha = 0.25)

plt.title('Piecewise model with breakpoint '+ str(nlines)+' lines')
plt.show()


In [None]:
summary.head(6)