# Setup

In [None]:
# Base imports
import os
import pickle

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

# ML import
from sklearn.decomposition import NMF
from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, median_absolute_error, r2_score
from pyphylon.util import load_config
from pyphylon.models import recommended_threshold


In [None]:
CONFIG = load_config("config.yml")
WORKDIR = CONFIG["WORKDIR"]
SPECIES = CONFIG["PG_NAME"]

In [None]:
DF_GENES = os.path.join(WORKDIR, f'processed/cd-hit-results/{SPECIES}_strain_by_gene.pickle.gz')
ENRICHED_METADATA = os.path.join(WORKDIR, 'interim/enriched_metadata_2d.csv')
#DF_EGGNOG = '/media/pekar2/pan_phylon/Enterobacter/processed/df_eggnog.csv'

DF_CORE_COMPLETE = os.path.join(WORKDIR, f'processed/CAR_genomes/df_core.csv')
DF_ACC_COMPLETE = os.path.join(WORKDIR, f'processed/CAR_genomes/df_acc.csv')
DF_RARE_COMPLETE = os.path.join(WORKDIR, f'processed/CAR_genomes/df_rare.csv')

L_MATRIX = os.path.join(WORKDIR, f'processed/nmf-outputs/L.csv')
A_MATRIX = os.path.join(WORKDIR, f'processed/nmf-outputs/A.csv')

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

# Load in (full) metadata
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

# Load in eggNOG gene annotations
#df_eggnog = pd.read_csv(DF_EGGNOG, index_col=0, dtype='object')

In [None]:
# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id]
df_genes_complete = df_genes_complete.fillna(0) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
df_core_complete = pd.read_csv(DF_CORE_COMPLETE, index_col=0)
df_acc_complete = pd.read_csv(DF_ACC_COMPLETE, index_col=0)
df_rare_complete = pd.read_csv(DF_RARE_COMPLETE, index_col=0)

df_acc_complete.shape

In [None]:
L = pd.read_csv(L_MATRIX, index_col=0)
A = pd.read_csv(A_MATRIX, index_col=0)

A.index = [f'phylon{x}' for x in A.index]
L.columns = [f'phylon{x}' for x in L.columns]

display(
    L.shape,
    L.head(),
    A.shape,
    A.head()
)

# Normalize L and A matrices

In [None]:
normalization_vals = [1/np.quantile(L[col], q=0.99) for col in L.columns]
recipricol_vals = [1/x for x in normalization_vals]

D1 = np.diag(normalization_vals)
D2 = np.diag(recipricol_vals)

sns.heatmap(pd.DataFrame(np.dot(D1,D2)), cmap='hot_r', annot=True) # Ensure they multiply to Identity

In [None]:
L_norm = pd.DataFrame(np.dot(L, D1), index=L.index, columns=L.columns)
A_norm = pd.DataFrame(np.dot(D2, A), index=A.index, columns=A.columns)

In [None]:
# Plot initial L_norm clustermap
sns.clustermap(L_norm,
               method='ward',
               cmap='hot_r',
               yticklabels=False,
               robust=True
              )

# Binarize L matrix by 3-means clustering (`L_binarized`)

In [None]:
# Initialize an empty array to hold the binarized matrix
L_binarized = np.zeros_like(L_norm.values)

# Loop through each column
for col_idx in trange(L_norm.values.shape[1]):
    column_data = L_norm.values[:, col_idx]

    # Reshape the column data to fit the KMeans input shape
    column_data_reshaped = column_data.reshape(-1, 1)

    # Apply 3-means clustering (generally better precision-recall tradeoff than 2-means)
    kmeans = KMeans(n_clusters=3, random_state=0, n_init='auto')
    kmeans.fit(column_data_reshaped)
    labels = kmeans.labels_
    centers = kmeans.cluster_centers_

    # Find the cluster with the highest mean
    highest_mean_cluster = np.argmax(centers)

    # Binarize the column based on the cluster with the highest mean
    binarized_column = (labels == highest_mean_cluster).astype(int)

    # Update the binarized matrix
    L_binarized[:, col_idx] = binarized_column

# Typecast to DataFrame
L_binarized = pd.DataFrame(L_binarized, index=L_norm.index, columns=L_norm.columns)

# Characterize phylons by strain affinity (`A_binarized`)

`L_norm`, `A_norm`, & `L_binarized` will also have their phylons renamed to match this characterization

__NOTE:__ the first one is provided as an example, please continue doing this for all other phylons until you have characterized everything

## Plotting strain affinities for each phylon

In [None]:
# phylon affinity distribution by MLST (go ahead and add in species coloring if you want for Enterobacter, it will aid in characterization)
# For this case, the recommended threshold may be too high (a value of 1 or 0.8 would be more suited to be threshold)
i = 0

display(
    px.histogram(A_norm.loc[f'phylon{i}'], color=metadata_complete.mlst, log_y=True),
    f'recommdended threshold: {recommended_threshold(A_norm, i)}'
)

## High affinity strains

In [None]:
# Change this to another value if you are using another threshold besides the recommended one
# Generally you want this value to be no less than 0.6
curr_threshold = recommended_threshold(A_norm, i)

In [None]:
high_affinity = A_norm.loc[f'phylon{i}'] >= curr_threshold
high_affinity_strains = A_norm.loc[f'phylon{i}'][high_affinity].index

# You can add in species classification here too
display(
    metadata.set_index('genome_id').loc[high_affinity_strains].mlst.value_counts(),
    metadata.set_index('genome_id').loc[high_affinity_strains].bioproject_accession.value_counts()
)

## Max affinity strain

In [None]:
metadata.set_index('genome_id').loc[A_norm.loc[f'phylon{i}'].idxmax()] # max affinity strain

## Mapping and thresholds

In [None]:
# Change the names of the phylons once you have characterized them
# If you cannot find an initial characterization, label it "unchar-x"
# where x is the number

# phylon_mapping = {
#     'phylon0': 'A-Thailand',
#     'phylon1': 'E-ST11',
#     'phylon2': 'B2-other',
#     'phylon3': 'A-K12',
#     'phylon4': 'D-ST38',
#     'phylon5': 'A-other',
# }

In [None]:
# Set thresholds with k-means as a guide
# if you are changing the threshold from the recommended value,
# add in the k-means suggestion as a comment

# A_thresholds = {
#     'A-Thailand': 0.55,
#     'E-ST11': 0.60, # k-means suggestion: 0.78
#     'B2-other': 0.76,
#     'A-K12': 0.79,
#     'D-ST38': 0.64,
#     'A-other': 0.54,
#     'B2-ST131': 0.60, # k-means suggestion: 0.87
#     'Shigella-flexneri': 0.68,
#     'B1-other': 0.59,
#     'B1-ShigaToxin': 0.91,
#     'Shigella-sonnei': 0.91,
#     'unchar-1': 0.32,
#     'C': 0.60, # k-means suggestion: 0.90
#     'F': 0.70, # k-means suggestion: 0.85
#     'A-BL21': 0.70, # k-means suggestion: 0.85
#     'D-ST32': 1.0,
# }

In [None]:
# L_norm.rename(mapper=phylon_mapping, axis=1, inplace=True)
# L_binarized.rename(mapper=phylon_mapping, axis=1, inplace=True)
# A_norm.rename(mapper=phylon_mapping, axis=0, inplace=True)

In [None]:
A_binarized = pd.DataFrame(np.zeros_like(A_norm.values), index=A_norm.index, columns=A_norm.columns)

for idx in A_binarized.index: # same as A_norm.index
    cond = A_norm.loc[idx] >= 0.5 # no adjusting threshold
    A_binarized.loc[idx,cond] = 1

In [None]:
sns.heatmap(A_binarized, cmap='Greys', xticklabels=False) # Sanity check

# Plot `L_norm` & `L_binarized` matrices

In [None]:
# Plot renamed L_norm clustermap
g = sns.clustermap(L_norm,
                   method='ward',
                   cmap='hot_r',
                   yticklabels=False,
                   robust=True
                  )

In [None]:
# Plot renamed L_binarized clustermap
g_bin = sns.clustermap(L_binarized,
                       method='ward',
                       metric='euclidean', # metric must be euclidean for ward (even if binarized)
                       cmap='Greys',
                       yticklabels=False,
                       robust=True
                      )

In [None]:
phylon_order = g_bin.data2d.columns.to_list()

## Uncover relationship between `gene freq` and `num of active phylons`

In [None]:
px.histogram(L_binarized.sum(axis=1), color=L_binarized.sum(axis=1))

# xx (z%) genes in 0 phylons
# xx (z%) genes in yy phylons (all phylons minus unchar-modes)
# xx genes (z%) in only 1 phylon (genes with most differentiating power)
# xx genes (z%) in 2 phylons

In [None]:
pd.DataFrame(L_binarized.sum(axis=1).value_counts()).sort_index().cumsum() # Cum sum

In [None]:
# Cum sum line plot
sns.lineplot(pd.DataFrame(L_binarized.sum(axis=1).value_counts()).sort_index().cumsum())

In [None]:
df_gene_freq_by_phylon = pd.DataFrame(index=L_binarized.index)

df_gene_freq_by_phylon['num_active_phylons'] = L_binarized.sum(axis=1)
df_gene_freq_by_phylon['gene_freq'] = df_genes.loc[L_binarized.index].sum(axis=1)
df_gene_freq_by_phylon['gene_freq'] = df_gene_freq_by_phylon['gene_freq'].sparse.to_dense()

# Create a Linear Regression model
model = LinearRegression()

# Fit the model
X = df_gene_freq_by_phylon['gene_freq'].values.reshape(-1,1)
y = df_gene_freq_by_phylon['num_active_phylons'].values
model.fit(X, y)

# Predict the y-values
y_pred = model.predict(X)

# Extract coefficients (slope) and intercept from the model
slope = model.coef_
intercept = model.intercept_

# Calculate R^2 value
r2 = r2_score(y, y_pred)

# Display results
print(f'Line of best fit: y={slope}*x + {intercept}')
print(f'R2 score: {r2}')

ax = sns.regplot(df_gene_freq_by_phylon, x='gene_freq', y='num_active_phylons')
plt.show()

# Plot `L_binarized` with sorted genes & phylons

In [None]:
gene_order = []

# Add in zero-phylon genes
zero_cond = L_binarized.sum(axis=1) == 0
gene_order.extend(L_binarized[zero_cond].index)

# Add in single-phylon genes
for phylon in phylon_order:
    single_cond = L_binarized.sum(axis=1) == 1
    inPhylon = L_binarized[phylon] == 1
    gene_order.extend(L_binarized[inPhylon & single_cond].index)

# Add in poly-phylon genes
for num_active_phylons in trange(2, int(L_binarized.sum(axis=1).max())+1):
    num_cond = L_binarized.sum(axis=1) == num_active_phylons
    gg = sns.clustermap(L_binarized[num_cond], method='ward', metric='euclidean', col_cluster=False, yticklabels=False);
    gene_order.extend(gg.data2d.index)

In [None]:
# Main sorted clustermap

g = sns.clustermap(
    L_binarized.loc[gene_order],
    method='ward',
    metric='euclidean',
    row_cluster=False,
    yticklabels=False,
    cmap='Greys'
)

In [None]:
# Histogram of genes in L_binarized by num of phylons they are active in
fig, ax = plt.subplots()
sns.histplot(L_binarized.sum(axis=1), binwidth=1, ax=ax)
plt.show()

# Plot sorted, corresponding A_binarized matrix

In [None]:
strain_order = []
unchar_strain_order = []


# zero-phylon strains (66, 3%)
noPhylon = A_binarized.sum() == 0
strain_order.extend(A_binarized.sum()[noPhylon].index.tolist())

# strain lists
single_phylon_strains = A_binarized.sum()[A_binarized.sum() == 1].index # 715 (30%)
multi_phylon_strains = A_binarized.sum()[A_binarized.sum() > 1].index # 1596 (67%)

for phylon in phylon_order:
    if 'unchar' in phylon:
        continue
    else:
        phylon_aff_binarized_single = A_binarized.loc[phylon, single_phylon_strains]
        phylon_aff_binarized_multi = A_binarized.loc[phylon, multi_phylon_strains]
    
        inPhylon_single = phylon_aff_binarized_single == 1
        inPhylon_multi = phylon_aff_binarized_multi == 1
    
        list1 = phylon_aff_binarized_single[inPhylon_single].index.tolist()
        list2 = phylon_aff_binarized_multi[inPhylon_multi].index.tolist()
        new_list2 = list(set(list2) - set(strain_order)) # ensures no double-counting
        
        strain_order.extend(list1)
        strain_order.extend(new_list2)

for phylon in phylon_order: # must be done after the first loop
    if 'unchar' in phylon:
        phylon_aff_binarized_single = A_binarized.loc[phylon, single_phylon_strains]
        phylon_aff_binarized_multi = A_binarized.loc[phylon, multi_phylon_strains]
    
        inPhylon_single = phylon_aff_binarized_single == 1
        inPhylon_multi = phylon_aff_binarized_multi == 1
    
        list1 = phylon_aff_binarized_single[inPhylon_single].index.tolist()
        list2 = phylon_aff_binarized_multi[inPhylon_multi].index.tolist()
        new_list1 = list(set(list1) - set(strain_order)) # ensures no double-counting
        new_list2 = list(set(list2) - set(strain_order)) # ensures no double-counting
        
        strain_order.extend(new_list1)
        strain_order.extend(new_list2)

strain_order += unchar_strain_order

len(strain_order)

In [None]:
A_binarized

In [None]:
sns.clustermap(A_binarized.loc[phylon_order, strain_order], cmap='Greys', xticklabels=False, row_cluster=False, col_cluster=False)

# Save L and A matrices

In [None]:
L_NORM = os.path.join(WORKDIR, 'processed/nmf-outputs/L_norm.csv')
A_NORM = os.path.join(WORKDIR, 'processed/nmf-outputs/A_norm.csv')

L_BIN = os.path.join(WORKDIR, 'processed/nmf-outputs/L_binarized.csv')
A_BIN = os.path.join(WORKDIR, 'processed/nmf-outputs/A_binarized.csv')

In [None]:
# Normalized matrices
L_norm.to_csv(L_NORM)
A_norm.to_csv(A_NORM)

# Binarized matrices
L_binarized.to_csv(L_BIN)
A_binarized.to_csv(A_BIN)