In [None]:
#!/usr/bin/env python
# encoding: utf-8

import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt 
import numpy as np
import matplotlib as mpl
from concurrent import futures
from functools import partial
from tqdm import tqdm
mpl.rcParams['figure.dpi'] = 60
import yaml
import math
import os
import assembly

out_dir='/'
os.makedirs(out_dir,exist_ok=True)
os.chdir(out_dir)

with open("config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

samples= config['spatial_infor'].keys()
infor_df_replace=False
species='mm10'
chrom=list(assembly.build(species, 1)._chromsizes.keys())
min_cluster_size=2
max_cluster_size=1000


In [None]:

def nCr(n, r):
    f = math.factorial
    return f(n) // f(r) // f(n-r)

def get_spot_cis_pct(spot_name,cluster_dir):
    total_contacts=0
    total_contacts_cis=0
    total_contacts_trans=0
    cluster_num=0
    raw_cluster_num=0
    except_num=0
    spot_cluster_file=os.path.join(cluster_dir,spot_name)

    try:
        with open(spot_cluster_file,'r') as f:
            for line in f:
                reads=line.rstrip().split()[1:]
                no_dupl_reads = set(reads)
                cluster_size=len(no_dupl_reads)
                raw_cluster_num+=1
                if min_cluster_size <= cluster_size <= max_cluster_size:
                    cluster_num+=1
                    line_contact= nCr(cluster_size, 2)
                    total_contacts += line_contact

                    chr_no_dupl_reads_num= {x:0 for x in chrom}
                    # print(chr_no_dupl_reads_num)   
                    for read in no_dupl_reads:
                        chr=read.split(']_')[1].split(':')[0]
                        try:
                            chr_no_dupl_reads_num[chr]+=1
                        except:
                            except_num+=1

                    chr_line_contact_cis = sum([nCr(x,2) for x in chr_no_dupl_reads_num.values() if x>=2 ])
                    chr_line_contact_trans=line_contact-chr_line_contact_cis
                    total_contacts_cis += chr_line_contact_cis
                    total_contacts_trans += chr_line_contact_trans
            return round(total_contacts_cis/total_contacts*100,2) 
    except:
        return 0


In [None]:
ncols=5
nrows=math.ceil(len(samples)/ncols)
s_width=5
s_height=6

fig_reads, axs_reads = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_contacts, axs_contacts = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_clusters, axs_clusters = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_cis_pct, axs_cis_pct = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))

fig_log1p_reads, axs_log1p_reads = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_log1p_contacts, axs_log1p_contacts = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_log1p_clusters, axs_log1p_clusters = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))

fig_den_reads, axs_den_reads = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_den_contacts, axs_den_contacts = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_den_clusters, axs_den_clusters = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_den_spot_cis_pct, axs_den_spot_cis_pct = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))

fig_filter_reads, axs_filters_reads = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_filter_contacts, axs_filters_contacts = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_filter_clusters, axs_filters_clusters = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))
fig_filter_cis_pct, axs_filters_cis_pct = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))

fig_filters, axs_filters = plt.subplots(nrows, ncols, figsize=(ncols*s_width, ncols*s_height))

with open("config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

df_all=pd.DataFrame()
df_f_all=pd.DataFrame()

for f,sample_name in enumerate(tqdm(samples)):
    # print(sample_name)
    
    #load spatial infor
    x_bc=config['spatial_infor'][sample_name]['x_bc']
    y_bc=config['spatial_infor'][sample_name]['y_bc']
    x_offset=config['spatial_infor'][sample_name]['x_offset']
    y_offset=config['spatial_infor'][sample_name]['y_offset']
    x_r=config['spatial_infor'][sample_name]['x_r']
    y_r=config['spatial_infor'][sample_name]['y_r']
    odd_start=config['spatial_infor'][sample_name]['odd_start']
    odd_end=config['spatial_infor'][sample_name]['odd_end']
    even_start=config['spatial_infor'][sample_name]['even_start']
    even_end=config['spatial_infor'][sample_name]['even_end']

    min_spot_cis_pct=config['spatial_infor'][sample_name]['min_spot_cis_pct']
    min_log1p_reads=config['spatial_infor'][sample_name]['min_log1p_reads']
    min_log1p_contacts=config['spatial_infor'][sample_name]['min_log1p_contacts']
    min_log1p_clusters=config['spatial_infor'][sample_name]['min_log1p_clusters']

    cluster_dir=os.path.join('/home/spaceA/cluster','clusters_'+sample_name+'_single_filtered')

    spot_file=os.path.join(cluster_dir,'one_last_time.txt')
    if os.path.exists(spot_file):
        pass
    else:
        spot_file=os.path.join(cluster_dir,'one_last_time.tsv')

    df_save_f=sample_name+'_spot_infor.csv'
    if infor_df_replace:
        df=pd.read_table(spot_file)
        df['odd'] = df['Cell Barcode ID'].apply(lambda x:str(int(x.strip().split('.')[0].replace('oddBo', '').replace('oddTop', '').replace('evenBo', '').replace('evenTop', ''))))
        df['even'] = df['Cell Barcode ID'].apply(lambda x:str(int(x.strip().split('.')[1].replace('oddBo', '').replace('oddTop', '').replace('evenBo', '').replace('evenTop', ''))))
        
        df=df[(df['odd'].astype(int) >= odd_start) & (df['odd'].astype(int) <= odd_end)]
        df=df[(df['even'].astype(int) >= even_start) & (df['even'].astype(int) <= even_end)]

        df['x']=df[x_bc].apply(lambda x:int(x)-x_offset)
        df['y']=df[y_bc].apply(lambda x:int(x)-y_offset)

        df['log1p_reads']=df['Num of Reads'].apply(lambda x:np.log1p(x))
        df['log1p_contacts']=df['Num of Contacts'].apply(lambda x:np.log1p(x))
        df['log1p_clusters']=df['Num of Clusters'].apply(lambda x:np.log1p(x))

        df['log10_reads']=df['Num of Reads'].apply(lambda x:np.log10(x) if x > 0 else np.log10(1))
        df['log10_contacts']=df['Num of Contacts'].apply(lambda x:np.log10(x) if x > 0 else np.log10(1))
        df['log10_clusters']=df['Num of Clusters'].apply(lambda x:np.log10(x) if x > 0 else np.log10(1))

        #cal spot cis pct
        with futures.ProcessPoolExecutor(max_workers=60) as pool:
            res=pool.map(partial(get_spot_cis_pct,cluster_dir=cluster_dir),df['Cell Barcode ID'])
            spot_cis_pct_res=list(res)
            df['spot_cis_pct']=spot_cis_pct_res
        df.to_csv(df_save_f)
    else:
        df=pd.read_csv(df_save_f)


    df.loc[:, 'sample']=sample_name
    df_all=pd.concat([df_all,df], ignore_index=True)

    row_num=math.floor(f/ncols)
    col_num=f % ncols 

    #plot num of reads
    sns.scatterplot(df,x='x',y='y',hue='Num of Reads',palette='Spectral_r',s=5,ax=axs_reads[row_num,col_num])
    # axs_reads[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_reads[row_num,col_num].legend().remove()
    axs_reads[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_reads[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_reads[row_num,col_num].invert_yaxis()  # Reverse Y axis 
    axs_reads[row_num,col_num].set_aspect('equal')
    
    fig_reads.suptitle('num of reads', fontsize=20)
    fig_reads.savefig('num_of_reads.png')

    #plot log1p reads
    sns.scatterplot(df,x='x',y='y',hue='log1p_reads',palette='Spectral_r',s=5,ax=axs_log1p_reads[row_num,col_num])
    # axs_log1p_reads[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_log1p_reads[row_num,col_num].legend().remove()
    axs_log1p_reads[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_log1p_reads[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_log1p_reads[row_num,col_num].invert_yaxis()
    axs_log1p_reads[row_num,col_num].set_aspect('equal')
    fig_log1p_reads.suptitle('log1p reads', fontsize=20)
    fig_log1p_reads.savefig('log1p_reads.png')

    #plot filters of reads
    df_f_reads=df[df['log1p_reads']>min_log1p_reads]
    sns.scatterplot(df_f_reads,x='x',y='y',color='blue',s=5,ax=axs_filters_reads[row_num,col_num])
    if x_r:
        axs_filters_reads[row_num,col_num].invert_xaxis()
    if not y_r:
        axs_filters_reads[row_num,col_num].invert_yaxis()
    axs_filters_reads[row_num,col_num].set_aspect('equal')
    fig_filter_reads.suptitle('filter reads spots', fontsize=20)
    axs_filters_reads[row_num,col_num].set_title(f'{sample_name} remained spots:{str(df_f_reads.shape[0])}')
    fig_filter_reads.savefig('filter_reads_spots.png')

    #plot num of contacts
    sns.scatterplot(df,x='x',y='y',hue='Num of Contacts',palette='Spectral_r',s=5,ax=axs_contacts[row_num,col_num])
    # axs_contacts[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_contacts[row_num,col_num].legend().remove()
    axs_contacts[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_contacts[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_contacts[row_num,col_num].invert_yaxis()
    axs_contacts[row_num,col_num].set_aspect('equal')
    fig_contacts.suptitle('num of contacts', fontsize=20)
    fig_contacts.savefig('num_of_contacts.png')

    #plot log1p contacts
    sns.scatterplot(df,x='x',y='y',hue='log1p_contacts',palette='Spectral_r',s=5,ax=axs_log1p_contacts[row_num,col_num])
    # axs_log1p_contacts[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_log1p_contacts[row_num,col_num].legend().remove()
    axs_log1p_contacts[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_log1p_contacts[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_log1p_contacts[row_num,col_num].invert_yaxis()
    axs_log1p_contacts[row_num,col_num].set_aspect('equal')
    fig_log1p_contacts.suptitle('log1p contacts', fontsize=20)
    fig_log1p_contacts.savefig('log1p_contacts.png')

    #plot filters of contacts
    df_f_contacts=df[df['log1p_contacts']>min_log1p_contacts]
    sns.scatterplot(df_f_contacts,x='x',y='y',color='blue',s=5,ax=axs_filters_contacts[row_num,col_num])
    if x_r:
        axs_filters_contacts[row_num,col_num].invert_xaxis()
    if not y_r:
        axs_filters_contacts[row_num,col_num].invert_yaxis()
    axs_filters_contacts[row_num,col_num].set_aspect('equal')
    fig_filter_contacts.suptitle('filter contacts spots', fontsize=20)
    axs_filters_contacts[row_num,col_num].set_title(f'{sample_name} remained spots:{str(df_f_contacts.shape[0])}')
    fig_filter_contacts.savefig('filter_contacts_spots.png')

    #plot num of clusters
    sns.scatterplot(df,x='x',y='y',hue='Num of Clusters',palette='Spectral_r',s=5,ax=axs_clusters[row_num,col_num])
    # axs_clusters[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_clusters[row_num,col_num].legend().remove()
    axs_clusters[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_clusters[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_clusters[row_num,col_num].invert_yaxis()
    axs_clusters[row_num,col_num].set_aspect('equal')
    fig_clusters.suptitle('num of clusters', fontsize=20)
    fig_clusters.savefig('num_of_clusters.png')

    # plot log1p clusters
    sns.scatterplot(df,x='x',y='y',hue='log1p_clusters',palette='Spectral_r',s=5,ax=axs_log1p_clusters[row_num,col_num])
    # axs_log1p_clusters[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_log1p_clusters[row_num,col_num].legend().remove()
    axs_log1p_clusters[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_log1p_clusters[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_log1p_clusters[row_num,col_num].invert_yaxis()
    axs_log1p_clusters[row_num,col_num].set_aspect('equal')
    fig_log1p_clusters.suptitle('log1p clusters', fontsize=20)
    fig_log1p_clusters.savefig('log1p_clusters.png')
    
    #plot filters of clusters
    df_f_clusters=df[df['log1p_clusters']>min_log1p_clusters]
    sns.scatterplot(df_f_clusters,x='x',y='y',color='blue',s=5,ax=axs_filters_clusters[row_num,col_num])
    if x_r:
        axs_filters_clusters[row_num,col_num].invert_xaxis()
    if not y_r:
        axs_filters_clusters[row_num,col_num].invert_yaxis()
    axs_filters_clusters[row_num,col_num].set_aspect('equal')
    fig_filter_clusters.suptitle('filter clusters spots', fontsize=20)
    axs_filters_clusters[row_num,col_num].set_title(f'{sample_name} remained spots:{str(df_f_clusters.shape[0])}')
    fig_filter_clusters.savefig('filter_clusters_spots.png')        

    #plot spot cis pct
    sns.scatterplot(df,x='x',y='y',hue='spot_cis_pct',palette='Spectral_r',s=5,ax=axs_cis_pct[row_num,col_num])
    # axs_cis_pct[row_num,col_num].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axs_cis_pct[row_num,col_num].legend().remove()    
    axs_cis_pct[row_num,col_num].set_title(sample_name)
    if x_r:
        axs_cis_pct[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_cis_pct[row_num,col_num].invert_yaxis()
    axs_cis_pct[row_num,col_num].set_aspect('equal')
    fig_cis_pct.suptitle('spot cis pct', fontsize=20)
    fig_cis_pct.savefig('spot_cis_pct.png')

    #plot filters of spot cis pct
    df_f_spot_cis_pct=df[df['spot_cis_pct']>min_spot_cis_pct]
    sns.scatterplot(df_f_spot_cis_pct,x='x',y='y',color='blue',s=5,ax=axs_filters_cis_pct[row_num,col_num])   
    if x_r:
        axs_filters_cis_pct[row_num,col_num].invert_xaxis()
    if not y_r:
        axs_filters_cis_pct[row_num,col_num].invert_yaxis()
    axs_filters_cis_pct[row_num,col_num].set_aspect('equal')
    fig_filter_cis_pct.suptitle('filter spot cis pct spots', fontsize=20)
    axs_filters_cis_pct[row_num,col_num].set_title(f'{sample_name} remained spots:{str(df_f_spot_cis_pct.shape[0])}')
    fig_filter_cis_pct.savefig('filter_spot_cis_pct_spots.png')

    # plot log1p reads density
    sns.kdeplot(df, x="log1p_reads",ax=axs_den_reads[row_num,col_num])
    axs_den_reads[row_num,col_num].vlines(min_log1p_reads,0,axs_den_reads[row_num,col_num].get_ylim()[1],colors='r',linestyles='dashed',label='min_log1p_reads')
    axs_den_reads[row_num,col_num].set_title(sample_name)
    fig_den_reads.suptitle('log1p reads density', fontsize=20)
    fig_den_reads.savefig('log1p_reads_density.png')

    # plot log1p contacts density
    sns.kdeplot(df, x="log1p_contacts",ax=axs_den_contacts[row_num,col_num])
    axs_den_contacts[row_num,col_num].vlines(min_log1p_contacts,0,axs_den_contacts[row_num,col_num].get_ylim()[1],colors='r',linestyles='dashed',label='min_log1p_contacts')
    axs_den_contacts[row_num,col_num].set_title(sample_name)
    fig_den_contacts.suptitle('log1p contacts density', fontsize=20)
    fig_den_contacts.savefig('log1p_contacts_density.png')
    
    # plot log1p clusters density
    sns.kdeplot(df, x="log1p_clusters",ax=axs_den_clusters[row_num,col_num])
    axs_den_clusters[row_num,col_num].vlines(min_log1p_clusters,0,axs_den_clusters[row_num,col_num].get_ylim()[1],colors='r',linestyles='dashed',label='min_log1p_clusters')
    axs_den_clusters[row_num,col_num].set_title(sample_name)
    fig_den_clusters.suptitle('log1p clusters density', fontsize=20)
    fig_den_clusters.savefig('log1p_clusters_density.png')

    # plot spot cis pct density
    sns.kdeplot(df, x="spot_cis_pct",ax=axs_den_spot_cis_pct[row_num,col_num])
    axs_den_spot_cis_pct[row_num,col_num].vlines(min_spot_cis_pct,0,axs_den_spot_cis_pct[row_num,col_num].get_ylim()[1],colors='r',linestyles='dashed',label='min_spot_cis_pct')
    axs_den_spot_cis_pct[row_num,col_num].set_title(sample_name)
    axs_den_spot_cis_pct[row_num,col_num].set_xlim(0,100)
    fig_den_spot_cis_pct.suptitle('spot cis pct density', fontsize=20)
    fig_den_spot_cis_pct.savefig('spot_cis_pct_density.png')
    
    # #plot filters
    df_f=df[(df['log1p_reads']>min_log1p_reads) & (df['spot_cis_pct']>min_spot_cis_pct) & (df['log1p_clusters']>min_log1p_clusters)]
    df_f.loc[:, 'sample']=sample_name
    df_f_all=pd.concat([df_f_all,df_f], ignore_index=True)

    df_f.to_csv(sample_name+'_filtered_spot_infor.csv')
    sns.scatterplot(df_f,x='x',y='y',color='blue',s=5,ax=axs_filters[row_num,col_num])
    if x_r:
        axs_filters[row_num,col_num].invert_xaxis()  # Reverse X axis
    if not y_r:
        axs_filters[row_num,col_num].invert_yaxis()
    axs_filters[row_num,col_num].set_aspect('equal')
    fig_filters.suptitle('final spots', fontsize=20)
    axs_filters[row_num,col_num].set_title(f'{sample_name} remained spots:{str(df_f.shape[0])}')
    fig_filters.savefig('final_spots.png')

df_all_f='all_samples_spot_infor.csv'
df_all.to_csv(df_all_f)

df_f_all_f='all_samples_filtered_spot_infor.csv'
df_f_all.to_csv(df_f_all_f)

In [None]:
fig, axs = plt.subplots(nrows=4, ncols=1,figsize=(18, 18))
sns.violinplot(df_f_all,x='sample',y='log10_reads',ax=axs[0])
axs[0].tick_params(axis='x',rotation=45)
sns.violinplot(df_f_all,x='sample',y='log10_clusters',ax=axs[1])
axs[1].tick_params(axis='x',rotation=45)
sns.violinplot(df_f_all,x='sample',y='log10_contacts',ax=axs[2])
axs[2].tick_params(axis='x',rotation=45)
sns.violinplot(df_f_all,x='sample',y='spot_cis_pct',ax=axs[3])
axs[3].tick_params(axis='x',rotation=45)
axs[3].set_ylim(0,40)
fig.suptitle('all sample filter spot qc', fontsize=20)
fig.savefig('all_sample_filter_spot_qc.png')