In [None]:
import pandas as pd 
import numpy as np 
import os
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
from IPython.display import display
from scipy import stats

import utils.DE_plotting_tools as plot_utils


In [None]:
import importlib 
importlib.reload(plot_utils)

#### Reading DE data 

In [None]:
data_path = "../DE_out/Blobel-15620/"

# Instantiating files to read from 
DE_files = [ file for file in os.listdir(data_path) if file.startswith('DE')]
# Seperating files with allgene or Olfr only DE 
DE_allgene_files = [ file for file in DE_files if 'allgene' in file]
DE_Olfr_files = [ file for file in DE_files if 'Olfr' in file]

DE_allgene_df_dict = {}
for file in DE_allgene_files: 
    DE_allgene_df_dict[file.replace('.csv', '')] = pd.read_csv(os.path.join(data_path, file), index_col= 0).reset_index(drop=True)

DE_Olfr_df_dict = {}
for file in DE_Olfr_files: 
    DE_Olfr_df_dict[file.replace('.csv', '')] = pd.read_csv(os.path.join(data_path, file), index_col= 0).reset_index(drop=True)    
    

# EXCEPTION manually add in n6 data. n6 is around wk8-10. 
DE_allgene_df_dict['DE_allgene_WTvsKO_wk8'] = pd.read_csv('../DE_out/Blobel-15045/DE_allgene_WTvsKO_15045.csv', index_col = 0)
DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk8'] = pd.read_csv('../DE_out/Blobel-15045/DE_Olfr_WTvsKO_15045.csv', index_col = 0)
DE_allgene_df_dict['DE_allgene_WTvsKO_wk10'] = pd.read_csv('../DE_out/Blobel-15045/DE_allgene_WTvsKO_14025.csv', index_col = 0)
DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk10'] = pd.read_csv('../DE_out/Blobel-15045/DE_Olfr_WTvsKO_14025.csv', index_col = 0)
DE_allgene_df_dict['DE_allgene_WTvsKO_n6'] = pd.read_csv('../DE_out/Blobel-15045/DE_allgene_WTvsKO_n6.csv', index_col = 0)
DE_Olfr_df_dict['DE_Olfr_WTvsKO_n6'] = pd.read_csv('../DE_out/Blobel-15045/DE_Olfr_WTvsKO_n6.csv', index_col = 0)
DE_allgene_df_dict['DE_allgene_WTvsKO_n12'] = pd.read_csv('../DE_out/WTvKO_ALL/DE_allgene_WTvsKO_ALL.csv', index_col = 0)
DE_Olfr_df_dict['DE_Olfr_WTvsKO_n12'] = pd.read_csv('../DE_out/WTvKO_ALL/DE_Olfr_WTvsKO_ALL.csv', index_col = 0)

print(DE_allgene_df_dict.keys())
print(DE_Olfr_df_dict.keys())

#### Vol comparison plots 

##### misc comparison

In [None]:
diff_OR = DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk5'][DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk5'].FDR < 0.05][['symbol','logFC', 'FDR']].symbol.values
diff_OR = set(diff_OR).union(set(DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk30'][DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk30'].FDR < 0.05][['symbol','logFC', 'FDR']].symbol.values))

fig = plot_utils.compare_vol_plot([DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk5'][(DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk5'].symbol.isin(diff_OR))], 
                            DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk30'][(DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk30'].symbol.isin(diff_OR))]], 
                           DE_df_name = ['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk30'], 
                           fig_dimension = [800,500])
fig.show()
# fig.write_html('../output/Blobel_15620/.html')

In [None]:
fig = plot_utils.compare_vol_plot([DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk5'], 
                                   DE_Olfr_df_dict['DE_Olfr_WTvsKO_wk30']], 
                                   DE_df_name = ['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk30'], 
                                #    fig_dimension = [800,500]
                                   )
fig.show()

In [None]:
"""
allgene DE, subset for Olfr. 
"""

plot_df = []
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)

temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)

# Adding in n6 wtvsko from previous experiment
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk8'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk10'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_n6'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)

fig = plot_utils.compare_vol_plot(plot_df, 
                                   DE_df_name = ['DE_allgene_subsetOlfr_WTvsKO_wk5', 
                                                 'DE_allgene_subsetOlfr_WTvsKO_wk30',
                                                 'DE_allgene_subsetOlfr_WTvsKO_wk8',
                                                 'DE_allgene_subsetOlfr_WTvsKO_wk10',
                                                 'DE_allgene_subsetOlfr_WTvsKO_n6'], 
                                #    fig_dimension = [800,500],
                                   fig_title = 'iR2 WT/KO allgene DE, subset for Olfr',
                                   FDR_line = 0.1)

fig.show()
# fig.write_html("../output/Blobel_15620/Rhbdf2_Olfr/WTvsKO_allgene_subsetOlfr.html")



##### n12 WT vs KO For Fig

In [None]:
"""
allgene DE, subset for Olfr. 
"""

plot_df = []
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'].copy()
plot_df.append(temp)

temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'].copy()
plot_df.append(temp)

# Adding in n6 wtvsko from previous experiment
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk8'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_wk10'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)
temp = DE_allgene_df_dict['DE_allgene_WTvsKO_n6'].copy()
temp = temp[temp.symbol.str.contains('Olfr', na=False)]
plot_df.append(temp)

fig = plot_utils.compare_vol_plot(plot_df, 
                                   DE_df_name = ['DE_allgene_subsetOlfr_WTvsKO_wk5', 
                                                 'DE_allgene_subsetOlfr_WTvsKO_wk30',
                                                 'DE_allgene_subsetOlfr_WTvsKO_wk8',
                                                 'DE_allgene_subsetOlfr_WTvsKO_wk10',
                                                 'DE_allgene_subsetOlfr_WTvsKO_n6'], 
                                #    fig_dimension = [800,500],
                                   fig_title = 'iR2 WT/KO allgene DE',
                                   FDR_line = 0.1)

fig.show()
# fig.write_html("../output/Blobel_15620/Rhbdf2_Olfr/WTvsKO_allgene.html")



In [None]:
"""
Vol Plot for Fig 

all Olfr Vol
"""

# DE Olfr
fig = plot_utils.vol_plot(DE_Olfr_df_dict['DE_Olfr_WTvsKO_n12'],
                          logFC_group = ['WT', 'KO'], 
                          manual_color = {'WT': '#19b2e6', 'KO': '#ee6082', 'na': 'lightgrey'}, 
                          FDR_cutoff = 0.05, 
                          FDR_line=None, 
                          fig_fixed_range = True,
                          opacity = 0.5
                          )
fig.update_layout(
    title='Rhbdf2 DE Olfr n12',
    xaxis_title='logFC (KO/WT)',
    yaxis_title='FDR',
    autosize=True,
    template='simple_white')
fig.show()
# fig.write_html(f'../output/fig_image/volcano//WTvsKO_Olfr_FDRp05.html')


"""
Vol Plot for Fig 

allgenes Vol
- filtered Gm, pseudo genes 
- label Olfr 
"""
FDR_cutoff = 0.2

# DE allgene
plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12'].copy()
plot_df = plot_df.dropna(subset=['symbol'])
plot_df = plot_df[(~plot_df.symbol.str.startswith('Gm')) & 
                  (~plot_df.symbol.str.contains('-ps')) &
                  (~plot_df.symbol.str.contains('Olfr'))]
base_fig = plot_utils.vol_plot(plot_df,
                          logFC_group = ['WT', 'KO'], 
                        #   manual_color = {'WT': '#a8d7cb', 'KO': '#d5a6bd', 'na': 'lightgrey'}, 
                          manual_color = {'WT': '#caedf9', 'KO': '#f7b8c8', 'na': 'lightgrey'}, 
                          FDR_cutoff = FDR_cutoff, 
                          FDR_line=None, 
                          fig_fixed_range = True, 
                          opacity = 0.5, 
                          ymin=-0.5
                          )
# DE allgene Olfr genes 
plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12'].copy()
plot_df = plot_df[(plot_df.symbol.str.startswith('Olfr', na=False))]
olfr_fig = plot_utils.vol_plot(plot_df,
                          logFC_group = ['WT_Olfr', 'KO_Olfr'], 
                        #   manual_color = {'WT_Olfr': '#3b806e', 'KO_Olfr': '#803b4d'}, 
                          manual_color = {'WT_Olfr': '#19b2e6', 'KO_Olfr': '#ee6082'}, 
                          FDR_cutoff = FDR_cutoff, 
                          plot_none_sig=False, 
                          opacity = 0.8, 
                          )

# Combine olfr data points into figure 
fig = go.Figure(base_fig.data + olfr_fig.data)
fig.layout = base_fig.layout

# Redefine xmax and xmin 
fig.update_xaxes(range=[-5, 5])
fig.update_layout(
    title='Rhbdf2 DE allgene n12',
    xaxis_title='logFC (KO/WT)',
    yaxis_title='FDR',
    autosize=True,
    template='simple_white')
fig.show()
# fig.write_html(f'../output/fig_image/volcano/WTvsKO_allgene_FDRp05.html')
# fig.write_html(f'../output/fig_image/volcano/WTvsKO_allgene_FDRp05_Filtered-Gm-ps.html')
# fig.write_html(f'../output/fig_image/volcano/WTvsKO_allgene_FDRp2_Filtered-Gm-ps.html')


##### n12 activity gene bars

In [None]:
fig = plot_utils.vol_plot(DE_allgene_df_dict['DE_allgene_WTvsKO_n12'],
                        interested_genes = ['Rhbdf1', 'Rhbdf2', 'S100a5', 'Dlg2', 'Lrrc3b', 'Pcp4l1', 'Kirrel2'],
                        #   FDR_cutoff = 0.05, 
                          # FDR_line=0.05,
                          fig_fixed_range = True, 
                          opacity = 0.5, 
                          ymin=-3
                          )
# fig = plot_utils.downsample_fig(fig, max_points = int(10000), sample_method='linspace')
fig.update_layout(
    title='DE allgene activity genes n12',
    xaxis_title='logFC (KO/WT)',
    yaxis_title='FDR',
    autosize=True,
    template='simple_white')
fig.show()
# fig.write_html(f'../output/fig_image/volcano//WTvsKO_allgene_activitygenes.html')


In [None]:
DE_allgene_df_dict['DE_allgene_WTvsKO_n12']

In [None]:
interested_genes = ['Rhbdf1', 'Rhbdf2', 'S100a5', 'Dlg2', 'Lrrc3b', 'Pcp4l1', 'Kirrel2']

plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12'].copy()
plot_df = plot_df[plot_df.symbol.isin(interested_genes)]

fig = go.Figure()
for _gene in plot_df.symbol.unique():
    subset = plot_df[plot_df.symbol == _gene]
    fig.add_trace(
        go.Bar(x = subset['symbol'], 
               y = subset['logFC'], 
            #    orientation = 'h',
               text = f'<b>{_gene}</b><br>{plot_utils.format_pvalue(round(subset.FDR.values[0], 1)).replace("p", "FDR")}<br>{plot_utils.format_pvalue(round(subset.PValue.values[0], 1))}',
            #    text = f'{_gene}<br>{plot_utils.format_pvalue(round(subset.PValue.values[0], 3))}',
               name = _gene, 
               textposition='outside',
               opacity = 0.8, 
               showlegend=False
            )
    )
    
# manually assign color
# manual_color = ['lightgrey','pink']
# for i in range(len(fig.data)): 
#     fig.data[i]['marker']['color'] = manual_color[i]


# fig.update_yaxes(range=[1.5*min([min(_data.y) for _data in fig.data]),0])
fig.update_yaxes(range=[-2,0])
fig.update_traces(textfont_size = 20)

fig.update_layout(
    # width=800, height=400, 
    # font=dict(size=12),
    # title='iRhom2 DE',
    xaxis=dict(visible=False),
    yaxis_title='logFC (KO/WT)',
    autosize=True,
    # margin=dict(l=50,r=50,b=50,t=150,pad=10),
    template='simple_white')
fig.show()
# fig.write_html('../output/fig_image/volcano//WTvsKO_allgene_activitygenes_bar.html')
fig.write_html('../output/fig_image/volcano//WTvsKO_allgene_activitygenes_bar_pvalue.html')

#### Olfr change across age

In [None]:
DE_Olfr_df = pd.concat(DE_Olfr_df_dict).reset_index(level=0).rename(columns={'level_0': 'group'})

##### Olfr Pairewise line correlation

In [None]:
from scipy.stats import pearsonr 

# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']
compare_list = [['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk8'], 
                ['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk10'],
                ['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk30'],
                ['DE_Olfr_WTvsKO_wk8', 'DE_Olfr_WTvsKO_wk10'],
                ['DE_Olfr_WTvsKO_wk8', 'DE_Olfr_WTvsKO_wk30'],
                ['DE_Olfr_WTvsKO_wk10', 'DE_Olfr_WTvsKO_wk30']]

for _A, _B in compare_list:
    _A_wk = _A.split('_')[3]
    _B_wk = _B.split('_')[3]
    DE_Olfr_df = pd.merge(DE_Olfr_df_dict[_A][merge_columns], 
                          DE_Olfr_df_dict[_B][merge_columns], on=['symbol'])

    FDR_cutoff = 0.1
    DE_Olfr_df.loc[(((DE_Olfr_df['FDR_x'] < FDR_cutoff) & (DE_Olfr_df['FDR_y'] < FDR_cutoff)) |
                    (DE_Olfr_df['FDR_x'] < FDR_cutoff) |
                    (DE_Olfr_df['FDR_y'] < FDR_cutoff)), 'FDR_group'] = True 
    fig = go.Figure()
    # Create traces for horizontal and vertical lines at y=0 and x=0
    fig.add_shape(x0=-7, x1=7, y0=0, y1=0, type='line', opacity=0.1, line=dict(color='grey', width=3))
    fig.add_shape(x0=0, x1=0, y0=-7, y1=7, type='line', opacity=0.1, line=dict(color='grey', width=3))


    plot_df = DE_Olfr_df[DE_Olfr_df['FDR_group'] == True]
    # plot_df = DE_Olfr_df.copy()
    x_list = plot_df['logFC_x']
    y_list = plot_df['logFC_y']
    fig.add_traces(go.Scatter(x = x_list, 
                                y = y_list,
                                name = f'{_A}<br>{_B}' ,
                                mode = 'markers', 
                                text = plot_df['symbol'],
                                marker=dict(size = 10, opacity = 0.3)
                                ))


    # Calculate Pearson Correlation 
    r, r_p = pearsonr(x_list, y_list)
    # Add a line for the correlation coefficient
    fig.add_trace(go.Scatter(x=[min(x_list), max(x_list)], 
                            y=[r*min(y_list), r*max(y_list)],
                            mode='lines', 
                            line = dict(
                                dash='dot',
                                width = 5,
                                color = 'rgba(0, 0, 0, 0.5)'
                            ),
                            name='Pearson correlation: {} <br>Pearson p-value: {}'.format(round(r,3), 
                                                                                    round(r_p, 5)),
    #                          showlegend=False
                            )
                )


    fig.update_traces( 
        textposition='top center',
        hovertemplate =
        '<b>%{text}</b>' + 
        '<br>LogFC_x: %{x}'+
        '<br>LogFC_y: %{y}<br>')

    fig.update_layout(
        title = f'WT/KO Olfr de logFC between {_A_wk} / {_B_wk}',
        xaxis=dict(title=f'logFC {_A_wk}'),
        yaxis=dict(title=f'logFC {_B_wk}'),
        # autosize=True,
        height=700, 
        width=700, 
        template='simple_white'
    )
    fig.show()

    # fig.write_html(f'../output/Blobel_15620/Rhbdf2_Olfr/corr_scatter/WTvsKO_Olfr_FDRp1_{_A_wk}_{_B_wk}.html')

In [None]:
# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']

df = pd.DataFrame(columns=['A', 'B', 'correlation'])

keys_compare = [k for k in DE_Olfr_df_dict.keys() if 'n6' not in k]
for i, _A in enumerate(keys_compare): 
    for j, _B in enumerate(keys_compare):
        _A_wk = _A.split('_')[3]
        _B_wk = _B.split('_')[3]
        DE_Olfr_df = pd.merge(DE_Olfr_df_dict[_A][merge_columns], 
                            DE_Olfr_df_dict[_B][merge_columns], on=['symbol'])
        FDR_cutoff = 0.1
        DE_Olfr_df.loc[(((DE_Olfr_df['FDR_x'] < FDR_cutoff) & (DE_Olfr_df['FDR_y'] < FDR_cutoff)) |
                        (DE_Olfr_df['FDR_x'] < FDR_cutoff) |
                        (DE_Olfr_df['FDR_y'] < FDR_cutoff)), 'FDR_group'] = True 
        plot_df = DE_Olfr_df[DE_Olfr_df['FDR_group'] == True]
        # plot_df = DE_Olfr_df.copy()
        x_list = plot_df['logFC_x']
        y_list = plot_df['logFC_y']
        # Calculate Pearson Correlation 
        r, r_p = pearsonr(x_list, y_list)
        df.loc[len(df)] = [_A_wk, _B_wk,  r]

# Pivot the DataFrame
result_df = df.pivot(index='A', columns='B', values='correlation')
desired_order = ['wk5', 'wk8', 'wk10', 'wk30']
result_df = result_df.reindex(index=desired_order, columns=desired_order)

# Fill in the values for the symmetric part of the matrix
for i in range(result_df.shape[0]):
    for j in range(i + 1, result_df.shape[1]):
        result_df.iloc[j, i] = result_df.iloc[i, j]

# Create the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(result_df, annot=True, cmap='viridis', fmt=".3f")
plt.title("Week Pairwise Correlation Heatmap")
plt.xlabel("")
plt.ylabel("")
# plt.savefig('../output/Blobel_15620/Rhbdf2_Olfr/corr_scatter/WTvsKO_Olfr_FDRp1_correlation_heatmap.png')
plt.show()

##### Olfr logFC across age 

In [None]:
DE_Olfr_df = pd.concat(DE_Olfr_df_dict).reset_index(level=0).rename(columns={'level_0': 'group'})
DE_Olfr_df['group'] = DE_Olfr_df['group'].str.split('_').str[3]
DE_Olfr_df['group'] = pd.Categorical(DE_Olfr_df['group'], 
                                     categories=['wk5', 'wk8', 'wk10', 'wk30', 'n6', 'n12'], ordered=True)
DE_Olfr_df = DE_Olfr_df.dropna(axis=1)

In [None]:
"""
Plot for Figs 
"""
plot_df = DE_Olfr_df[~DE_Olfr_df.group.isin(['n6'])]

for _group in plot_df.sort_values('group').group.unique(): 
        fig = plot_utils.plot_olfr_lines(plot_df, 
                        FDR_group=[_group], 
                        FDR_cutoff=0.05, 
                        std_shade=False, 
                        exclude_group = ['n12'],
                        labels={'KO+_Olfr': 'greater_than', 
                                'KO-_Olfr': 'lesser_than'}, 
                        manual_color={'KO+_Olfr': '#EF5350', 
                                      'KO-_Olfr': '#19b2e6'},
                        plot_title = 'KO/WT Olfr de logFC across age')
        fig.update_yaxes(range=[-10, 10])
        fig.update_annotations(y=10)
        fig.update_layout(yaxis=dict(title='logFC (KO/WT)'), 
                          showlegend=False)
        fig.show()
        
        # fig.write_html(f'../output/Blobel_15620/Rhbdf2_Olfr/olfr_change_wks/WTvsKO_Olfr_FDRp05_{_group}.html')
        fig.write_html(f'../output/fig_image/olfr_change_wks/WTvsKO_Olfr_FDRp05_{_group}.html')

In [None]:
# Single subplot for all timepoints 


from plotly.subplots import make_subplots

plot_df = DE_Olfr_df[~DE_Olfr_df.group.isin(['n6', 'n12'])]

# Define the number of rows and columns for subplots
num_rows = len(plot_df.sort_values('group').group.unique())
num_cols = 1

# Create subplots
fig_combined = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=plot_df.sort_values('group').group.unique())

# Initialize a list to store annotations and titles
combined_annotations = []
subplot_title = {}
ymin, ymax = 0, 0
# Populate subplots with individual plots and transfer annotations
for i, _group in enumerate(plot_df.sort_values('group').group.unique(), start=1):
        
    fig = plot_utils.plot_olfr_lines(plot_df, 
                    FDR_group=[_group], 
                    FDR_cutoff=0.05, 
                    std_shade=False, 
                    labels={'KO+_Olfr': 'greater_than', 
                            'KO-_Olfr': 'lesser_than'}, 
                    manual_color={'KO+_Olfr': '#EF5350', 
                                  'KO-_Olfr': '#19b2e6'},
                    plot_title = f'KO/WT Olfr de logFC across age - Group {_group}')
    
    subplot_title[_group] = fig.layout.title.text
    
    # Transfer annotations to combined_annotations list
    combined_annotations.extend([dict(xref=f'x{i}', yref=f'y{i}', text=ann.text, 
                                      x=ann.x, y=ann.y, 
                                      showarrow=ann.showarrow, arrowhead=ann.arrowhead, 
                                      ax=ann.ax, ay=ann.ay) for ann in fig.layout.annotations])
    
    # Add traces
    for f in range(len(fig.data)):
        fig.data[f].update(showlegend=False)
        fig_combined.add_trace(fig.data[f], row=i, col=1)

# Update title for each subplots 
fig_combined.for_each_annotation(lambda plot: plot.update(text = subplot_title[plot.text]))

# Add transferred annotations to the combined figure
max_y_annotation = max(ann['y'] for ann in combined_annotations)
for ann in combined_annotations:
    ann['y'] = max_y_annotation
    fig_combined.add_annotation(**ann)
    
# Update ymax and ymin for each trace in the combined figure
for trace in fig_combined.data:
    ymax = max(trace.y) if max(trace.y) > ymax else ymax
    ymin = min(trace.y) if min(trace.y) < ymin else ymin
    
fig_combined.update_yaxes(title_text="logFC (KO/WT)", 
                          range = [ymin*1.2, ymax*1.2])
fig_combined.update_layout(
                        #    height=1200, width=800, 
                           margin=dict(l=50,r=50,b=0,t=100,pad=10),
                           template = 'simple_white')
fig_combined.show()
# fig_combined.write_html('../output/fig_image/olfr_change_wks/WTvsKO_Olfr_FDRp05.html')

In [None]:
# # Deprecated.. It's bias to selectively choose Olfr/genes based on multiple days
# plot_df = DE_Olfr_df[~DE_Olfr_df.group.isin(['n6', 'n12'])]
# compare_groups = [['wk5', 'wk8'], ['wk8', 'wk10'], ['wk10', 'wk30']]
# # compare_groups = [['wk5'], ['wk8']]
# for _group in compare_groups: 
#         fig = plot_utils.plot_olfr_lines(plot_df, 
#                                 FDR_group=_group, 
#                                 FDR_cutoff=0.2, 
#                                 labels={'KO+_Olfr': 'greater_than', 
#                                         'KO-_Olfr': 'lesser_than'}, 
#                                 plot_title = 'WT/KO Olfr de logFC across age')
#         fig.update_yaxes(range=[-10, 10])
#         fig.update_annotations(y=10)
#         fig.update_layout(yaxis=dict(title='logFC (WT/KO)'))
#         fig.show()
#         # fig.write_html(f'../output/fig_image/olfr_change_wks/WTvsKO_Olfr_FDRp05_{_group}.html')

##### allgenes logFC across age

In [None]:
DE_allgene_df = pd.concat(DE_allgene_df_dict).reset_index(level=0).rename(columns={'level_0': 'group'})

In [None]:
import seaborn as sns
from scipy.stats import pearsonr 

# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['ensembl_gene_id', 'logFC', 'FDR']

df = pd.DataFrame(columns=['A', 'B', 'correlation'])

keys_compare = [k for k in DE_allgene_df_dict.keys() if 'n6' not in k]
for i, _A in enumerate(keys_compare): 
    for j, _B in enumerate(keys_compare):
        _A_wk = _A.split('_')[3]
        _B_wk = _B.split('_')[3]
        DE_allgene_df = pd.merge(DE_allgene_df_dict[_A][merge_columns], 
                              DE_allgene_df_dict[_B][merge_columns], on=['ensembl_gene_id'])
        FDR_cutoff = 0.05
        DE_allgene_df.loc[(DE_allgene_df['FDR_x'] < FDR_cutoff) | 
                          (DE_allgene_df['FDR_y'] < FDR_cutoff), 'FDR_group'] = True 
        plot_df = DE_allgene_df[DE_allgene_df['FDR_group'] == True]
        # plot_df = DE_Olfr_df.copy()
        x_list = plot_df['logFC_x']
        y_list = plot_df['logFC_y']
        # Calculate Pearson Correlation 
        r, r_p = pearsonr(x_list, y_list)
        df.loc[len(df)] = [_A_wk, _B_wk,  r]

# Pivot the DataFrame
result_df = df.pivot(index='A', columns='B', values='correlation')
desired_order = ['wk5', 'wk8', 'wk10', 'wk30']
result_df = result_df.reindex(index=desired_order, columns=desired_order)

# Fill in the values for the symmetric part of the matrix
for i in range(result_df.shape[0]):
    for j in range(i + 1, result_df.shape[1]):
        result_df.iloc[j, i] = result_df.iloc[i, j]

# Create the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(result_df, annot=True, cmap='viridis', fmt=".3f")
plt.title("allgene Week Pairwise Correlation Heatmap")
plt.xlabel("")
plt.ylabel("")
# plt.savefig('../output/Blobel_15620/Rhbdf2_allgene/corr_scatter/WTvsKO_Olfr_FDRp05_correlation_heatmap.png')
plt.show()

In [None]:
DE_allgene_df = pd.concat(DE_allgene_df_dict).reset_index(level=0).rename(columns={'level_0': 'group'})
DE_allgene_df['group'] = DE_allgene_df['group'].str.split('_').str[3]
DE_allgene_df['group'] = pd.Categorical(DE_allgene_df['group'], 
                                     categories=['wk5', 'wk8', 'wk10', 'wk30', 'n6', 'n12'], ordered=True)
# Drops columns with nan, but exclude 'symbol' column 
DE_allgene_df = DE_allgene_df.drop(columns= [col for col in DE_allgene_df.columns if col not in ['symbol'] and DE_allgene_df[col].isna().any()])

In [None]:

plot_df = DE_allgene_df[~DE_allgene_df.group.isin(['n6', 'n12'])].dropna()

for _group in plot_df.sort_values('group').group.unique(): 
        print(_group)
        fig = plot_utils.plot_olfr_lines(plot_df, 
                        FDR_group=[_group], 
                        FDR_cutoff=0.05, 
                        std_shade=False, 
                        max_genes=10,
                        manual_color = {'WT': '#bfef45', 'KO': '#fabed4'},
                        labels={'WT': 'lesser_than', 
                                'KO': 'greater_than'}, 
                        plot_title = 'WT/KO Olfr de logFC across age')
        fig.update_yaxes(range=[-10, 10])
        fig.update_annotations(y=10)
        fig.update_layout(yaxis=dict(title='logFC (KO/WT)'))
        fig.show()
        
        # fig.write_html(f'../output/Blobel_15620/Rhbdf2_allgene/allgene_change_wks/WTvsKO_allgene_FDRp05_{_group}.html')  
        # fig.write_html(f'../output/fig_image/olfr_change_wks/WTvsKO_Olfr_FDRp05_{_group}.html')

In [None]:
# Single subplot for all timepoints 


from plotly.subplots import make_subplots

plot_df = DE_allgene_df[~DE_allgene_df.group.isin(['n6', 'n12'])].dropna()

# Define the number of rows and columns for subplots
num_rows = len(plot_df.sort_values('group').group.unique())
num_cols = 1

# Create subplots
fig_combined = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=plot_df.sort_values('group').group.unique())

# Initialize a list to store annotations and titles
combined_annotations = []
subplot_title = {}
# Populate subplots with individual plots and transfer annotations
for i, _group in enumerate(plot_df.sort_values('group').group.unique(), start=1):
        
    fig = plot_utils.plot_olfr_lines(plot_df, 
                        FDR_group=[_group], 
                        FDR_cutoff=0.05, 
                        std_shade=False, 
                        max_genes=50,
                        manual_color = {'WT': '#bfef45', 'KO': '#fabed4'},
                        labels={'WT': 'lesser_than', 
                                'KO': 'greater_than'}, 
                        plot_title = f'KO/WT allgene de logFC across age - Group {_group}')
    
    subplot_title[_group] = fig.layout.title.text
    
    # Transfer annotations to combined_annotations list
    combined_annotations.extend([dict(xref=f'x{i}', yref=f'y{i}', text=ann.text, 
                                      x=ann.x, y=ann.y, 
                                      showarrow=ann.showarrow, arrowhead=ann.arrowhead, 
                                      ax=ann.ax, ay=ann.ay) for ann in fig.layout.annotations])
    
    # Add traces
    for f in range(len(fig.data)):
        fig.data[f].update(showlegend=False)
        fig_combined.add_trace(fig.data[f], row=i, col=1)

# Update title for each subplots 
fig_combined.for_each_annotation(lambda plot: plot.update(text = subplot_title[plot.text]))

# Add transferred annotations to the combined figure
for ann in combined_annotations:
    fig_combined.add_annotation(**ann)
    
fig_combined.update_layout(
                        #    height=1200, width=800, 
                           template = 'simple_white')
fig_combined.update_yaxes(title_text="logFC (KO/WT)")
fig_combined.show()

# fig_combined.write_html('../output/fig_image/olfr_change_wks/WTvsKO_allgene_FDRp05.html')

In [None]:
from scipy.stats import linregress

interested_genes = ['S100a5', 'Rhbdf2', 'Rhbdf1', 'Kirrel2', 'Lrrc3b', 'Dlg2']
manual_color = plot_utils.distinct_colors(interested_genes, category='pastel')

fig = go.Figure()
plot_df = DE_allgene_df[DE_allgene_df.group != 'n6'].dropna()
for _gene in interested_genes:
        subset_df = plot_df[plot_df['symbol'] == _gene]
        subset_df = subset_df.sort_values('group')
        fig.add_traces(go.Scatter(x = subset_df.group, 
                                y = subset_df.logFC,
                                name = _gene, 
                                mode = 'lines+markers', 
                                hovertext = _gene,
                                # legendgroup = _olfr_type, 
                                opacity = 0.8,
                                marker=dict(size = [20 if _fdr < 0.05 else 5 for _fdr in subset_df.FDR ],
                                                color = manual_color[_gene]
                                                ),
                                ))

# Stat test
# fig = plot_utils.label_ranksums_between_labels(fig, labels)

fig.update_traces( 
textposition='top center',
# hovertemplate = '<b>%{hovertext}</b>'+'<br>logFC: %{y}<br>'
)

fig.update_layout(
title = f'Activity genes across weeks',
yaxis=dict(title='logFC'),
template='simple_white')
# fig.write_html(f'../output/Blobel_15620/Rhbdf2_allgene/allgene_change_wks/WTvsKO_activitygenes.html')
fig.show()

#### Nostril Olfr (Santoro) and GEP genes (Tsukahara)

In [None]:
nostril_Olfr = pd.read_csv('../files/Santoro_2020/Occlu_diff_Olfr.csv', index_col = 0)
nostril_Olfr.loc[~(nostril_Olfr.fold_diff == 'Close+_Olfr'), 'fold_diff'] = 'Close-_Olfr'

iR2_Olfr = pd.read_csv('../files/iR2_Olfr.csv', index_col = 0 )

GEP_genes = pd.read_csv('../files/Tsukahara_2021/GEP_genes.csv')

##### Volcano

In [None]:
"""
Olfr DE only Volcano plot
"""

cell_markers = pd.read_csv('../files/CELL_top_markers.csv', index_col = 0)[0:50]
plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12']
fig = go.Figure()
# plot blank all genes excluding markers 
background_df = plot_df[~plot_df.symbol.isin(GEP_genes.gene)]
fig.add_trace(go.Scatter(x=background_df['logFC'], 
                            y=-np.log10(background_df['FDR']),
                            # text=background_df['symbol'],
                            mode='markers', 
                            name = 'allgenes', 
                            marker=dict(size = 10, 
                                        color = 'lightgrey', 
                                        opacity=0.2)))
manual_color = {'GEP_low': '#4363d8','GEP_high':'#ffd8b1'}
for _group in GEP_genes.GEP_group.unique(): 
    temp = plot_df[plot_df.symbol.isin(GEP_genes[GEP_genes['GEP_group'] == _group].gene)]
    fig.add_trace(go.Scatter(x=temp['logFC'], 
                            y=-np.log10(temp['FDR']),
                            text=temp['symbol'],
                            mode='markers', 
                            name=_group,
                            marker=dict(size = 10, 
                                        color = manual_color[_group],
                                        opacity=0.5)))

# Add the horizontal line at y=0.5
fig.add_shape(type='line', x0=-8, x1=8,
              y0=-np.log10(0.05), y1=-np.log10(0.05),
              line=dict(color='grey', width=3, dash='dash'))

fig.update_traces( 
    textposition='top center',
    hovertemplate = '<b>%{text}</b>' + '<br>LogFC: %{x}'+'<br>FDR: %{y}<br>')
fig.update_layout(
    title=f'GEP genes',
    xaxis_title = 'logFC', 
    yaxis_title = '-log(FDR)',
    autosize=True,
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/fig_image/volcano/DE_allgenes_GEPgenes.html')

##### n12

In [None]:
"""
plot for Fig n12 
"""

# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']

fig = go.Figure()

for _genes in list(GEP_genes.GEP_group.unique()) + ['na']:
    _wk = 'n12'
    
    plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12'][merge_columns]
    if _genes in GEP_genes.GEP_group.unique():
        plot_df = plot_df[plot_df['symbol'].isin(GEP_genes[GEP_genes['GEP_group'] == _genes].gene)]
    else: 
        plot_df = plot_df[~plot_df['symbol'].isin(GEP_genes['gene'])]
        plot_df = plot_df.sample(frac=0.01, random_state=0)
        
    fig.add_trace(go.Violin(y=plot_df['logFC'], 
                            name=f'{_genes}_{_wk}', 
                            text = plot_df['symbol'], 
                            # opacity = 0.9,
                            points='all', pointpos = 0))
    fig.update_traces(meanline_visible=True)
        
        
manual_color = ['#4363d8','#ffd8b1', '#D3D3D3']
for i in range(len(fig.data)): 
    fig.data[i]['marker'] = {'color': manual_color[i // 1 % len(manual_color)]}
        
# Stat test
fig = plot_utils.downsample_fig(fig, 
                                sample_method='linspace', 
                                # seed=4,
                                max_points = 100)
fig = plot_utils.add_p_value_annotation(fig, [[0,1], [0,2], [1,2]], 
                                            # y_padding=False,
                                            test_type = 'ranksums', 
                                            p_round=10,
                                            include_tstat=True)
# fig = plot_utils.add_p_value_annotation(fig, [[0,0], [1,1], [2,2]], 
#                                             test_type = 'ttest_1samp', 
#                                             popmean = 0, 
#                                             include_tstat=True)

fig.update_layout(
    showlegend=False,
    title = f'KO/WT Olfr de logFC<br><sup>ttest_1samp (popmean = 0), GEP from Tsukahara2021</sup>',
    yaxis=dict(title=f'logFC KO / WT'),
    # autosize=True,
    # height=700, width=700, 
    margin=dict(l=50,r=50,b=100,t=150,pad=10),
    template='simple_white'
)
fig.show()
fig.write_html(f'../output/fig_image/Violin/GEP_n12.html')



##### wk_comparison 

In [None]:
# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']
compare_list = ['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk8', 'DE_Olfr_WTvsKO_wk30']


fig = go.Figure()
for _de in compare_list:
    for _nostril_olfr in list(nostril_Olfr.fold_diff.unique()) + ['na']:
        _wk = _de.split('_')[3]
        
        plot_df = DE_Olfr_df_dict[_de][merge_columns]
        # Filter for Olfr from santoro 
        if _nostril_olfr in nostril_Olfr.fold_diff.unique():
            plot_df = plot_df[plot_df['symbol'].isin(nostril_Olfr[nostril_Olfr['fold_diff'] == _nostril_olfr].id)]
        else: 
            plot_df = plot_df[~plot_df['symbol'].isin(nostril_Olfr['id'])]
            plot_df = plot_df.sample(frac=0.1, random_state=0)
            
        fig.add_trace(go.Violin(y=plot_df['logFC'], 
                                name=f'{_nostril_olfr}_{_wk}', 
                                text = plot_df['symbol'], 
                                points='all'))
        fig.update_traces(meanline_visible=True)
        
        
manual_color = ['#990011','#317773', '#a9a9a9']
for i in range(len(fig.data)): 
    fig.data[i]['marker'] = {'color': manual_color[i // 1 % len(manual_color)]}
        
# Stat test
fig = plot_utils.add_p_value_annotation(fig, [[0,0], [1,1], [2,2], [3,3], [4,4], [5,5], [6,6], [7,7], [8,8]], 
                                            y_padding=False,
                                            test_type = 'ttest_1samp', 
                                            popmean = 0)

fig.update_layout(
    showlegend=False,
    title = f'KO/WT Olfr de logFC<br><sup>ttest_1samp (popmean = 0), fold_diff from Santoro2015</sup>',
    yaxis=dict(title=f'logFC KO / WT'),
    # autosize=True,
    # height=700, width=700, 
    margin=dict(l=50,r=50,b=100,t=150,pad=10),
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/Blobel_15620/Violin/fold_diff_wkComoparison.html')


In [None]:
# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']
compare_list = ['DE_Olfr_WTvsKO_wk5', 'DE_Olfr_WTvsKO_wk8', 'DE_Olfr_WTvsKO_wk30']

fig = go.Figure()
for _de in compare_list:
    for _olfr in list(iR2_Olfr.iR2_group.unique()) + ['na']:
        _wk = _de.split('_')[3]
        
        plot_df = DE_Olfr_df_dict[_de][merge_columns]
        # Filter for Olfr from santoro 
        if _olfr in iR2_Olfr.iR2_group.unique():
            plot_df = plot_df[plot_df['symbol'].isin(iR2_Olfr[iR2_Olfr['iR2_group'] == _olfr].id)]
        else: 
            plot_df = plot_df[~plot_df['symbol'].isin(iR2_Olfr['id'])]
            plot_df = plot_df.sample(frac=0.1, random_state=0)
            
        fig.add_trace(go.Violin(y=plot_df['logFC'], 
                                name=f'{_olfr}_{_wk}', 
                                text = plot_df['symbol'], 
                                points='all'))
        fig.update_traces(meanline_visible=True)
        
        
manual_color = ['#EF5350','#66BB6A', '#a9a9a9']
for i in range(len(fig.data)): 
    fig.data[i]['marker'] = {'color': manual_color[i // 1 % len(manual_color)]}
        
# Stat test
# fig = plot_utils.add_p_value_annotation(fig, [[0,1], [0,2], [1,2]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums')
# fig = plot_utils.add_p_value_annotation(fig, [[3,4], [3,5], [4,5]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums')
# fig = plot_utils.add_p_value_annotation(fig, [[6,7], [6,8], [7,8]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums')
fig = plot_utils.add_p_value_annotation(fig, [[0,0], [1,1], [2,2], [3,3], [4,4], [5,5], [6,6], [7,7], [8,8]], 
                                            y_padding=False,
                                            test_type = 'ttest_1samp', 
                                            popmean = 0)

fig.update_layout(
    showlegend=False,
    title = f'KO/WT Olfr de logFC<br><sup>ttest_1samp (popmean = 0), iR2 group defined via week8 FDR < 0.2</sup>',
    yaxis=dict(title=f'logFC KO / WT'),
    # autosize=True,
    # height=700, width=700, 
    margin=dict(l=50,r=50,b=100,t=150,pad=10),
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/Blobel_15620/Violin/iR2_wkComoparison.html')


In [None]:
# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']
compare_list = ['DE_allgene_WTvsKO_wk5', 'DE_allgene_WTvsKO_wk8', 'DE_allgene_WTvsKO_wk10', 'DE_allgene_WTvsKO_wk30']


fig = go.Figure()
for _de in compare_list:
    for _genes in list(GEP_genes.GEP_group.unique()) + ['na']:
        _wk = _de.split('_')[3]
        
        plot_df = DE_allgene_df_dict[_de][merge_columns]
        # Filter for Olfr from santoro 
        if _genes in GEP_genes.GEP_group.unique():
            plot_df = plot_df[plot_df['symbol'].isin(GEP_genes[GEP_genes['GEP_group'] == _genes].gene)]
        else: 
            plot_df = plot_df[~plot_df['symbol'].isin(GEP_genes['gene'])]
            
        fig.add_trace(go.Violin(y=plot_df['logFC'], 
                                name=f'{_genes}_{_wk}', 
                                text = plot_df['symbol'], 
                                points='all', pointpos=0))
        fig.update_traces(meanline_visible=True)
        
        
manual_color = ['#4363d8','#ffd8b1', '#D3D3D3']
for i in range(len(fig.data)): 
    fig.data[i]['marker'] = {'color': manual_color[i // 1 % len(manual_color)]}
        
        
fig = plot_utils.downsample_fig(fig, 
                                sample_method='linspace', 
                                max_points = 100)
# Stat test
# fig = plot_utils.add_p_value_annotation(fig, [[0,1], [0,2], [1,2]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums', include_tstat=True)
# fig = plot_utils.add_p_value_annotation(fig, [[3,4], [3,5], [4,5]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums', include_tstat=True)
# fig = plot_utils.add_p_value_annotation(fig, [[6,7], [6,8], [7,8]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums', include_tstat=True)
# fig = plot_utils.add_p_value_annotation(fig, [[9,10], [9,11], [10,11]], 
#                                             # y_padding=False,
#                                             test_type = 'ranksums', include_tstat=True)
fig = plot_utils.add_p_value_annotation(fig, [[0,0], [1,1], [2,2], [3,3], [4,4], [5,5], [6,6], [7,7], [8,8], [9,9], [10,10], [11,11]], 
                                            y_padding=False,
                                            test_type = 'ttest_1samp', 
                                            include_tstat=True,
                                            popmean = 0)

fig.update_layout(
    showlegend=False,
    title = f'KO/WT Olfr de logFC<br><sup>ttest_1samp (popmean = 0), GEP from Tsukahara2021</sup>',
    yaxis=dict(title=f'logFC KO / WT'),
    # autosize=True,
    # height=700, width=700, 
    margin=dict(l=50,r=50,b=100,t=150,pad=10),
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/Blobel_15620/Violin/GEP_wkComoparison.html')
# fig.write_html(f'../output/fig_image/Violin/GEP_wkComoparison.html')


#### Mcclintock Review Activity genes 

In [None]:
nostril_Olfr = pd.read_csv('../files/Santoro_2020/Occlu_diff_Olfr.csv', index_col = 0)


iR2_Olfr = pd.read_csv('../files/iR2_Olfr.csv', index_col = 0 )

Activity_genes = pd.read_csv('../files/Mcclintock_2017review_suppl.csv')

In [None]:
"""
Olfr DE only Volcano plot
"""

plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12']
fig = go.Figure()
# plot blank all genes excluding markers 
background_df = plot_df[~plot_df.symbol.isin(Activity_genes['Gene Symbol'])]
fig.add_trace(go.Scatter(x=background_df['logFC'], 
                            y=-np.log10(background_df['FDR']),
                            # text=background_df['symbol'],
                            mode='markers', 
                            name = 'allgenes', 
                            marker=dict(size = 10, 
                                        color = 'lightgrey', 
                                        opacity=0.2)))

manual_color = plot_utils.distinct_colors(Activity_genes['Response to activity'].unique())
for _group in Activity_genes['Response to activity'].unique(): 
    temp = plot_df[plot_df.symbol.isin(Activity_genes[Activity_genes['Response to activity'] == _group]['Gene Symbol'])]
    fig.add_trace(go.Scatter(x=temp['logFC'], 
                            y=-np.log10(temp['FDR']),
                            text=temp['symbol'],
                            mode='markers', 
                            name=_group,
                            marker=dict(size = 10, 
                                        color = manual_color[_group],
                                        opacity=0.5)))

# Add the horizontal line at y=0.5
fig.add_shape(type='line', x0=-8, x1=8,
              y0=-np.log10(0.05), y1=-np.log10(0.05),
              line=dict(color='grey', width=3, dash='dash'))

fig.update_traces( 
    textposition='top center',
    hovertemplate = '<b>%{text}</b>' + '<br>LogFC: %{x}'+'<br>FDR: %{y}<br>')
fig.update_layout(
    title=f'McClintock 2017 Review Activity genes',
    xaxis_title = 'logFC', 
    yaxis_title = '-log(FDR)',
    autosize=True,
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/fig_image/volcano/DE_allgenes_Activitygenes.html')

In [None]:
"""
plot for Fig n12 
"""

# Cross scatter plot to show the logFC correlation between Olfrs across weeks 
merge_columns = ['symbol', 'logFC', 'FDR']

fig = go.Figure()

for _group in list(Activity_genes['Response to activity'].unique()) + ['na']:
    _wk = 'n12'
    
    plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12'][merge_columns]
    if _group in Activity_genes['Response to activity'].unique():
        plot_df = plot_df[plot_df['symbol'].isin(Activity_genes[Activity_genes['Response to activity'] == _group]['Gene Symbol'])]
    else: 
        plot_df = plot_df[~plot_df['symbol'].isin(Activity_genes['Gene Symbol'])]
        plot_df = plot_df.sample(frac=0.01, random_state=0)
        
    fig.add_trace(go.Violin(y=plot_df['logFC'], 
                            name=f'{_group}_{_wk}', 
                            text = plot_df['symbol'], 
                            # opacity = 0.9,
                            points='all', pointpos = 0))
    fig.update_traces(meanline_visible=True)
    
        
        

# manual_color = plot_utils.distinct_colors(Activity_genes['Response to activity'].unique())
# manual_color['na'] = '#D3D3D3'
manual_color = plot_utils.distinct_colors(list(Activity_genes['Response to activity'].unique()) + ['na'], 
                                          custom_color=['#4363d8','#ffd8b1', '#D3D3D3'])

for i in range(len(fig.data)): 
    fig.data[i]['marker'] = {'color': manual_color[fig.data[i].name.split('_')[0]]}
        
# Stat test
fig = plot_utils.downsample_fig(fig, 
                                sample_method='linspace', 
                                # seed=4,
                                max_points = 100)
fig = plot_utils.add_p_value_annotation(fig, [[0,1], [0,2], [1,2]], 
                                            # y_padding=False,
                                            test_type = 'ranksums', 
                                            include_tstat=True, 
                                            p_round=10)
# fig = plot_utils.add_p_value_annotation(fig, [[0,0], [1,1], [2,2]], 
#                                             test_type = 'ttest_1samp', 
#                                             popmean = 0, 
#                                             include_tstat=True)

fig.update_layout(
    showlegend=False,
    title = f'KO/WT Olfr de logFC<br><sup>ranksums, Activity genes from Wang2017</sup>',
    yaxis=dict(title=f'logFC KO / WT'),
    # autosize=True,
    # height=700, width=700, 
    margin=dict(l=50,r=50,b=100,t=150,pad=10),
    template='simple_white'
)
fig.show()
fig.write_html(f'../output/fig_image/Violin/Activitygenes_n12_2.html')



#### Cell type comparison

##### Volcano

In [None]:
cell_markers = pd.read_csv('../files/CELL_top_markers.csv', index_col = 0)[0:50]

"""
allgene DE cellmarkers only Volcano plot
"""

for i in DE_allgene_df_dict:
    plot_df = DE_allgene_df_dict[i]

    fig = go.Figure()
    for cell in cell_markers: 
        temp = plot_df[plot_df.symbol.isin(cell_markers[cell])]
        fig.add_trace(go.Scatter(x=temp['logFC'], 
                                y=-np.log10(temp['FDR']),
                                text=temp['symbol'],
                                mode='markers', 
                                name=cell,
                                marker=dict(size = 10, opacity=0.3)))

    # Add the horizontal line at y=0.5
    fig.add_shape(type='line', x0=-10, x1=10,
                        y0=-np.log10(0.05), y1=-np.log10(0.05),
                line=dict(color='violet', width=3, dash='dash'))

    fig.update_traces( 
        textposition='top center',
        hovertemplate =
        '<b>%{text}</b>' + 
        '<br>LogFC: %{x}'+
        '<br>FDR: %{y}<br>')

    fig.update_layout(
        title=f'{i} cell markers only',
        autosize=True,
    #     width=500,
    #     height=500,
        template='simple_white'
    )
    fig.show()
    # fig.write_html(f'../output/Blobel_15620/Cellmarker/DE_allgene_{i}_cellmarkers.html')

In [None]:
"""
allgene DE  Volcano plot
"""

cell_markers = pd.read_csv('../files/CELL_top_markers.csv', index_col = 0)[0:50]
plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12']
fig = go.Figure()
# plot blank all genes excluding markers 
background_df = plot_df[~plot_df.symbol.isin(cell_markers.stack().to_list())]
fig.add_trace(go.Scatter(x=background_df['logFC'], 
                            y=-np.log10(background_df['FDR']),
                            # text=background_df['symbol'],
                            mode='markers', 
                            name = 'allgenes', 
                            marker=dict(size = 10, 
                                        color = 'lightgrey', 
                                        opacity=0.2)))
for cell in cell_markers: 
    temp = plot_df[plot_df.symbol.isin(cell_markers[cell])]
    fig.add_trace(go.Scatter(x=temp['logFC'], 
                            y=-np.log10(temp['FDR']),
                            text=temp['symbol'],
                            mode='markers', 
                            name=cell,
                            marker=dict(size = 10, opacity=0.5)))

# Add the horizontal line at y=0.5
fig.add_shape(type='line', x0=-8, x1=8,
              y0=-np.log10(0.05), y1=-np.log10(0.05),
              line=dict(color='grey', width=3, dash='dash'))

fig.update_traces( 
    textposition='top center',
    hovertemplate = '<b>%{text}</b>' + '<br>LogFC: %{x}'+'<br>FDR: %{y}<br>')
fig.update_layout(
    title=f'cell markers only',
    xaxis_title = 'logFC', 
    yaxis_title = '-log(FDR)',
    autosize=True,
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/fig_image/volcano/DE_allgene_cellmarkers.html')

##### Heatmap

In [None]:
cell_markers = pd.read_csv('../files/CELL_top_markers.csv', index_col = 0)[0:50]


In [None]:
n_top_genes = 5

plot_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12']
heatmap_df = pd.DataFrame(columns = ['symbol', 'cell_type', 'logFC'])
for _celltype in cell_markers: 
    subset = plot_df[plot_df.symbol.isin(cell_markers[_celltype][0:n_top_genes])]
    for _i in subset.index: 
        heatmap_df.loc[len(heatmap_df)] = [subset.loc[_i].symbol, _celltype, subset.loc[_i].logFC]

# Initialize an empty dictionary to hold logFC values for each cell_type
heatmap_logFC = {}
heatmap_symbol = {}
# Iterate over each unique cell_type
for _celltype in heatmap_df['cell_type'].unique():
    # Extract logFC values for the current cell_type
    symbols = heatmap_df.loc[heatmap_df['cell_type'] == _celltype, 'symbol'].values
    logFC_values = heatmap_df.loc[heatmap_df['cell_type'] == _celltype, 'logFC'].values
    # Assign the logFC values to the dictionary
    heatmap_symbol[_celltype] = symbols
    heatmap_logFC[_celltype] = logFC_values
    
heatmap_logFC = pd.DataFrame(heatmap_logFC).T
heatmap_symbol = pd.DataFrame(heatmap_symbol).T

plt.figure(figsize=(10, 6))
sns.heatmap(heatmap_logFC, cmap='RdBu_r', annot=heatmap_symbol, fmt="", linewidths=0.5,
            vmin=-0.5, vmax=0.5, center=0)
plt.title('Heatmap of logFC by cell type')
plt.xticks(visible=False)
plt.xlabel('genes')
plt.ylabel('Cell Type')  # No label for the y-axis

plt.savefig('../output/WTvKO_ALL/celltype/Heatmap_celltype_n12.png')
plt.show()

#### GO analysis 

In [None]:
import utils.go_utils as go_utils

##### n12 GO 

In [None]:
"""
Quick GO to look at WT vs ALL data 
"""
DE_allgene_WTvsALL = pd.read_csv('../DE_out/WTvKO_ALL/DE_allgene_WTvsKO_ALL.csv', index_col=0)

wt_genes = DE_allgene_WTvsALL[(DE_allgene_WTvsALL.FDR < 0.1) & 
                              (DE_allgene_WTvsALL.logFC < 0)].symbol.values
ko_genes = DE_allgene_WTvsALL[(DE_allgene_WTvsALL.FDR < 0.1) & 
                              (DE_allgene_WTvsALL.logFC > 0)].symbol.values
go_list = [wt_genes, ko_genes]

go_dict = {}
for i, genes in enumerate(go_list): 
    go_dict[i] = go_utils.go_it(genes)
    go_dict[i]['n_genes/n_go'] = go_dict[i].n_genes/go_dict[i].n_go
    go_dict[i]['n_genes/n_study'] = go_dict[i].n_genes/go_dict[i].n_study
    
go_dict[0]['group'] = 'WT'
go_dict[1]['group'] = 'KO'
go_df = pd.concat([go_dict[0], go_dict[1]])
go_df.to_csv('../output/WTvKO_ALL/GO/GO_terms.csv')

In [None]:
go_df = pd.read_csv('../output/WTvKO_ALL/GO/GO_terms.csv')
fig = px.bar(go_df, 
            x='n_genes', 
            y='term', 
            orientation='h',
            color = 'group', 
            barmode='group',
            hover_data=['study_genes']).update_layout(
                plot_bgcolor='rgba(0, 0, 0, 0)'
                )
            
# manually assign color
manual_color = ['lightgrey','pink']
for i in range(len(fig.data)): 
    fig.data[i]['marker']['color'] = manual_color[i]

            
fig.update_layout(
    title='Rhbdf2 DE',
    xaxis_title='logFC (KO/WT)',
    yaxis_title='FDR',
    autosize=True,
    template='simple_white'
)
fig.show()
# fig.write_html(f'../output/WTvKO_ALL/GO/WTvsKO_n12_GO.html')

##### Exporting data table for figs

In [None]:
from ast import literal_eval
go_df = pd.read_csv('../output/WTvKO_ALL/GO/GO_terms.csv', index_col = 0)
go_df['study_genes'] = go_df['study_genes'].apply(lambda x: literal_eval(x) if "[" in x else x)

In [None]:
# Transform go_df to publish data table figure
go_df_short = go_df.explode('study_genes').reset_index(drop=True)[['GO', 'term', 'class', 'p_corr', 'study_genes', 'group']]
go_df_short.to_csv('../output/fig_image/GO/go_df_short.csv')

In [None]:
go_df_short = go_df_short.sort_values(['group', 'p_corr'])
# Create a dictionary where keys are groups and values are lists of genes
group_genes = {group: go_df_short[go_df_short['group'] == group].drop_duplicates('study_genes')['study_genes'].tolist() for group in go_df_short['group'].unique()}

# Find the maximum length of gene lists
max_length = max(len(genes) for genes in group_genes.values())

# Pad shorter lists with NaN values
for group, genes in group_genes.items():
    padding = max_length - len(genes)
    group_genes[group].extend([''] * padding)

# Create a DataFrame from the dictionary
go_gene_table = pd.DataFrame(group_genes)
go_gene_table.to_csv('../output/fig_image/GO/go_gene_table.csv')

##### GO plot for figs 

In [None]:
"""
Bar only 
"""

go_df = pd.read_csv('../output/WTvKO_ALL/GO/GO_terms.csv')

plot_df = go_df.groupby('group').head(4).sort_values('group', ascending=False)


fig = go.Figure()
for _group in plot_df.group.unique():
    subset = plot_df[plot_df.group == _group].sort_values('p_corr', ascending=False)
    fig.add_trace(
        go.Bar(x = -np.log10(subset['p_corr']), 
               y = subset['term'], 
               name = _group,
               orientation = 'h',
               text = subset['n_genes'],
               textposition='outside', 
               insidetextfont=dict(family='Arial', size=15, color='black'), 
               showlegend=False
            )
    )
# manually assign color
manual_color = plot_utils.distinct_colors(plot_df.group.unique(),
                                          custom_color = ['#a8d7cb', '#d5a6bd'])
for i, _fig in enumerate(fig.data): 
    fig.data[i]['marker']['color'] = manual_color[_fig['name']]


fig.update_traces( 
    hovertemplate =
    '<b>%{text}</b>' + 
    '<br>%{y}')
fig.update_layout(
    title='iRhom2 GO term',
    xaxis=dict(title='Fold enrichment'),
    autosize=True,
    template='simple_white')
fig.show()

# fig.write_html(f'../output/fig_image/GO/WTvsKO_n12_GO.html')

In [None]:
from ast import literal_eval

go_df = pd.read_csv('../output/WTvKO_ALL/GO/GO_terms.csv')
go_df['study_genes'] = go_df['study_genes'].apply(lambda x: literal_eval(x) if "[" in x else x)

In [None]:
"""
GO vol 
"""

go_df = pd.read_csv('../output/WTvKO_ALL/GO/GO_terms.csv')
go_df['study_genes'] = go_df['study_genes'].apply(lambda x: literal_eval(x) if "[" in x else x)

vol_df = pd.read_csv('../DE_out/WTvKO_ALL/DE_allgene_WTvsKO_ALL.csv', index_col = 0)

FDR_cutoff = 0.2
plot_df = vol_df.copy()
plot_df = plot_df.dropna(subset=['symbol'])
plot_df = plot_df[(~plot_df.symbol.str.startswith('Gm')) & 
                  (~plot_df.symbol.str.contains('-ps')) &
                  (~plot_df.symbol.str.contains('Olfr'))]
plot_df = plot_df.dropna(subset=['symbol'])
base_fig = plot_utils.vol_plot(plot_df,
                          logFC_group = ['KO', 'WT'], 
                          manual_color = {'WT': 'lightgrey', 'KO': 'lightgrey', 'na': 'lightgrey'}, 
                          FDR_cutoff = FDR_cutoff, 
                          FDR_line=None, 
                          fig_fixed_range = True, 
                          opacity = 0.5, 
                          ymin=-0.5
                          )
# DE allgene Olfr genes 
go_df = go_df.groupby('group').head(4).sort_values('group', ascending=False)


plot_df = vol_df.copy()
plot_df = plot_df[(plot_df.symbol.isin(sum([_gene for _term in go_df.term for _gene in go_df[go_df.term == _term].study_genes], [])))]
olfr_fig = plot_utils.vol_plot(plot_df,
                          logFC_group = ['WT_GO_genes', 'KO_GO_genes'], 
                        #   manual_color = {'WT_Olfr': '#d5a6bd', 'KO_Olfr': '#a8d7cb'}, 
                          manual_color = {'WT_GO_genes': '#19b2e6', 'KO_GO_genes': '#ee6082'}, 
                          FDR_cutoff = FDR_cutoff, 
                          plot_none_sig=False, 
                          opacity = 0.8, 
                          )

# Combine olfr data points into figure 
fig = go.Figure(base_fig.data + olfr_fig.data)
fig.layout = base_fig.layout


# manually assign color
# manual_color = plot_utils.distinct_colors(plot_df.group.unique(),
#                                           custom_color = ['#a8d7cb', '#d5a6bd'])
# for i, _fig in enumerate(fig.data): 
#     fig.data[i]['marker']['color'] = manual_color[_fig['name']]


fig.update_traces( 
    hovertemplate =
    '<b>%{text}</b>' + 
    '<br>%{y}')
fig.update_layout(
    title='iRhom2 GO term',
    xaxis=dict(title='Fold enrichment'),
    autosize=True,
    template='simple_white')
fig.show()

# fig.write_html(f'../output/fig_image/GO/WTvsKO_n12_GO.html')

##### old 

In [None]:
"""
GO plot for Figure 
"""

from ast import literal_eval

go_df = pd.read_csv('../output/WTvKO_ALL/GO/GO_terms.csv')
go_df['study_genes'] = go_df['study_genes'].apply(lambda x: literal_eval(x) if "[" in x else x)
plot_df = go_df.sort_values(['group','p_corr'], ascending=True).copy()

# Filter by manually defined go_terms 
# plotting_term = ['defense response', 
#                  'innate immune response',
#                  'defense response to protozoan',
#                  'response to bacterium', 
#                  'sensory perception of smell', 
#                  'detection of chemical stimulus involved in sensory perception of smell',
#                  'G protein-coupled receptor signaling pathway', 
#                  'olfactory receptor activity'
#                  ]
# plot_df = plot_df[plot_df['term'].isin(plotting_term)]

# Filter by top n per group 
plot_df = plot_df.groupby('group').head(4)

# Assign bar color with group
# plot_df['color'] = plot_df.group.apply(lambda x: '#d5a6bd' if x == 'KO' else '#a8d7cb')
plot_df['color'] = plot_df.group.apply(lambda x: '#ee6082' if x == 'KO' else '#19b2e6')

# Construct bar dataframe 
bar_x = list(-np.log10(plot_df['p_corr']))
bar_y = list(plot_df['term'])
bar_color = list(plot_df['color'])
# bar_text = list(plot_df['n_genes'])
bar_hover = list(plot_df['study_genes'])
# Construct scatter dataframe 
scatter_x = []
scatter_y = []
scatter_size = []
scatter_color = []
scatter_hover = []

DE_allgene_df = DE_allgene_df_dict['DE_allgene_WTvsKO_n12']
for _term in plot_df.term:
    for _gene in plot_df[plot_df.term == _term].study_genes.item():
        scatter_x.append(DE_allgene_df[DE_allgene_df.symbol == _gene].logFC.item())
        scatter_y.append(_term)
        _fdr = -np.log10(DE_allgene_df[DE_allgene_df.symbol == _gene].FDR.item())
        scatter_size.append( 15 if _fdr >= 2.3 else 12 if  _fdr >= 1.3 else 5)
        # scatter_color.append('#d9ead3' if go_df[go_df.term ==_term].group.item() == 'WT' else '#c90076')
        scatter_color.append('#1389b1' if go_df[go_df.term ==_term].group.item() == 'WT' else '#9f1134')
        scatter_hover.append(_gene)

layout = go.Layout(
    # title='<br>GO analysis between WT vs Rhbdf2 KO<br>\
    #     <sup>WTvsKO allgene DE n12<\sup><span style="font-size: 10px;"> </span>',
#     xaxis=dict(title='Expression AUC'),
#     xaxis2=dict(title='n_genes/n_go', overlaying='x', side='top'),
    xaxis=dict(title='Fold enrichment'),
    xaxis2=dict(title='logFC (WT/ KO)', overlaying='x', side='top'),
    yaxis=dict(autorange="reversed"),
    template='simple_white',
    bargap=0.3,
#     autosize=False,
#     width=800,
#     height=400
    font=dict(
        size=15,  # Set the font size here
    )
)

# Create the figure object and add the traces
fig = go.Figure(layout=layout)

# Create a scatter plot with a different x-axis
fig.add_trace(go.Scatter(
                    x=scatter_x,
                    y=scatter_y,
                    mode='markers',
                    hovertext = scatter_hover,
                    marker = dict(
                        color = scatter_color, 
                        size = scatter_size,
                        # line=dict(width=1.5, color='rgb(0, 0, 0)')
#                         opacity = 0.5,
                    ),
                    showlegend=False,
                    xaxis = 'x2')
             )

# Create a horizontal bar chart
fig.add_trace(go.Bar(
                    x=bar_x,
                    y=bar_y,
                    # text= bar_text,
                    hovertext=bar_hover,
                    textposition='inside',
                    insidetextanchor="start",
                    insidetextfont=dict(family='Arial', size=15, color='black'),
                    orientation='h',
                    marker=dict(
                        color=bar_color
                          ),
                    showlegend=False)
             )


fig.update_layout(
    xaxis2=dict(range=[-6,6])
)

# Show the plot
fig.show()
# fig.write_html(f'../output/fig_image/GO/WTvsKO_n12_GO_all.html')
# fig.write_html(f'../output/fig_image/GO/WTvsKO_n12_GO_top4.html')

In [None]:
"""
Conduct GO analysis on all the columns in GeneSets from different anlaysis
into the df dictionary 
"""
logFC_cutoff = 0
FDR_cutoff = 0.1

wt_wk5_genes = DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'][(DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'].FDR < FDR_cutoff) & 
                                                           (DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'].logFC < -logFC_cutoff)].symbol.values
ko_wk5_genes = DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'][(DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'].FDR < FDR_cutoff) & 
                                                           (DE_allgene_df_dict['DE_allgene_WTvsKO_wk5'].logFC > logFC_cutoff)].symbol.values
wt_wk30_genes = DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'][(DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'].FDR < FDR_cutoff) & 
                                                             (DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'].logFC < -logFC_cutoff)].symbol.values
ko_wk30_genes = DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'][(DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'].FDR < FDR_cutoff) & 
                                                             (DE_allgene_df_dict['DE_allgene_WTvsKO_wk30'].logFC > logFC_cutoff)].symbol.values

go_list = [wt_wk5_genes, ko_wk5_genes, wt_wk30_genes, ko_wk30_genes]
list_name = ['wt_wk5_genes', 'ko_wk5_genes', 'wt_wk30_genes', 'ko_wk30_genes']

go_dict = {}
for i, genes in enumerate(go_list): 
    k = list_name[i]
    go_dict[k] = go_utils.go_it(genes)
    go_dict[k]['n_genes/n_go'] = go_dict[k].n_genes/go_dict[k].n_go
    go_dict[k]['n_genes/n_study'] = go_dict[k].n_genes/go_dict[k].n_study
    # Create some labels for grouping later 
    go_dict[k]['group'] = k.split('_')[0]
    go_dict[k]['age'] = k.split('_')[1]

In [None]:
from IPython.display import display

for wk in ['wk5', 'wk30']: 
    # Grabs both wt and ko go from go_dict that's shares the same week 
    go_df = pd.concat([go_dict[k] for k in go_dict.keys() if wk in k ])
    display(go_df[go_df.study_genes.astype(str).str.contains('Olfr')])

In [None]:
for wk in ['wk5', 'wk30']: 

    # Grabs both wt and ko go from go_dict that's shares the same week 
    go_df = pd.concat([go_dict[k] for k in go_dict.keys() if wk in k ])
    # go_df.to_csv(f'../output/Blobel_15620/GO/{wk}_GO_terms.csv')

    fig = px.bar(go_df, 
                x='n_genes', 
                y='term', 
                orientation='h',
                color = 'group', 
                hover_data=['study_genes']).update_layout(
                    plot_bgcolor='rgba(0, 0, 0, 0)'
                    )
                
    # manually assign color
    manual_color = {'wt': 'lightgrey',
                    'ko': 'pink'}
    for i in range(len(fig.data)): 
        fig.data[i]['marker']['color'] = manual_color[fig.data[i]['name']]

                
    fig.update_layout(
        title=f'iR2 WT/KO DE {wk}',
        autosize=True,
        template='simple_white'
    )
    fig.show()
    fig.write_html(f'../output/Blobel_15620/GO/{wk}_GO.html')


#### ... 