# Setup

In [None]:
# Base imports
import os
import pickle

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

# ML import
from sklearn.decomposition import NMF
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from pyphylon.util import load_config

In [None]:
CONFIG = load_config("config.yml")
WORKDIR = CONFIG["WORKDIR"]
SPECIES = CONFIG["PG_NAME"]

In [None]:
L_BIN = os.path.join(WORKDIR, f'processed/nmf-outputs/L_binarized.csv')
L_BIN = pd.read_csv(L_BIN, index_col=0)
L_BIN.columns = [f'phylon{i}' for i in range(1, L_BIN.shape[1]+1)]

In [None]:
# collection functions...

from pyphylon.biointerp import collect_functions
# only run me once:
# all_functions = collect_functions(WORKDIR, 'processed/bakta/')
# all_functions.to_csv(os.path.join(WORKDIR, 'processed/all_functions.csv'))

all_functions = pd.read_csv(os.path.join(WORKDIR, 'processed/all_functions.csv'), index_col=0)

In [None]:
# Get the pan-genome
df_genes = pd.read_pickle(os.path.join(WORKDIR, f'processed/cd-hit-results/{SPECIES}_strain_by_gene.pickle.gz'))

In [None]:
from pyphylon.biointerp import get_pg_to_locus_map
# Data wrangling to get the functions for each cluster            
pg2locus_map = get_pg_to_locus_map(WORKDIR, SPECIES)
functions2genes = pd.merge(all_functions, pg2locus_map, left_on='locus', right_on='gene_id')
cluster_functions = functions2genes.groupby('cluster').first().reset_index()[['cluster','product','go']]
cluster_functions

In [None]:
from pyphylon.biointerp import explode_go_annos
cluster_to_go_functions = explode_go_annos(cluster_functions)
cluster_to_go_functions

In [None]:
go_functions_count = cluster_to_go_functions.groupby('go').count()
go_functions = go_functions_count[go_functions_count['cluster'] > 3].sort_values('cluster', ascending=False)
go_functions

In [None]:
#calculate a single engirchment
from pyphylon.biointerp import calc_enrichment
go_term = 'GO:0005524'
phylon = 'phylon1'
calc_enrichment(L_BIN, cluster_to_go_functions, go_term, functions2genes, phylon, phylon_contribution_cutoff=0)  

In [None]:
from pyphylon.biointerp import calc_all_phylon_go_enrichments, get_go_mapping  # TODO need to speed this up - shrinking functions2genes to only accessory genes seemed to help...

phylon_go_enrichments = calc_all_phylon_go_enrichments(L_BIN, functions2genes, cluster_to_go_functions, go_functions, phylon_contribution_cutoff=0.5)
phylon_go_enrichments = phylon_go_enrichments[phylon_go_enrichments['p_value']<0.05]

go_mapping = get_go_mapping()
phylon_go_enrichments = pd.merge(phylon_go_enrichments, go_mapping, left_on='function', right_index=True, how='left')
missing_go = phylon_go_enrichments[phylon_go_enrichments['name'].isnull()].index
phylon_go_enrichments.loc[missing_go, 'name'] = phylon_go_enrichments.loc[missing_go,'function']

phylon_go_enrichments = phylon_go_enrichments[phylon_go_enrichments['function']!='SO:0001217'] # filter out SO:0001217 is just a category for "protein encoding gene"
phylon_go_enrichments.to_csv(os.path.join(WORKDIR, 'processed/phylon_go_enrichment.csv'))

phylon_go_enrichments = pd.read_csv(os.path.join(WORKDIR, 'processed/phylon_go_enrichment.csv'), index_col=0)

In [None]:
phylon_go_enrichments_mat = pd.pivot_table(phylon_go_enrichments, index='phylon', columns='function', values='p_value')
sns.clustermap(phylon_go_enrichments_mat.fillna(0.05), cmap='rocket_r')
plt.title('phylon functional enrichments')

In [None]:
# Explore a single phylon:
phylon = 'phylon1'
phylon_go_enrichments[phylon_go_enrichments['phylon']==phylon][:10]

In [None]:
# Explore a all phylons:
from pyphylon.biointerp import gen_phylon_wordcloud
for phylon in phylon_go_enrichments['phylon'].unique():
    phylon_enr = phylon_go_enrichments[phylon_go_enrichments['phylon']==phylon]
    phylon_enr.loc[:,'products'] = phylon_enr['products'].str.replace(';', '<br>')
    fig = px.scatter(phylon_enr, x='overlap', y='logp', text='name', size='overlap', hover_data='products')
    print(phylon)
    fig.show()
    gen_phylon_wordcloud(L_BIN, functions2genes, phylon, cutoff=0)

