# Setup

In [None]:
# Base imports
import os
import pickle

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

# ML import
from sklearn.decomposition import NMF
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from pyphylon.util import load_config



In [None]:
CONFIG = load_config("config.yml")
WORKDIR = CONFIG["WORKDIR"]
SPECIES = CONFIG["PG_NAME"]

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(os.path.join(WORKDIR, f'processed/cd-hit-results/{SPECIES}_strain_by_gene.pickle.gz'))

# Load in metadata
ENRICHED_METADATA = os.path.join(WORKDIR, 'interim/enriched_metadata_2d.csv')
# Load in (full) metadata
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

# Filter metadata for Complete sequences only
metadata = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata.genome_id].copy()
df_genes_complete.fillna(0, inplace=True) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
# # Load in eggNOG annotations
# df_eggnog = pd.read_csv('df_eggnog.csv', index_col=0)
# df_eggnog.fillna('-', inplace=True)

# display(
#     df_eggnog.shape,
#     df_eggnog.head()
# )

In [None]:
# Load in L_binarized matrix
L_BIN = os.path.join(WORKDIR, 'processed/nmf-outputs/L_binarized.csv')
L_binarized = pd.read_csv(L_BIN, index_col=0)
L_binarized

In [None]:
# Grab the clustering order (by Ward's minimum variance)
g = sns.clustermap(
    L_binarized,
    method='ward',
    cmap='hot_r'
);