In [None]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import plotly.express as px 


import anndata
import scanpy as sc 


In [None]:
adata = anndata.read_csv('../files/Tsukahara_2021/GSE173947_envA_timecourse_umi_counts.csv')
meta = pd.read_csv('../files/Tsukahara_2021/GSE173947_envA_timecourse_metadata.csv', index_col = 0 )

# Combine metadata 
adata.obs = adata.obs.merge(meta, left_index = True, 
                            right_index = True)
# store a raw copy 
raw_adata = adata.copy()


In [None]:
adata

### TODO Tasks 
- identify the change in ORs expression across the change of environment
- create category bins for up-regulated, no-change, and down-regulated via change of environment 
- Observe the change of Rhbdf2 and associated genes if there are consistent patterns across bins 

#### Preprocessing

In [None]:
# Restart adata 
adata = raw_adata.copy()

In [None]:
# Basic preprocessing steps
sc.pp.filter_cells(adata, min_genes=200)  # Filter cells with fewer than 200 expressed genes
sc.pp.filter_genes(adata, min_cells=3)  # Filter genes expressed in fewer than 3 cells

# Calculate QC metrics
# SKIPPING mt filter as the data does not contain any mt genes 
# adata.var['mt'] = adata.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)
# sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'],
#              jitter=0.4, multi_panel=True) 
# Filter high mt cells 
# adata = adata[adata.obs.n_genes_by_counts < 2500, :]
# adata = adata[adata.obs.pct_counts_mt < 5, :]

# Normalize counts 
sc.pp.normalize_total(adata, target_sum=1e4)  # Normalize counts

# log-transform
# sc.pp.log1p(adata) 

# Scale the data
# sc.pp.scale(adata, max_value=10)

#### Umap visualization

In [None]:
# Perform PCA
sc.tl.pca(adata, svd_solver='arpack')

sc.pp.neighbors(adata, n_neighbors=10, n_pcs=50)

sc.tl.leiden(adata)

# Perform UMAP
sc.tl.paga(adata)
sc.pl.paga(adata, plot=False)  # remove `plot=False` if you want to see the coarse-grained graph
sc.tl.umap(adata, init_pos='paga')

In [None]:
# Plot UMAP
sc.set_figure_params(figsize = [5,5], facecolor = 'white')
sc.pl.umap(adata, color = ['source'], size = 10)

In [None]:
sc.pl.umap(adata, color = ['Rhbdf2', 'leiden'], size = 10)

#### Seperating clusters of differentially regulated ORs across environment change

In [None]:
from scipy.stats import ttest_ind

def diff_across_variable(adata, 
                         diff_across = 'source',
                         genes = []):

    top_Olfr = [Olfr for Olfr in adata.obs.top_Olfr.unique() if 'Olfr' in Olfr]
    
    environments = adata.obs[diff_across].unique()

    results = pd.DataFrame()
    for i in range(len(environments)):
        for j in range(i+1, len(environments)):
            
            env1 = environments[i]
            env2 = environments[j]
            
            # Subset the data for the two sources
            adata_subset = adata[adata.obs['source'].isin([env1, env2])].copy()
            
            # Filter genes based on the gene_subset
            for Olfr in top_Olfr:
                try: 
                    genes_to_compare = [Olfr] + genes
                    # Subset the genes to selected Olfr
                    adata_subset_Olfr = adata_subset[:, genes_to_compare]
                    # Subset cells to selected Olfr 
                    adata_subset_Olfr = adata_subset_Olfr[(adata_subset_Olfr.obs.top_Olfr == Olfr)]
                    
                    # Check if Olfr is in present in both environments adata if not skip 
                    if not len(adata_subset_Olfr.obs.env.unique()) == 2:
                        continue 
                    

                    # Extract Olfr reads 
                    env1_genes = adata_subset_Olfr[adata_subset_Olfr.obs.source == env1].X
                    env2_genes = adata_subset_Olfr[adata_subset_Olfr.obs.source == env2].X
                    
                    # Calculate the mean of Olfr expression across cells in different environment
                    # env1_genes_mean = env1_genes.mean(axis = 0)
                    # env2_genes_mean = env2_genes.mean(axis = 0)
                    # Calculate fold change of mean Olfr expression between environments
                    # foldchange = np.log(float(env1_genes_mean / env2_genes_mean))
                    
                    # Perform t-test, 
                    t_stat, p_value = ttest_ind(env1_genes, env2_genes)
                    

                    colname = []
                    values = []
                    for k, gene in enumerate(genes_to_compare):
                        if 'Olfr' in gene: 
                            colname.append("Olfr_tstat")
                            colname.append("Olfr_pvalue")
                        else: 
                            colname.append(f"{gene}_tstat")
                            colname.append(f"{gene}_pvalue")
                        values.append(t_stat[k])
                        values.append(p_value[k])
                        
                    # Store the result in a result df 
                    result_df = pd.DataFrame([[Olfr, f'{env1}_{env2}']], 
                                            columns=['Olfr', 'env_change'])
                    result_df = pd.concat([result_df,
                                            pd.DataFrame([values], columns = colname)], axis = 1)
                    results = pd.concat([results, result_df])
                except Exception as e: 
                    # print(e)
                    # print(f'{Olfr} skipped')
                    continue
                    



    return results

# adata_subset_Olfr = adata_subset[:, 'Olfr1018']
# adata_subset_Olfr = adata_subset_Olfr[(adata_subset_Olfr.obs.top_Olfr == 'Olfr1018')]


In [None]:
# Find gene expression difference across envirionments 
results = diff_across_variable(adata, 
                               genes = ['Rhbdf2', 'S100a5'])

# Filter out Olfr present in only one cell 
results = results[~np.isnan(results['Olfr_tstat'])]

In [None]:
# # Creating bins of identifiers for Olfr across environment change 
results['Olfr_change'] = 'no_change'
results.loc[(results['Olfr_tstat']> 2) & (results['Olfr_pvalue'] < 0.05), 'Olfr_change'] = 'up_regulated'
results.loc[(results['Olfr_tstat']< 2) & (results['Olfr_pvalue'] < 0.05), 'Olfr_change'] = 'down_regulated'
results['Olfr_change'].hist()

In [None]:
results[results.Olfr_change == 'up_regulated'].Rhbdf2_tstat.hist()

In [None]:
# Save csv
# results.to_csv('../output/nb_4/results.csv')
# Read csv 
results = pd.read_csv('../output/nb_4/results.csv', index_col = 0)

In [None]:
# Quickly visualize the t_stat (statistcial difference) between change of environment 


plot_data = results.copy()
plot_data = plot_data[plot_data['env_change'].isin(['baseline-overnight_envA-45m',
                                                    'baseline-overnight_envA-2h',
                                                    'baseline-overnight_envA-24h',
                                                    'baseline-overnight_envA-5d',
                                                    'baseline-overnight_1ON-envA2w'])]

plot_data = plot_data.groupby('env_change')

# Create a figure and axis object for the plot
fig, ax = plt.subplots()
# Iterate over the grouped data and plot histograms with different colors
for category, group in plot_data:
    ax.hist(group['Olfr_tstat'], bins=100, alpha=0.5, label=str(category))
# Add labels and legend to the plot
ax.set_xlabel('Olfr_stat')
ax.set_ylabel('Frequency')
ax.legend(title='env_change')

# Show the plot
plt.show()

In [None]:
# Plot to visualize the foldchange and pvalue 

plot_by = 'Olfr'

plot_data = results.copy()
fig = px.scatter(x = plot_data[plot_by+'_tstat'], 
                 y = -np.log(plot_data[plot_by+'_pvalue']), 
                 color = plot_data['env_change'], 
                 hover_name = plot_data['Olfr'],
                 opacity = 0.5)
fig.add_hline(y=-np.log(0.05), 
              line_width = 3, 
              line_dash = 'dash')
fig.update_layout(
    title=f"Foldchange of {plot_by} expression counts across env_change",
    xaxis_title="t-statistic",
    yaxis_title="-log(p_value)",
    autosize=False,
    width=700,
    height=500,
    template = 'simple_white'
)
fig.show()
# fig.write_html('../output/nb_4/plots/Vol_Olfr_tstat_env_change.html')
# fig.write_html('../output/nb_4/plots/Vol_S100a5_tstat_env_change.html')

In [None]:
results

In [None]:
adata.obs