# Setup

In [None]:
# Base imports
import os
import pickle

# Compute imports
import numpy as np
import pandas as pd

import scipy
from scipy import spatial as sp
from scipy.spatial.distance import hamming, squareform
from scipy.sparse import csr_matrix
from scipy.cluster import hierarchy as hc
from scipy.cluster.hierarchy import cophenet


from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

# ML import
import prince
from sklearn.decomposition import NMF
from sklearn.cluster import KMeans
from sklearn.metrics import mean_squared_error, median_absolute_error, confusion_matrix, auc
from sklearn.mixture import GaussianMixture
from pyphylon.util import load_config

from pyphylon.pangenome import get_gene_frequency_submatrices, connectivity
from pyphylon.models import run_nmf, normalize_nmf_outputs, binarize_nmf_outputs, generate_nmf_reconstructions, calculate_nmf_reconstruction_metrics

In [None]:
CONFIG = load_config("config.yml")
WORKDIR = CONFIG["WORKDIR"]
SPECIES = CONFIG["PG_NAME"]

In [None]:
df_genes = pd.read_pickle(os.path.join(WORKDIR, f'processed/cd-hit-results/{SPECIES}_strain_by_gene.pickle.gz'))
df_genes.fillna(0, inplace=True)
df_genes = df_genes.sparse.to_dense().astype('int8')

df_genes.shape

In [None]:
metadata = pd.read_csv(os.path.join(WORKDIR, 'interim/enriched_metadata_2d.csv'), index_col=0, dtype='object')
metadata

In [None]:
# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# # Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id]
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
# Get sparse representations of the P matrix
df_genes_complete_sparse = df_genes_complete.astype(pd.SparseDtype("int8", 0))

coo_genes = df_genes_complete_sparse.sparse.to_coo()
csr_genes = csr_matrix(coo_genes)
csr_genes

In [None]:
# Convert sparse P matrix into a frequency matrix
df_genes_freq = pd.DataFrame(index=df_genes_complete_sparse.index, data=csr_genes.sum(axis=1), columns=['freq'])
df_genes_freq = df_genes_freq.freq
df_genes_freq.sort_values().hist()

In [None]:
# Import (full) accessory genome
df_acc_complete = pd.read_csv(os.path.join(WORKDIR, 'processed/CAR_genomes/df_acc.csv'), index_col=0)
df_acc_complete

## Submatrices of accessory genome

In [None]:
P_submatrices = get_gene_frequency_submatrices(df_acc_complete)
P_submatrices[0][100] = None # this is just the full accessory genome, we can remove this "submatrix"

# MCA (for rank analysis)

In [None]:
# Step 1: Perform MCA with prince (run once)
mca = prince.MCA(
    n_components=df_acc_complete.shape[1],  # Set the number of components
    n_iter=3,           # Set the number of iterations for the CA algorithm
    copy=True,
    check_input=True,
    engine='sklearn',
    random_state=42
)
mca = mca.fit(df_acc_complete)  # Fit MCA on the dataframe

mca

In [None]:
# Step 2: Extract and plot the cumulative explained variance
explained_variance_percentage = mca.percentage_of_variance_  # Retrieve exp. inertia which gives variance explained by each component
cumulative_variance = pd.Series(explained_variance_percentage).cumsum()  # Cumulative sum to find cumulative explained inertia

In [None]:
# Create the plot (full)
sns.set(style='whitegrid')
plt.figure(figsize=(12, 6))
plt.plot(cumulative_variance, marker='o', linestyle='-', color='blue')
plt.title('Cumulative Explained Variance by Dimension')
plt.xlabel('Number of Dimensions')
plt.ylabel('Cumulative Explained Variance')

# Calculate explained inertia (variance) thresholds
threshold = {num: cumulative_variance[cumulative_variance >= num].index[0] for num in range(1,99)}

# Add vertical lines for explained inertia thresholds
plt.axvline(x=threshold[70], color='grey', label=f'70% Explained Variance: {threshold[70]}', linestyle='--')
plt.axvline(x=threshold[75], color='limegreen', label=f'75% Explained Variance: {threshold[75]}', linestyle='--')
plt.axvline(x=threshold[80], color='purple', label=f'80% Explained Variance: {threshold[80]}', linestyle='--')
plt.axvline(x=threshold[85], color='pink', label=f'85% Explained Variance: {threshold[85]}', linestyle='--')
plt.axvline(x=threshold[90], color='maroon', label=f'90% Explained Variance: {threshold[90]}', linestyle='--')

plt.legend()
plt.show()

In [None]:
# Create the plot (first n components)
n_significant_components = (explained_variance_percentage > 0.01).sum()

sns.set(style='whitegrid')
plt.figure(figsize=(12, 6))
plt.plot(cumulative_variance[:n_significant_components], marker='o', linestyle='-', color='blue')
plt.title('Cumulative Explained Inertia (Variance) by Dimension')
plt.xlabel('Number of Dimensions')
plt.ylabel('Cumulative Explained Inertia (Variance)')

# Add vertical lines for explained inertia thresholds
plt.axvline(x=threshold[70], color='grey', label=f'70% Explained Variance: {threshold[70]}', linestyle='--')
plt.axvline(x=threshold[75], color='limegreen', label=f'75% Explained Variance: {threshold[75]}', linestyle='--')
plt.axvline(x=threshold[80], color='purple', label=f'80% Explained Variance: {threshold[80]}', linestyle='--')
plt.axvline(x=threshold[85], color='pink', label=f'85% Explained Variance: {threshold[85]}', linestyle='--')
plt.axvline(x=threshold[90], color='maroon', label=f'90% Explained Variance: {threshold[90]}', linestyle='--')

plt.legend()
plt.show()

In [None]:
MASH_RANK = 16 # Mash rank

In [None]:
rank_list = sorted(set([
    2,
    MASH_RANK,
    threshold[70],
    threshold[75],
    threshold[80],
    threshold[85],
    threshold[90],
]))

rank_list

# NMF decomposition on accessory genome

- NMF at various ranks
- Find "best" model
- NMF around "best" model
- Repeat for NMF of submatrices

## Initial NMF decomposition across ranks

In [None]:
W_dict, H_dict = run_nmf(
    data=df_acc_complete,
    ranks=rank_list,
    max_iter=10_000
)

In [None]:
L_norm_dict, A_norm_dict = normalize_nmf_outputs(df_acc_complete, W_dict, H_dict)

In [None]:
L_binarized_dict, A_binarized_dict = binarize_nmf_outputs(L_norm_dict, A_norm_dict)

In [None]:
P_reconstructed_dict, P_error_dict, P_confusion_dict = generate_nmf_reconstructions(df_acc_complete, L_binarized_dict, A_binarized_dict)

In [None]:
df_metrics = calculate_nmf_reconstruction_metrics(P_reconstructed_dict, P_confusion_dict)

In [None]:
df_metrics.sort_values(by='AIC')

## Rerunning with extra values near "best rank"

In [None]:
# New ranks
best_rank = df_metrics['AIC'].idxmin(axis=0)
extra_ranks = rank_list.copy() + list(range(best_rank-3, best_rank+3+1))

# NMF run
W_dict, H_dict = run_nmf(
    data=df_acc_complete,
    ranks=extra_ranks,
    max_iter=10_000
)

# Postprocess
L_norm_dict, A_norm_dict = normalize_nmf_outputs(df_acc_complete, W_dict, H_dict)
L_binarized_dict, A_binarized_dict = binarize_nmf_outputs(L_norm_dict, A_norm_dict)

# Reconstruction & error metrics
P_reconstructed_dict, P_error_dict, P_confusion_dict = generate_nmf_reconstructions(df_acc_complete, L_binarized_dict, A_binarized_dict)
df_metrics = calculate_nmf_reconstruction_metrics(P_reconstructed_dict, P_confusion_dict)

df_metrics.sort_values(by='AIC')

In [None]:
best_rank_by_aic = df_metrics['AIC'].idxmin(axis=0)

## Running NMF with extra rank list on submatrices

In [None]:
# Initialize submatrix dict of dicts
W_submatrices = dict.fromkeys(P_submatrices.keys())
H_submatrices = dict.fromkeys(P_submatrices.keys())

L_norm_submatrices = dict.fromkeys(P_submatrices.keys())
A_norm_submatrices = dict.fromkeys(P_submatrices.keys())

L_binarized_submatrices = dict.fromkeys(P_submatrices.keys())
A_binarized_submatrices = dict.fromkeys(P_submatrices.keys())

P_reconstructed_submatrices =  dict.fromkeys(P_submatrices.keys())
P_error_submatrices = dict.fromkeys(P_submatrices.keys())
P_confusion_submatrices = dict.fromkeys(P_submatrices.keys())

df_metrics_submatrices = dict.fromkeys(P_submatrices.keys())


# Helper function to make dict of dicts
def make_dict_in_dict(d: dict):
    keys = sorted(d.keys())
    for key in keys:
        d[key] = dict.fromkeys(keys)

# List for helper function
dod_list = [
    W_submatrices, H_submatrices,
    L_norm_submatrices, A_norm_submatrices,
    L_binarized_submatrices, A_binarized_submatrices,
    P_reconstructed_submatrices, P_error_submatrices, P_confusion_submatrices,
    df_metrics_submatrices
]

# Make dict of dicts
for dod in dod_list:
    make_dict_in_dict(dod)

# Actual NMF decomposition
for min_key in tqdm(P_submatrices.keys(), desc='Iterating over min keys'):
    for max_key in tqdm(P_submatrices.keys(), desc='Iterating over max keys'):
        if min_key == 0 and max_key == 100:
            continue
        if min_key < max_key:
            # NMF run
            W_submatrices[min_key][max_key], H_submatrices[min_key][max_key] = run_nmf(
                data=P_submatrices[min_key][max_key],
                ranks=extra_ranks,
                max_iter=10_000
            )
            
            # Postprocess
            L_norm_submatrices[min_key][max_key], A_norm_submatrices[min_key][max_key] = normalize_nmf_outputs(
                P_submatrices[min_key][max_key],
                W_submatrices[min_key][max_key],
                H_submatrices[min_key][max_key]
            )
            L_binarized_submatrices[min_key][max_key], A_binarized_submatrices[min_key][max_key] = binarize_nmf_outputs(
                L_norm_submatrices[min_key][max_key],
                A_norm_submatrices[min_key][max_key]
            )
            
            # Reconstruction & error metrics
            a, b, c = generate_nmf_reconstructions(
                P_submatrices[min_key][max_key],
                L_binarized_submatrices[min_key][max_key],
                A_binarized_submatrices[min_key][max_key]
            )
            P_reconstructed_submatrices[min_key][max_key] = a
            P_error_submatrices[min_key][max_key] = b
            P_confusion_submatrices[min_key][max_key] = c
            
            df_metrics_submatrices[min_key][max_key] = calculate_nmf_reconstruction_metrics(
                P_reconstructed_submatrices[min_key][max_key],
                P_confusion_submatrices[min_key][max_key]
            )

In [None]:
# list of valid ranges
l = []

for x in sorted(P_submatrices.keys()):
    for y in sorted(P_submatrices.keys()):
        if x == 0 and y == 100:
            continue
        elif x < y:
            l.append((x, y))

# Initialize dicts
best_ranks_dict = dict.fromkeys(l)
best_ranks_dict['full'] = df_metrics['AIC'].idxmin(axis=0)

best_L_norm_dict = dict.fromkeys(l)
best_L_norm_dict['full'] = L_norm_dict[df_metrics['AIC'].idxmin(axis=0)]

best_A_norm_dict = dict.fromkeys(l)
best_A_norm_dict['full'] = A_norm_dict[df_metrics['AIC'].idxmin(axis=0)]

best_L_binarized_dict = dict.fromkeys(l)
best_L_binarized_dict['full'] = L_binarized_dict[df_metrics['AIC'].idxmin(axis=0)]

best_A_binarized_dict = dict.fromkeys(l)
best_A_binarized_dict['full'] = A_binarized_dict[df_metrics['AIC'].idxmin(axis=0)]

# Get best ranks and models
for min_key in P_submatrices.keys():
    for max_key in P_submatrices.keys():
        if min_key == 0 and max_key == 100:
            continue
        elif min_key < max_key:
            print(f'submatrix range: {min_key, max_key}')
            print(f"best rank by AIC: {df_metrics_submatrices[min_key][max_key]['AIC'].idxmin(axis=0)}")
            
            best_ranks_dict[(min_key, max_key)] = df_metrics_submatrices[min_key][max_key]['AIC'].idxmin(axis=0)
            best_L_norm_dict[(min_key, max_key)] = L_norm_submatrices[min_key][max_key][best_ranks_dict[(min_key, max_key)]]
            best_A_norm_dict[(min_key, max_key)] = A_norm_submatrices[min_key][max_key][best_ranks_dict[(min_key, max_key)]]
            
            best_L_binarized_dict[(min_key, max_key)] = L_binarized_submatrices[min_key][max_key][best_ranks_dict[(min_key, max_key)]]
            best_A_binarized_dict[(min_key, max_key)] = A_binarized_submatrices[min_key][max_key][best_ranks_dict[(min_key, max_key)]]

In [None]:
df_metrics['AIC'].idxmin(axis=0) # Best rank for the full accessory genome (for comparison)

## Finding robust clusters of strains

### Consensus matrix With all submatrices

In [None]:
conn_dict = dict.fromkeys(P_submatrices.keys())
make_dict_in_dict(conn_dict)

for min_key in P_submatrices.keys():
    for max_key in P_submatrices.keys():
        if min_key == 0 and max_key == 100:
            continue
        elif min_key < max_key:
            conn_dict[min_key][max_key] = connectivity(
                P_submatrices[min_key][max_key].values,
                A_binarized_submatrices[min_key][max_key][best_ranks_dict[(min_key, max_key)]].values
            )

In [None]:
# Consensus matrix for these runs (H matrix, default)
consensus_matrix = np.zeros(shape=conn_dict[0][25].shape)

num_conn_mat = 0
for min_key in P_submatrices.keys():
    for max_key in P_submatrices.keys():
        if min_key == 0 and max_key == 100:
            continue
        elif min_key < max_key:
            num_conn_mat += 1
            consensus_matrix += conn_dict[min_key][max_key]

consensus_matrix /= num_conn_mat

df_consensus_matrix = pd.DataFrame(consensus_matrix, index=df_acc_complete.columns, columns=df_acc_complete.columns)
df_consensus_matrix

In [None]:
# change this to get different number of clusters

# Minimum acceptable value for robust clusters
thresh = 0.7

# change this to get a different linkage (by method)
df_consensus_dist = 1 - df_consensus_matrix
link = hc.linkage(scipy.spatial.distance.squareform(df_consensus_dist), method='ward')

# retrieve clusters using fcluster
dist = scipy.spatial.distance.squareform(df_consensus_dist)

consensus_clst = pd.DataFrame(index=df_acc_complete.columns)
consensus_clst['cluster'] = hc.fcluster(link, thresh * dist.max(), 'distance')

In [None]:
consensus_clst.cluster.max()

In [None]:
# Bar plot showing sizes of each consensus strain cluster (at thresh = 0.5)
px.bar(
    x=consensus_clst.cluster.value_counts().sort_index().index,
    y=consensus_clst.cluster.value_counts().sort_index().values
)

In [None]:
# Color each NMF cluster (default matplotlib colors)

#cm = matplotlib.colormaps.get_cmap('tab20')
cmb = matplotlib.colormaps.get_cmap('tab20b')
cmc = matplotlib.colormaps.get_cmap('tab20c')
cm_colors = cmb.colors + cmc.colors

consensus_clr = dict(zip(sorted(consensus_clst.cluster.unique()), cm_colors))
consensus_clst['color'] = consensus_clst.cluster.map(consensus_clr)

print('Number of colors: ', len(consensus_clr))
print('Number of clusters', len(consensus_clst.cluster.unique()))

In [None]:
size = 9

#legend_TN = [patches.Patch(color=c, label=l) for l,c in mash_color_dict_31.items()] # Mash cluster for legend

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_consensus_matrix,
    figsize=(size,size),
    row_linkage=link,
    #row_colors=phylogroup_clst.color, # Phylogroup colors on left
    col_linkage=link,
    #col_colors=clst.color, # Mash cluster on top
    yticklabels=False,
    xticklabels=False,
    cmap='hot_r'
)

#l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.01,0.75), handles=legend_TN, frameon=True)
#l2.set_title(title='Mash cluster',prop={'size':10})

In [None]:
# upper diagonal elements of consensus
avec = np.array([consensus_matrix[i, j] for i in range(consensus_matrix.shape[0] - 1)
                 for j in range(i + 1, consensus_matrix.shape[1])])

# consensus entries are similarities, conversion to distances
Y = 1 - avec
Z = hc.linkage(Y, method='ward')

# cophenetic correlation coefficient of a hierarchical clustering
# defined by the linkage matrix Z and matrix Y from which Z was
# generated
coph_cor, _ = cophenet(Z, Y)

coph_cor # Cophenetic correlation of consensus matrix (ideally 0.7 or higher)

In [None]:
dispersion = np.sum(4 * np.multiply(consensus_matrix - 0.5, consensus_matrix - 0.5)) / consensus_matrix.size

dispersion # Dispersion coefficient of consensus matrix

### After removing small-rank models ((50, 75) and (75,100))

In [None]:
conn_dict_filt = dict.fromkeys(P_submatrices.keys())
make_dict_in_dict(conn_dict_filt)

for min_key in P_submatrices.keys():
    for max_key in P_submatrices.keys():
        if min_key == 0 and max_key == 100:
            continue
        elif min_key == 50 and max_key == 100:
            continue
        elif min_key == 75 and max_key == 100:
            continue
        elif min_key < max_key:
            conn_dict_filt[min_key][max_key] = connectivity(
                P_submatrices[min_key][max_key].values,
                A_binarized_submatrices[min_key][max_key][best_ranks_dict[(min_key, max_key)]].values
            )

In [None]:
# Consensus matrix for these runs (H matrix, default)
consensus_matrix_filt = np.zeros(shape=conn_dict_filt[0][25].shape)

num_conn_mat = 0
for min_key in P_submatrices.keys():
    for max_key in P_submatrices.keys():
        if min_key == 0 and max_key == 100:
            continue
        elif min_key == 50 and max_key == 100:
            continue
        elif min_key == 75 and max_key == 100:
            continue
        elif min_key < max_key:
            num_conn_mat += 1
            consensus_matrix_filt += conn_dict_filt[min_key][max_key]

consensus_matrix_filt /= num_conn_mat

df_consensus_matrix_filt = pd.DataFrame(consensus_matrix_filt, index=df_acc_complete.columns, columns=df_acc_complete.columns)
df_consensus_matrix_filt

In [None]:
# change this to get different number of clusters

# Minimum acceptable value for robust clusters = 50%
thresh = 0.7

# change this to get a different linkage (by method)
df_consensus_filt_dist = 1 - df_consensus_matrix_filt
link = hc.linkage(scipy.spatial.distance.squareform(df_consensus_filt_dist), method='ward')

# retrieve clusters using fcluster
dist = scipy.spatial.distance.squareform(df_consensus_filt_dist)

consensus_clst_filt = pd.DataFrame(index=df_acc_complete.columns)
consensus_clst_filt['cluster'] = hc.fcluster(link, thresh * dist.max(), 'distance')

In [None]:
consensus_clst_filt.cluster.max()

In [None]:
# Bar plot showing sizes of each consensus strain cluster (at thresh = 0.5)
px.bar(
    x=consensus_clst_filt.cluster.value_counts().sort_index().index,
    y=consensus_clst_filt.cluster.value_counts().sort_index().values
)

In [None]:
# Color each NMF cluster (default matplotlib colors)

#cm = matplotlib.colormaps.get_cmap('tab20')
cmb = matplotlib.colormaps.get_cmap('tab20b')
cmc = matplotlib.colormaps.get_cmap('tab20c')
cm_colors = cmb.colors + cmc.colors

consensus_clr_filt = dict(zip(sorted(consensus_clst_filt.cluster.unique()), cm_colors))
consensus_clst_filt['color'] = consensus_clst_filt.cluster.map(consensus_clr_filt)

print('Number of colors: ', len(consensus_clr_filt))
print('Number of clusters', len(consensus_clst_filt.cluster.unique()))

In [None]:
size = 9

#legend_TN = [patches.Patch(color=c, label=l) for l,c in mash_color_dict_31.items()] # Mash cluster for legend

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_consensus_matrix_filt,
    figsize=(size,size),
    row_linkage=link,
    #row_colors=phylogroup_clst.color, # Phylogroup colors on left
    col_linkage=link,
    #col_colors=clst.color, # Mash cluster on top
    yticklabels=False,
    xticklabels=False,
    cmap='hot_r'
)

#l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.01,0.75), handles=legend_TN, frameon=True)
#l2.set_title(title='Mash cluster',prop={'size':10})

In [None]:
# upper diagonal elements of consensus
avec = np.array([consensus_matrix_filt[i, j] for i in range(consensus_matrix_filt.shape[0] - 1)
                 for j in range(i + 1, consensus_matrix_filt.shape[1])])

# consensus entries are similarities, conversion to distances
Y = 1 - avec
Z = hc.linkage(Y, method='ward')

# cophenetic correlation coefficient of a hierarchical clustering
# defined by the linkage matrix Z and matrix Y from which Z was
# generated
coph_cor, _ = cophenet(Z, Y)

coph_cor # Cophenetic correlation of consensus matrix (ideally 0.7 or higher)

In [None]:
dispersion = np.sum(4 * np.multiply(consensus_matrix_filt - 0.5, consensus_matrix_filt - 0.5)) / consensus_matrix_filt.size

dispersion # Dispersion coefficient of consensus matrix

In [None]:
(consensus_clst.cluster.value_counts() >= 10).sum()

### Finding robust sets across ranks

In [None]:
best_L_norm_dict.keys()

## Find best run for main model

In [None]:
L_norm = best_L_norm_dict['full']
A_norm = best_A_norm_dict['full']

L_bin = best_L_binarized_dict['full']
A_bin = best_A_binarized_dict['full']

# Save NMF outputs

In [None]:
newpath = os.path.join(WORKDIR, 'processed/nmf-outputs/') 
if not os.path.exists(newpath):
    os.makedirs(newpath)
L_norm.to_csv(os.path.join(WORKDIR, 'processed/nmf-outputs/L.csv'))
A_norm.to_csv(os.path.join(WORKDIR, 'processed/nmf-outputs/A.csv'))

L_bin.to_csv(os.path.join(WORKDIR, 'processed/nmf-outputs/L_binarized.csv'))
A_bin.to_csv(os.path.join(WORKDIR, 'processed/nmf-outputs/A_binarized.csv'))