In [None]:
'''Make plots and tables
author: Carsten Knutsen
Date: March 18 2023
conda environment: bulk_rnaseq
'''

In [None]:
import pandas as pd
import os
import scanpy as sc
from anndata import AnnData
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import zscore
from sklearn.decomposition import PCA


In [None]:
input_dir = 'data/output_data'
output_dir = 'data/output_data'
os.makedirs(output_dir, exist_ok=True)

## Put DEG lists into excel

In [None]:
deg_excel = f'{output_dir}/deg_tests.xlsx'
with pd.ExcelWriter(deg_excel) as writer:
    csv_dir = f'{input_dir}/deg_tests'
    for fn in sorted(os.listdir(csv_dir)):
        print(fn)
        sheet = pd.read_csv(f'{csv_dir}/{fn}', index_col = 0, header =0)
        sheet.to_excel(writer, sheet_name = '_'.join(fn.split('.')[0].split('_')[:-1])[:31])
    

## PCA

In [None]:
tmm = pd.read_csv(f'{input_dir}/tmm.csv', header = 0, index_col = 0)
log_df = tmm.T
log_df = log_df+0.001
log_df
log_df=(
    log_df
    .apply(np.log10)
)

pca = PCA(n_components=2)
principalComponents = pca.fit_transform(log_df)
principalDf = pd.DataFrame(data = principalComponents
             , index = tmm.T.index, columns = ['PC1', 'PC2'])
principalDf['Age'] = ['Adult',
                      'Adult',
                      'Adult',
                      'Adult',
                      'Adult',
                      'Adult',
                      'Juvenile',
                      'Juvenile',
                      'Juvenile',
                      'Juvenile',
                      'Juvenile',
                      'Juvenile',
                     ]
principalDf['Treatment'] = ['Pseudomonas',
                            'Pseudomonas',
                            'Pseudomonas',
                            'Saline',
                            'Saline',
                            'Saline',
                            'Pseudomonas',
                            'Pseudomonas',
                            'Pseudomonas',
                            'Saline',
                            'Saline',
                            'Saline',
                           ]
principalDf['Replicate'] = ['1',
                            '2',
                            '3',
                            '1',
                            '2',
                            '3',
                            '1',
                            '2',
                            '3',
                            '1',
                            '2',
                            '3',
                           ]

principalDf['Sample'] = [f'{x}_{y}_{z}' for x,y,z in zip(principalDf['Age'],
                                                        principalDf['Treatment'],
                                                        principalDf['Replicate'])
                        ]
principalDf['name'] = principalDf.index

fig, ax = plt.subplots(1,1)
sns.scatterplot(data = principalDf, x = 'PC1', y ='PC2',hue='Treatment',hue_order = ['Saline', 'Pseudomonas'],style ='Age', s=300,linewidth=0,ax = ax)
ax.set_xlabel(f'PC1 ({str(pca.explained_variance_ratio_[0] *100)[:4]}% variance)')
ax.set_ylabel(f'PC2 ({str(pca.explained_variance_ratio_ [1]*100)[:4]}% variance)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
fig.savefig(f'{output_dir}/pca.png', dpi = 300, bbox_inches = 'tight')
plt.close(fig)


## Volcano Plots for each comparison

In [None]:
deg_lists = pd.read_excel(deg_excel, sheet_name=None, index_col=0)
for k in sorted(deg_lists.keys()):
    df = deg_lists[k]
    df['$-Log_{10}$(FDR)'] = df['FDR'].apply(lambda x: -np.log10(x))
    df['$Log_{2}$(FC)'] = df['logFC']

    color = []
    for x in df.index:
        tmp = df.loc[x]
        if tmp['logFC'] > 1 and tmp['FDR'] < 0.05:
            color.append('Upregulated')
        elif tmp['logFC'] < -1 and tmp['FDR'] < 0.05:
            color.append('Downregulated')
        else:
            color.append('NS')

    
    df['color'] = color
    print(k)
    print(df.value_counts('color'))
    df_sub = df.loc[df['FDR']<0.05].sort_values('logFC',ascending=False)
    genes = df_sub.head(5).index.tolist() + df_sub.tail(5).index.tolist() + df.head(5).index.tolist()
    markers = ['Bmx','Gja5','Car4','Ednrb','Glp1r','Kit','Mmrn1','Hpgd','Slc6a2','Car8']
    for gene_ls in [['top', genes],['markers',markers]]:
        fig, ax = plt.subplots(1, 1, figsize=(3, 3))
        sns.scatterplot(data=df,
                        y='$-Log_{10}$(FDR)',
                        x='$Log_{2}$(FC)',
                        hue='color',
                        hue_order=['Upregulated', 'NS', 'Downregulated'],
                        palette=['green', 'grey', 'red'],
                        s=10,
                        linewidth=0,
                        ax=ax
                        )
        for gene in gene_ls[1]:
            plt.text( df['$Log_{2}$(FC)'][gene],df['$-Log_{10}$(FDR)'][gene], gene, size=6)
        ax.get_legend().remove()
        ax.set_title(f'{k}')
        fig.savefig(f'{output_dir}/volcano_{k}_{gene_ls[0]}.png', dpi=300, bbox_inches="tight")
        plt.close()
