# Imputation and DGE analysis

In [None]:
# Import packages
import os, sys, glob, re, math, pickle
import phate, scprep, magic
import graphtools as gt
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import time,random,datetime
import networkx as nx
from sklearn import metrics
from sklearn import model_selection
from scipy import sparse
from scipy.stats import mannwhitneyu, tiecorrect, rankdata
from statsmodels.stats.multitest import multipletests
import scanpy as sc
from sklearn.dummy import DummyClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.cluster import SpectralClustering, OPTICS, cluster_optics_dbscan, AgglomerativeClustering
from bbknn import bbknn
import warnings
%matplotlib inline
%load_ext memory_profiler
import logging

# settings
plt.rc('font', size = 8)
plt.rc('font', family='sans serif')
plt.rcParams['pdf.fonttype']=42
plt.rcParams['ps.fonttype']=42
plt.rcParams['text.usetex']=False
plt.rcParams['legend.frameon']=False
plt.rcParams['axes.grid']=False
plt.rcParams['legend.markerscale']=0.5
sc.set_figure_params(dpi=300,dpi_save=600,
                     frameon=False,
                     fontsize=8)
plt.rcParams['savefig.dpi']=600
sc.settings.verbosity=2
sc._settings.ScanpyConfig.n_jobs=-1

# reproducibility
rs = np.random.seed(42)

# utils
def mwu(X,Y,gene_names,correction=None,debug=False) :
    '''
    Benjamini-Hochberg correction implemented. Can change to Bonferonni
    gene_names (list)
    if X,Y single gene expression array, input x.reshape(-1,1), y.reshape(-1,1)
    NOTE: get zeros sometimes because difference (p-value is so small)
    '''
    p=pd.DataFrame()
    print('Mann-Whitney U w/Benjamini/Hochberg correction\n')
    start = time.time()
    for i,g in enumerate(gene_names) :
        if i==np.round(np.quantile(np.arange(len(gene_names)),0.25)) :
            print('... 25% completed in {:.2f}-s'.format(time.time()-start))
        elif i==np.round(np.quantile(np.arange(len(gene_names)),0.5)) :
            print('... 50% completed in {:.2f}-s'.format(time.time()-start))
        elif i==np.round(np.quantile(np.arange(len(gene_names)),0.75)) :
            print('... 75% completed in {:.2f}-s'.format(time.time()-start))
        p.loc[i,'Gene']=g
        if (tiecorrect(rankdata(np.concatenate((np.asarray(X[:,i]),np.asarray(Y[:,i])))))==0) :
            if debug :
                print('P-value not calculable for {}'.format(g))
            p.loc[i,'pval']=np.nan
        else :
            _,p.loc[i,'pval']=mannwhitneyu(X[:,i],Y[:,i]) # continuity correction is True
    print('... mwu computed in {:.2f}-s\n'.format(time.time() - start))
    # ignore NaNs, since can't do a comparison on these (change numbers for correction)
    p_corrected = p.loc[p['pval'].notna(),:]
    new_pvals = multipletests(p_corrected['pval'],method='fdr_bh')
    p_corrected['pval_corrected'] = new_pvals[1]
    return p_corrected

def log2aveFC(X,Y,gene_names,AnnData=None) :
    '''not sensitivity to directionality due to subtraction
    X and Y full arrays, subsetting performed here
    `gene_names` (list): reduced list of genes to calc
    `adata` (sc.AnnData): to calculate reduced list. NOTE: assumes X,Y drawn from adata.var_names
    '''
    if not AnnData is None :
        g_idx = [i for i,g in enumerate(AnnData.var_names) if g in gene_names]
        fc=pd.DataFrame({'Gene':AnnData.var_names[g_idx],
                         'log2FC':np.log2(X[:,g_idx].mean(axis=0)) - np.log2(Y[:,g_idx].mean(axis=0))}) # returns NaN if negative value 
    else :
        fc=pd.DataFrame({'Gene':gene_names,
                         'log2FC':np.log2(X.mean(axis=0)) - np.log2(Y.mean(axis=0))})
    return fc


# fps
dfp = '/vast/palmer/pi/lim_janghoo/cl2292/'
pfp = '/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/results/'
pdfp = '/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/data/'
sc.settings.figdir = pfp

In [None]:
#Imputation by genotype & timepoint

wt30 = adata[(adata.obs['genotype']=='SCA1-fl/+'), :]
mut30 = adata[(adata.obs['genotype']=='SCA1-fl/NG2-Cre'), :]

#k=45, t=3
print('Starting imputation for {}\n'.format('30wk SCA1-fl/+'))
tic = time.time()


wt30.obs['value'] = 0
sc.pp.pca(wt30)
sc.pp.neighbors(wt30, n_pcs=45)


# MAGIC
G = gt.Graph(data=wt30.obsp['connectivities']+sparse.diags([1]*wt30.shape[0],format='csr'),
             precomputed='adjacency',
             use_pygsp=True)
G.knn_max = None

magic_op=magic.MAGIC().fit(X=wt30.X,graph=G) # running fit_transform produces wrong shape
wt30.layers['imputed']=magic_op.transform(wt30.X,genes='all_genes')

print('\n  imputation in {:.2f}-min'.format((time.time() - tic)/60))


print('\n Starting imputation for {}\n'.format('30wk SCA1-fl/NG2-Cre'))
tic = time.time()

mut30.obs['value'] = 0
sc.pp.pca(mut30)
sc.pp.neighbors(mut30, n_pcs=45)

# MAGIC
G = gt.Graph(data=mut30.obsp['connectivities']+sparse.diags([1]*mut30.shape[0],format='csr'),
             precomputed='adjacency',
             use_pygsp=True)
G.knn_max = None

magic_op=magic.MAGIC().fit(X=mut30.X,graph=G) # running fit_transform produces wrong shape
mut30.layers['imputed']=magic_op.transform(mut30.X,genes='all_genes')

print('\n  imputation in {:.2f}-min'.format((time.time() - tic)/60))



In [None]:
adata =wt30.concatenate([mut30,] ,batch_key='concat')

# save data objects
adata.write(os.path.join(pdfp,'250414_OL-SCA1-cKI_imp.h5ad'))
print('saved @'+datetime.datetime.now().strftime('%y%m%d.%H:%M:%S'))

In [None]:
## EMD; IMP
wt = adata[(adata.obs['genotype']=='SCA1-fl/+'), :]
mut = adata[(adata.obs['genotype']=='SCA1-fl/NG2-Cre'), :]

if True :
    dge_grandtotal = time.time()
    group='ctype'
    fname = 'SCA1-flwCre vs SCA1-flwoCre' 
    dge = pd.DataFrame()
    for t in ['30wk'] :
        print('Evaluating {}'.format(t))
        t_total = time.time()
        dge_total = time.time()
        start_t=time.time()
        
        # up down dichotomy
        print('\n--------')
        print('...')
        print('--------\n')
#        dge = pd.DataFrame()
        for i in wt.obs[group].unique():
            start = time.time()
            print('\n{}, SCA1 vs WT'.format(i))
            print('----')
            X = wt[(wt.obs[group]==i)&(wt.obs['timepoint']==t), :].layers['imputed']
            Y = mut[(mut.obs[group]==i)&(mut.obs['timepoint']==t), :].layers['imputed']
            

            X = np.asarray(X)
            Y = np.asarray(Y)
        
            print('    Ncells in X:{}'.format(X.shape[0]))
            print('    Ncells in Y:{}\n'.format(Y.shape[0]))            
            
            emd = scprep.stats.differential_expression(Y,X,
                                                       measure = 'emd',
                                                       direction='both',
                                                       gene_names=wt.var_names,
                                                       n_jobs=-1)
            
            # mann-whitney u, corrected p-values
            p = mwu(Y,X, wt.var_names)
            emd['Gene']=emd.index
            emd=emd.drop(columns='rank')
            fc = log2aveFC(Y,X,wt.var_names.to_list())
            gene_mismatch = fc['Gene'].isin(p['Gene'])
            if gene_mismatch.any():
                fc = fc.loc[gene_mismatch,:]
                warnings.warn('Warning: {} genes dropped due to p-val NA.'.format((gene_mismatch==False).sum()))
            dt = pd.merge(p,fc,how='left',on="Gene")
            gene_mismatch = emd['Gene'].isin(p['Gene'])
            if gene_mismatch.any():
                emd = emd.loc[gene_mismatch,:]
            dt = pd.merge(dt,emd,how='left',on='Gene')
            dt['Cell type']=[i]*dt.shape[0]
            dt['timepoint']=[str(t)]*dt.shape[0]
            dt['nlog10pvalcorrected']=(-1)*np.log10(dt['pval_corrected'])
            dge = pd.concat([dge,dt], ignore_index=True)
            print('... computed in {:.2f}-s'.format(time.time()-start))
        print('\nFinished timepoint {} in {:.2f}-min'.format(t,(time.time()-start_t)/60))  
    dge.to_csv(os.path.join(pfp,'250414_dge_'+fname+'.csv'),index=False)
#     dgeup = dge.loc[dge['emd']>0,:] # take only 'up' (switch for down)
#     dgedown = dge.loc[dge['emd']<0,:] # take only 'down'
#     dgeup.to_csv(os.path.join(pfp,'240723_dge_'+fname+'_SCA1flwCre_up.csv'),index=False)
#     dgedown.to_csv(os.path.join(pfp,'240723_dge_'+fname+'_SCA1flwCre_down.csv'),index=False)
        

    print('DGE finished in {:.2f}-min'.format((time.time()-dge_grandtotal)/60))


In [None]:
## DEG numbers

num_dge=pd.DataFrame()

dge = pd.read_csv('/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/results/250414_dge_SCA1-flwCre vs SCA1-flwoCre.csv')

downsig = dge.loc[(dge['emd']<-0.1)&(dge['pval_corrected']<0.01),:]
upsig = dge.loc[(dge['emd']>0.1)&(dge['pval_corrected']<0.01),:]

for t in ['30wk']:
    for c in adata.obs['ctype'].unique():
        down = downsig.loc[(downsig['Cell type']==c)&(downsig['timepoint']==t),:]
        up = upsig.loc[(upsig['Cell type']==c)&(upsig['timepoint']==t),:]
        de = pd.DataFrame({'Cell type':[str(c)], 'timepoint':[str(t)], 'downregulated':[len(down['Gene'])], 'upregulated':[len(up['Gene'])]})
        num_dge = pd.concat([num_dge,de], ignore_index=True)
#         down['Gene'].to_csv(os.path.join(pfp,'231006_'+t+'_'+c+'_down_ASOvsWT.csv'),index=False)
#         up['Gene'].to_csv(os.path.join(pfp,'231006_'+t+'_'+c+'_up_ASOvsWT.csv'),index=False)
num_dge.to_csv(os.path.join(pfp,'250414_number of DEG.csv'),index=False)

In [None]:
import os, sys, glob, re, math, pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
import csv
import gseapy
pfp = '/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/results/'


# settings
plt.rc('font', size = 9)
plt.rc('font', family='sans serif')
plt.rcParams['pdf.fonttype']=42
plt.rcParams['ps.fonttype']=42
plt.rcParams['text.usetex']=False
plt.rcParams['legend.frameon']=False
plt.rcParams['axes.grid']=False
plt.rcParams['legend.markerscale']=0.5
plt.rcParams['savefig.dpi']=600
sns.set_style("ticks")


def filterdeg(df, ctype, timepoint=None):
    if 'Cell type' not in df.columns or 'Gene' not in df.columns:
        raise ValueError("DataFrame must contain 'Cell type' and 'Gene' columns.")
    
    if timepoint is None:
        filtered_df = df.loc[df['Cell type'] == ctype, :]
    else:
        if 'timepoint' not in df.columns:
            raise ValueError("DataFrame must contain 'timepoint' column for filtering by timepoint.")
        filtered_df = df.loc[(df['Cell type'] == ctype) & (df['timepoint'] == timepoint), :]

    return filtered_df['Gene'].to_list()

def enrichr(genes, title = 'Title',geneset = 'GO_Biological_Process_2023', save = None):
    genes = genes
    res=gseapy.enrichr(gene_list=genes, organism = 'Mouse', gene_sets = geneset, cutoff=0.05)
    df = res.res2d[res.res2d['Adjusted P-value']<0.05]
    df = df.sort_values(by=['Combined Score'], ascending = False).head(10)
    df['-log10(Adjusted P-value)'] = df['Adjusted P-value'].apply(np.log10)*-1
    df['Term'] = df['Term'].apply(lambda x: x[:-13]) #Remove '(GO:xxxxxxxxx)'

    # Data
    GO_biological_processes = df['Term'].to_list()
    Fold_Enrichment = df['Combined Score'].to_list()
    bar_colors = df['-log10(Adjusted P-value)'].to_list()

    # Check if bar_colors is empty, and skip plotting if it is
    if not bar_colors:
        print("No significant GO terms to plot.")
        return
    
    # Create figure and axis
    fig, ax = plt.subplots(1,1, figsize = (2,2))

    # Create horizontal bars
    bars = ax.barh(GO_biological_processes, Fold_Enrichment, color='gray')

    # Create a ScalarMappable for color mapping
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(bar_colors), vmax=max(bar_colors)))

    # Convert p-values to colors using the colormap
    colors = [sm.to_rgba(p_value) for p_value in bar_colors]

    # Set colors for the bars
    for bar, color in zip(bars, colors):
        bar.set_color(color)

    # Create a divider to make room for the colorbar
    divider = make_axes_locatable(ax)

    # Append an axis for the colorbar on the right side of the main plot
    cax = divider.append_axes("right", size="5%", pad=0.05)

    # Add a colorbar to the plot
    cbar = plt.colorbar(sm, cax=cax)
    cbar.set_label('-log10(Adjusted P-value)')

    # Set labels and title
    ax.set_xlabel('Combined Score')
    ax.set_ylabel('GO Biological Process')
    ax.set_title(title)

    # Adjust x tick labels font size
    plt.xticks(fontsize=8)

    if save is None:
        plt.show()
    else:
        fig.savefig(os.path.join(pfp,save),dpi=300, bbox_inches='tight')
        plt.show()

def intersection(lst1, lst2):
    lst3 = [value for value in lst1 if value in lst2]
    return lst3    


dge = pd.read_csv('/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/results/250414_dge_SCA1-flwCre vs SCA1-flwoCre.csv')

downsig = dge.loc[(dge['emd']<-0.1)&(dge['pval_corrected']<0.01),:]
upsig = dge.loc[(dge['emd']>0.1)&(dge['pval_corrected']<0.01),:]


In [None]:
#sort by -log(adj_p)
def enrichr2(genes, title = 'Title',geneset = 'GO_Biological_Process_2023', save = None):
    genes = genes
    res=gseapy.enrichr(gene_list=genes, organism = 'Mouse', gene_sets = geneset, cutoff=0.05)
    df = res.res2d[res.res2d['Adjusted P-value']<0.05]
    df['-log10(Adjusted P-value)'] = df['Adjusted P-value'].apply(np.log10)*-1
    df = df.sort_values(by=['-log10(Adjusted P-value)'], ascending = False).head(10)
    df['Term'] = df['Term'].apply(lambda x: x[:-13]) #Remove '(GO:xxxxxxxxx)'

    # Data
    GO_biological_processes = df['Term'].to_list()
    Fold_Enrichment = df['-log10(Adjusted P-value)'].to_list()
    bar_colors = df['Combined Score'].to_list()

    # Check if bar_colors is empty, and skip plotting if it is
    if not bar_colors:
        print("No significant GO terms to plot.")
        return
    
    # Create figure and axis
    fig, ax = plt.subplots(1,1, figsize = (2,2))

    # Create horizontal bars
    bars = ax.barh(GO_biological_processes, Fold_Enrichment, color='gray')

    # Create a ScalarMappable for color mapping
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(bar_colors), vmax=max(bar_colors)))

    # Convert p-values to colors using the colormap
    colors = [sm.to_rgba(p_value) for p_value in bar_colors]

    # Set colors for the bars
    for bar, color in zip(bars, colors):
        bar.set_color(color)

    # Create a divider to make room for the colorbar
    divider = make_axes_locatable(ax)

    # Append an axis for the colorbar on the right side of the main plot
    cax = divider.append_axes("right", size="5%", pad=0.05)

    # Add a colorbar to the plot
    cbar = plt.colorbar(sm, cax=cax)
    cbar.set_label('Combined Score')

    # Set labels and title
    ax.set_xlabel('-log10(Adjusted P-value)')
    ax.set_ylabel('GO Biological Process')
    ax.set_title(title)

    # Adjust x tick labels font size
    plt.xticks(fontsize=8)

    if save is None:
        plt.show()
    else:
        fig.savefig(os.path.join(pfp,save),dpi=300, bbox_inches='tight')
        plt.show()

In [None]:
## PC 
geneset = 'GO_Cellular_Component_2025'

c_list = ['PC']
for c in c_list:
    for t in ['30wk',]:
        cKIdown = set(filterdeg(downsig, ctype=c, timepoint=t))
        cKIup = set(filterdeg(upsig, ctype=c, timepoint=t))

#         enrichr(list(cKIdown), title = c+', '+t+' down cKI',
#                 geneset = geneset, 
#                 save = '250415_enrichR_'+c+'_'+t+'_cKI down_CC.pdf'
#                )
        
        enrichr2(list(cKIup), title = c+', '+t+' up cKI',
                geneset = geneset,
                save = '250415_enrichR_'+c+'_'+t+'_cKI up_CC_sortbyp.pdf'
               )

    

In [None]:
## PC 
geneset = 'GO_Biological_Process_2025'

c_list = ['OPC','OL']
for c in c_list:
    for t in ['30wk',]:
        cKIdown = set(filterdeg(downsig, ctype=c, timepoint=t))
        cKIup = set(filterdeg(upsig, ctype=c, timepoint=t))

        enrichr(list(cKIdown), title = c+', '+t+' down cKI',
                geneset = geneset, 
                save = '250415_enrichR_'+c+'_'+t+'_cKI down_BP.pdf'
               )
        
        enrichr(list(cKIup), title = c+', '+t+' up cKI',
                geneset = geneset,
                save = '250415_enrichR_'+c+'_'+t+'_cKI up_BP.pdf'
               )

    

In [None]:
## Biological Process sorted by p-value

geneset = 'GO_Biological_Process_2025'
sort_by = 'p_value'
# downsig = cKIdown
# upsig = cKIup

c_list = ['OL']


for t in ['30wk']:
    all_results=[]
    for c in c_list:
        down = set(filterdeg(downsig, ctype=c, timepoint=t))
        up = set(filterdeg(upsig, ctype=c, timepoint=t))

        res = gseapy.enrichr(gene_list=list(down), organism='mouse', gene_sets=geneset, cutoff=0.05)
        df1 = res.res2d[res.res2d['Adjusted P-value'] < 0.05]
        res = gseapy.enrichr(gene_list=list(up), organism='mouse', gene_sets=geneset, cutoff=0.05)
        df2 = res.res2d[res.res2d['Adjusted P-value'] < 0.05]
        
        df1['Cell type'] = c
        df2 = df2.copy()
        df2['Cell type'] = c
        df1['Timepoint'] = t
        df2 = df2.copy()
        df2['Timepoint'] = t
        
        # Sort by the chosen method
        if sort_by == 'combined_score':
            df1['Combined Score'] = df1['Combined Score']*-1
            df1 = df1.sort_values(by=['Combined Score'], ascending=False).head(5)
            df2 = df2.sort_values(by=['Combined Score'], ascending=False).head(5)
            df = pd.concat([df1, df2])
            df = df.sort_values(by=['Combined Score'], ascending = False)            
            color_data = df['-log10(Adjusted P-value)'] = df['Adjusted P-value'].apply(np.log10) * -1
        elif sort_by == 'p_value':
            
            df1['-log10(Adjusted P-value)'] = df1['Adjusted P-value'].apply(np.log10)
            df2 = df2.copy()
            df2['-log10(Adjusted P-value)'] = df2['Adjusted P-value'].apply(np.log10) * -1

            df1 = df1.sort_values(by=['-log10(Adjusted P-value)'], ascending=False).head(10)
            df2 = df2.sort_values(by=['-log10(Adjusted P-value)'], ascending=False).head(10)
            df = pd.concat([df1, df2])
            df = df.sort_values(by=['-log10(Adjusted P-value)'], ascending = False)                
        else:
            raise ValueError("sort_by must be 'combined_score' or 'p_value'")

        all_results.append(df)
        
    df_combined = pd.concat(all_results)

    if sort_by == 'combined_score':          
        x_data = df_combined['Combined Score'].to_list()
        color_data = df_combined['-log10(Adjusted P-value)'].to_list()
        x_label = 'Combined Score'
        color_label = '-log10(Adjusted P-value)'
        cmap_label = '-log10(Adjusted P-value)'
    elif sort_by == 'p_value':              
        x_data = df_combined['-log10(Adjusted P-value)'].to_list()
        color_data = df_combined['Combined Score'].to_list()
        x_label = '-log10(Adjusted P-value)'
        color_label = 'Combined Score'
        cmap_label = 'Combined Score'
    else:
        raise ValueError("sort_by must be 'combined_score' or 'p_value'")
            

    df_combined['Term'] = df_combined['Term'].apply(lambda x: x[:-13])  # Remove '(GO:xxxxxxxxx)'
#     df_combined['Term2'] = df_combined['Timepoint']+'_'+df_combined['Cell type']+ ' '+df_combined['Term']
    GO_biological_processes = df_combined['Term'].to_list()

#     # Check if there's data to plot
#     if df.empty:
#         print(f"No significant GO terms to plot for {title}.")
#         return
    # Create figure and axis
    fig, ax = plt.subplots(1, 1, figsize=(2,2))

    # Create horizontal bars
    bars = ax.barh(GO_biological_processes, x_data, color='gray')

    # Create a ScalarMappable for color mapping
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(color_data), vmax=max(color_data)))

    # Set colors for the bars
    colors = [sm.to_rgba(score) for score in color_data]
    for bar, color in zip(bars, colors):
        bar.set_color(color)

        
#     # Add labels for cell types
#     for bar, cell_type in zip(bars, df_combined['Cell type']):
#         ax.text(
#             bar.get_width(),  # x-coordinate of the label
#             bar.get_y() + bar.get_height() / 2,  # y-coordinate of the label (centered vertically)
#             cell_type,  # Label text
#             va='center',  # Vertical alignment
#             ha='left',  # Horizontal alignment
#             fontsize=8,  # Font size for the label
#             color='black'  # Text color
#         )
        
        
    # Create a divider to make room for the colorbar
    divider = make_axes_locatable(ax)

    # Append an axis for the colorbar on the right side of the main plot
    cax = divider.append_axes("right", size="5%", pad=0.05)

    # Add a colorbar to the plot
    cbar = plt.colorbar(sm, cax=cax)
    cbar.set_label(cmap_label, color = 'black')

    # Set labels and title
    ax.set_xlabel(x_label, color = 'black')
    ax.set_ylabel(geneset, color = 'black')
    ax.set_title(t, color = 'black')
    ax.grid(False)
    
    # Customize axes tick labels to be black
    ax.tick_params(axis='x', colors='black')
    ax.tick_params(axis='y', colors='black')
    ax.axvline(x=0, color = 'black', linestyle='-', linewidth=0.5)

    # Adjust x tick labels font size
    plt.xticks(fontsize=8, color='black')
    plt.yticks(color='black')

    # Show or save the plot
#     fig.savefig(os.path.join(pfp, '250415_cKI_OL_enrichr_'+t+'_pval.pdf'), dpi=300, bbox_inches='tight')
    plt.show()
 

In [None]:
import os, sys, glob, re, math, pickle
import pandas as pd
import time,random,datetime
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib_venn
import scanpy as sc
import warnings
import csv
%matplotlib inline
%load_ext memory_profiler

# settings
plt.rc('font', size = 8)
plt.rc('font', family='sans serif')
plt.rcParams['pdf.fonttype']=42
plt.rcParams['ps.fonttype']=42
plt.rcParams['text.usetex']=False
plt.rcParams['legend.frameon']=False
plt.rcParams['axes.grid']=False
plt.rcParams['legend.markerscale']=0.5
sc.set_figure_params(dpi=300,dpi_save=600,
                     frameon=False,
                     fontsize=8)
plt.rcParams['savefig.dpi']=600
sc.settings.verbosity=2
sc._settings.ScanpyConfig.n_jobs=-1

# reproducibility
rs = np.random.seed(42)

# fps
dfp = '/vast/palmer/pi/lim_janghoo/cl2292/'
pfp = '/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/results/'
pdfp = '/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/data/'
sc.settings.figdir = pfp

# load DEGs

#OL-SCA1-cKI 30W

dge = pd.read_csv('/vast/palmer/pi/lim_janghoo/cl2292/SCA1_OL/results/250414_dge_SCA1-flwCre vs SCA1-flwoCre.csv')
downsig_cKI = dge.loc[(dge['emd']<=-0.1)&(dge['pval_corrected']<0.01),:]
upsig_cKI = dge.loc[(dge['emd']>=0.1)&(dge['pval_corrected']<0.01),:]

#SCA1-KI 5-30W

KI = pd.read_csv(pfp+'250414_dge_KI_imp.csv')
downsig_KI = KI.loc[(KI['emd']<=-0.1)&(KI['pval_corrected']<0.01),:]
upsig_KI = KI.loc[(KI['emd']>=0.1)&(KI['pval_corrected']<0.01),:]

def filterdeg(df, ctype, timepoint=None):
    if 'Cell type' not in df.columns or 'Gene' not in df.columns:
        raise ValueError("DataFrame must contain 'Cell type' and 'Gene' columns.")
    
    if timepoint is None:
        filtered_df = df.loc[df['Cell type'] == ctype, :]
    else:
        if 'timepoint' not in df.columns:
            raise ValueError("DataFrame must contain 'timepoint' column for filtering by timepoint.")
        filtered_df = df.loc[(df['Cell type'] == ctype) & (df['timepoint'] == timepoint), :]

    return filtered_df['Gene'].to_list()

c_list = downsig_cKI['Cell type'].unique().tolist()
nrow = len(c_list)

fig, axs = plt.subplots(nrow, 1, figsize=(2, nrow*2))
for i, c in enumerate(c_list): 
    # Get sets of genes for venn diagrams
    cKIdown = set(filterdeg(downsig_cKI, ctype=c))
    KIdown = set(filterdeg(downsig_KI, ctype=c))

    # Plot the Venn diagram for "DOWN"
    ax = axs[i]  # Select the corresponding subplot
    matplotlib_venn.venn2([cKIdown, KIdown], ('OL-SCA1-cKI 30W DOWN', 'SCA1-KI 5-30W DOWN'),
                          set_colors=('#FF9999','#9999FF'), alpha = 0.8, ax=ax)
    matplotlib_venn.venn2_circles([cKIdown, KIdown], linewidth=0.6, ax=ax)
    ax.set_title(f"{c} DOWN")
    
fig.savefig(os.path.join(pfp, '250415_venn_overlappaing DEG_down.pdf'), dpi=300, bbox_inches='tight')

fig, axs = plt.subplots(nrow, 1, figsize=(2, nrow*2))
for i, c in enumerate(c_list): 
    # Get sets of genes for venn diagrams
    cKIup = set(filterdeg(upsig_cKI, ctype=c))
    KIup = set(filterdeg(upsig_KI, ctype=c))

    # Plot the Venn diagram for "UP"
    ax = axs[i]  # Select the corresponding subplot
    matplotlib_venn.venn2([cKIup, KIup], ('OL-SCA1-cKI 30W UP', 'SCA1-KI 5-30W UP'),
                          set_colors=('#FF9999','#9999FF'), alpha = 0.8, ax=ax)
    matplotlib_venn.venn2_circles([cKIup, KIup], linewidth=0.6, ax=ax)
    ax.set_title(f"{c} UP")
    
fig.savefig(os.path.join(pfp, '250415_venn_overlappaing DEG_up.pdf'), dpi=300, bbox_inches='tight')