# MASH
Summary of [MASH](https://github.com/marbl/Mash) results from project: `[{{ project().name }}]`

## Description
Fast genome and metagenome distance estimation using MinHash

In [None]:
import pandas as pd
from pathlib import Path

import warnings
warnings.filterwarnings('ignore')

#import os
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as shc
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import networkx as nx
import community as community_louvain
import plotly.graph_objects as go
sns.set_context("paper")

def kMeansRes(scaled_data, k, alpha_k=0.02):
    '''
    # Calculating clusters from https://medium.com/towards-data-science/an-approach-for-choosing-number-of-clusters-for-k-means-c28e614ecb2c
    Parameters 
    ----------
    scaled_data: matrix 
        scaled data. rows are samples and columns are features for clustering
    k: int
        current k for applying KMeans
    alpha_k: float
        manually tuned factor that gives penalty to the number of clusters
    Returns 
    -------
    scaled_inertia: float
        scaled inertia value for current k           
    '''
    
    inertia_o = np.square((scaled_data - scaled_data.mean(axis=0))).sum()
    # fit k-means
    kmeans = KMeans(n_clusters=k, random_state=0).fit(scaled_data)
    scaled_inertia = kmeans.inertia_ / inertia_o + alpha_k * k
    return scaled_inertia

def chooseBestKforKMeans(scaled_data, k_range):
    ans = []
    for k in k_range:
        scaled_inertia = kMeansRes(scaled_data, k)
        ans.append((k, scaled_inertia))
    results = pd.DataFrame(ans, columns = ['k','Scaled Inertia']).set_index('k')
    if len(results) > 0:
        best_k = results.idxmin()[0]
    else:
        print("WARNING: Cannot determine best k, returning k as 1.")
        best_k = 1
    return best_k, results

In [None]:
def create_edge_trace(Graph, name, showlegend=False, color='#888', width=0.5, opacity=0.8,
                      legendgroup="edges", legendgrouptitle_text="edges"):
    edge_trace = go.Scatter(
        x=[],
        y=[],
        name=name,
        opacity=opacity,
        line=dict(width=width,color=color),
        hoverinfo='none',
        mode='lines',
        showlegend=showlegend,
        legendgroup=legendgroup,
        legendgrouptitle_text=legendgrouptitle_text,)

    edges = np.array([edge for edge in Graph.edges() if G.edges[edge]["relation_type"] == name])
    pos = np.array([Graph.nodes[e]['pos'] for e in edges.flatten()]).reshape(-1, 2)
    xs = np.insert(pos[:, 0], np.arange(2, len(pos[:, 0]), 2), None)
    ys = np.insert(pos[:, 1], np.arange(2, len(pos[:, 1]), 2), None)
    edge_trace['x'] = xs
    edge_trace['y'] = ys

    return edge_trace

def create_node_trace(G, node_trace_category, color, showtextlabel=False, nodesize=10, nodeopacity=0.8, 
                      nodesymbol="circle", linewidth=1, linecolor="black", textposition="top center", showlegend=False,
                     legendgroup="nodes", legendgrouptitle_text="nodes"):
    if showtextlabel:
        markermode = "markers+text"
    else:
        markermode = "markers"
    nodes = np.array([node for node in G.nodes() if G.nodes[node]["node_trace"] == node_trace_category])
    pos = np.array([G.nodes[node]['pos'] for node in nodes.flatten()]).reshape(-1, 2)
    xs, ys = pos[:, 0], pos[:, 1]
    texts = np.array([G.nodes[node]['text'] for node in nodes])
    node_trace = go.Scatter(
        x=xs.tolist(),
        y=ys.tolist(),
        text=texts.tolist(),
        textposition=textposition,
        mode=markermode,
        hoverinfo='text',
        name=node_trace_category,
        showlegend=showlegend,
        legendgroup=legendgroup,
        legendgrouptitle_text=legendgrouptitle_text,
        marker=dict(
            symbol=nodesymbol,
            opacity=nodeopacity,
            showscale=False,
            color=color,
            size=nodesize,
            line=dict(width=linewidth, color=linecolor)))
    return node_trace

In [None]:
report_dir = Path("../data/external/G1034_20230801/")

In [None]:
df_mash = pd.read_csv(report_dir / 'mash/df_mash.csv', index_col=0)

## Hierarchical Clustering based on MASH distances

In [None]:
df_mash_corr = df_mash.fillna(0).corr()

plt.figure(figsize=(30, 10))
#plt.title("MASH Distances", fontsize=20)  # You can adjust the title font size here

selected_data = df_mash_corr.copy()
clusters = shc.linkage(selected_data, 
            method='ward', 
            metric="euclidean",
            optimal_ordering=True)
shc.dendrogram(Z=clusters, labels=df_mash_corr.index, orientation='left')  # Set orientation to 'left'

plt.yticks(fontsize=14)  # Adjust the font size for y-axis ticks (labels) for horizontal dendrogram
plt.xticks(fontsize=14)  # Adjust the font size for y-axis ticks (labels) for horizontal dendrogram
plt.show()

## Estimate Number of Clusters

In [None]:
# choose features
data_for_clustering = df_mash.copy()
data_for_clustering.fillna(0,inplace=True)

# create data matrix
data_matrix = np.asarray(data_for_clustering).astype(float)
data_matrix

# scale the data
mms = MinMaxScaler()
scaled_data = mms.fit_transform(data_matrix)

# choose k range
if len(df_mash) <= 21:
    max_range = len(df_mash) - 1
else:
    max_range = 20

k_range=range(2, max_range)
# compute adjusted intertia
best_k, results = chooseBestKforKMeans(scaled_data, k_range)

# plot the results
plt.figure(figsize=(7,4))
plt.plot(results,'o')
plt.title('Adjusted Inertia for each K')
plt.xlabel('K')
plt.ylabel('Adjusted Inertia')
plt.xticks(range(2,max_range,1))
print(f"Estimated number of clusters: {best_k}")
plt.show()

## MASH Clustermap

In [None]:
n_clusters = best_k

# max color 12
if best_k < 12:
    top_clusters = best_k
else:
    top_clusters = 12

Agg_hc = AgglomerativeClustering(n_clusters = n_clusters, affinity = 'euclidean', linkage = 'ward')
y_hc = Agg_hc.fit_predict(df_mash_corr)
color_set3 = ['#8dd3c7','#ffffb3','#bebada','#fb8072','#80b1d3','#fdb462','#b3de69','#fccde5','#d9d9d9','#bc80bd','#ccebc5','#ffed6f']

df_hclusts = pd.DataFrame(index=df_mash_corr.index, columns=['hcluster', 'color_code'])
df_hclusts['hcluster'] = y_hc
top_clusters = df_hclusts.hcluster.value_counts().index.tolist()[:top_clusters]
dict_top_colors = dict(zip(top_clusters, color_set3[:len(top_clusters)]))

for genome_id in df_hclusts.index:
    cluster_id = df_hclusts.loc[genome_id, 'hcluster']
    if cluster_id in top_clusters:
        df_hclusts.loc[genome_id, 'color_code'] = dict_top_colors[cluster_id]
    else:
        df_hclusts.loc[genome_id, 'color_code'] = "#808080"
        
comm_colors = df_hclusts['color_code']
plt.figure()


# sns.set_theme(color_codes=True)
g = sns.clustermap((1 - df_mash)*100,
                  figsize=(20,20), row_linkage=clusters, col_linkage=clusters,
                  row_colors=comm_colors, col_colors=comm_colors, cmap="rocket_r")
g.ax_heatmap.set_xlabel('Genomes', fontsize=18)
g.ax_heatmap.set_ylabel('Genomes', fontsize=18)
plt.setp(g.ax_heatmap.get_xticklabels(), rotation=90, fontsize=12)  # set rotation and font size for x-axis labels
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0, fontsize=12)  # set font size for y-axis labels

# Adjust font size of the colorbar's tick labels
cbar = g.cax
cbar.set_yticklabels(cbar.get_yticklabels(), fontsize=16)

plt.show()

## Community Detection

In [None]:
# Set cutoff
cut_off = 0.95

# Calculate MASH similarity from distance matrix
similarity_matrix = 1 - df_mash

# Remove edges with weight <= 0.95
df_mash_graph = similarity_matrix[similarity_matrix>=cut_off]
df_mash_graph.fillna(0, inplace=True)
print(df_mash_graph.shape)

# Create a graph from the similarity matrix
G = nx.from_pandas_adjacency(df=df_mash_graph)
print(len(list(nx.connected_components(G))))

# remove self links
self_links = [(n, n) for n in G.nodes() if G.has_edge(n, n)]
G.remove_edges_from(self_links)

# Use the Louvain method to detect communities
partition = community_louvain.best_partition(G)
print('Number of communities detected: ',len(set(partition.values())), ', cutoff:', cut_off)

# `partition` is a dictionary where the keys are the nodes and the values are the community IDs
pos = nx.spring_layout(G)
nx.draw_networkx_nodes(G, pos, cmap=plt.get_cmap('jet'), node_size = 20)
nx.draw_networkx_edges(G, pos, alpha=0.5)
plt.show()

## Visualization
### Build Taxonomy Mapping

In [None]:
df_gtdb = pd.read_csv(report_dir / 'tables' / 'df_gtdb_gtdbtk_meta.csv', index_col='genome_id')

tax_mapping = {}
for k, v in df_gtdb.Organism.to_dict().items():
    v = v.split()
    if len(v) == 2:
        genus, epithet = v
        genus = genus.split("__")[-1]
        species = " ".join([genus, epithet]) 
    else:
        print(f"WARNING: {v}")
    tax_mapping[k] = species

### Build MASH Species Mapping

In [None]:
df_partition = pd.DataFrame.from_dict({"MASH_partition_0.95" : {k:f"MASH_species_{v}" for k,v in partition.items()}})
df_clusters = df_hclusts.merge(df_partition, left_index=True, right_index=True)
df_clusters['hcluster'] = df_clusters['hcluster'].apply(lambda x: f"hcluster_{str(x)}")
df_clusters

In [None]:
mash_species_color_map = {}
ctr = 0
for num, k, in enumerate(df_clusters["MASH_partition_0.95"].value_counts().to_dict().keys()):
    if ctr + 1 == len(color_set3):
        ctr = 0
    else:
        ctr = ctr + 1
    mash_species_color_map[k] = color_set3[ctr]
df_clusters['MASH_partition_0.95_color'] = df_clusters['MASH_partition_0.95'].map(mash_species_color_map)
df_clusters['Species'] = df_clusters.index.map(tax_mapping)
df_clusters['Species'] = [i.split()[0] for i in df_clusters['Species']]
df_clusters.Species.value_counts()

In [None]:
df_clusters[df_clusters.Species == 'Unclassified']
df_gtdb.loc['NBC_01635']

In [None]:
df_gtdb.loc['NBC_01245']

In [None]:
grouping_name = "hcluster"
color_grouping_name = "color_code"

node_annotation_map = {}
for i in df_clusters.index:
    group = df_clusters.loc[i, grouping_name]
    color = df_clusters.loc[i, color_grouping_name]
    symbol = "circle"
    node_annotation_map[group] = {'color' : color,
                                  'node_symbol' : symbol}

In [None]:
edge_annotation_map = {'mash' : {'color':'black',
                                 'width':0.5}}

In [None]:
traces = []
cutoff = cut_off

G = nx.from_pandas_adjacency(df_mash)
edge_to_remove = [e for e in G.edges if G.edges[e]['weight'] >= 1-cutoff]
G.remove_edges_from(edge_to_remove)

# define layout options
options = {
    'prog': 'neato',
}
pos = nx.nx_agraph.graphviz_layout(G, **options)#, args='-Goverlap=false -Elen=weight')
for n, p in pos.items():
    G.nodes[n]['pos'] = p
    G.nodes[n]['node_trace'] = df_clusters.loc[n, grouping_name]
    G.nodes[n]['text'] = f'{n}<br>{tax_mapping[n]}<br>{df_clusters.loc[n, "MASH_partition_0.95"]}'

weights = []
for e in G.edges:
    weight = G.edges[e]['weight']
    weight = f"{1-weight:.2f}"
    weights.append(weight)
    G.edges[e]['relation_type'] = f'{float(weight):.0%}'

weights = sorted(set(weights))

x_max, x_min = 1, cutoff
y_max, y_min = 3, 0.2
x = (y_max - y_min) / (x_max - x_min)
c = y_max - (x_max*x)

for w in weights:
    width = float(w)*x + c
    edge_trace = create_edge_trace(G, f'{float(w):.0%}', color='black', width=width, showlegend=True, opacity=0.5,
                                   legendgroup="MASH distances", legendgrouptitle_text="Similarity")
    traces.append(edge_trace)

for trace in df_clusters[grouping_name].unique():
    nodeopacity = 0.8
    showtextlabel = False
    linecolor = None
    linewidth = 0.5
    textposition="middle center"
    node_size = 16
    color = node_annotation_map[trace]['color']
    node_trace = create_node_trace(G, trace, color, showtextlabel=showtextlabel, 
                                   nodesymbol=node_annotation_map[trace]['node_symbol'], nodeopacity=nodeopacity, 
                                   showlegend=True, linecolor=linecolor, linewidth=linewidth, nodesize=node_size,
                                   textposition=textposition, legendgroup="genomes", legendgrouptitle_text=grouping_name)
    traces.append(node_trace)

In [None]:
fig = go.Figure(data=traces,
                layout=go.Layout(
                    paper_bgcolor='rgba(0,0,0,0)',
                    plot_bgcolor='white',
                    showlegend=True,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, linecolor='black', mirror=True, linewidth=1),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, linecolor='black', mirror=True, linewidth=1),
                    width=1800, height=1700)
                )
fig = fig.update_layout(legend=dict(
    orientation="v"
))

In [None]:
outfile_html = Path(f"../figures/MASH_Group.html")
outfile_html.parent.mkdir(exist_ok=True, parents=True)
fig.write_html(outfile_html)

## Export supplementary

In [None]:
df_supplementary = pd.read_excel("../data/external/103_table_v5.xlsx")
df_supplementary = df_supplementary.merge(df_louvainn, left_on='strain_name', right_on='genome_id').set_index("genome_id")
df_supplementary

In [None]:
df_gtdbtk = pd.read_csv("../data/external/G1034_20230801/tables/gtdbtk.bac120.summary.tsv", sep="\t")
for i in df_gtdbtk.index:
    genome_id = df_gtdbtk.loc[i, "user_genome"]
    gtdb_species = df_gtdbtk.loc[i, "classification"].split(";")[-1].split("__")[-1]
    df_supplementary.loc[genome_id, "GTDB_species"] = gtdb_species
for mash_clusters in df_supplementary.Louvainn_partition.unique():
    subset = df_supplementary[df_supplementary.Louvainn_partition == mash_clusters]
    print(subset.GTDB_species.unique(), subset.Louvainn_partition.unique())

In [None]:
df_mash = pd.read_csv(report_dir / 'mash/df_mash.csv', index_col=0)
cut_off = 0.95
df_mash_sim = 1 - df_mash
df_mash_graph = df_mash_sim[df_mash_sim>cut_off]
df_mash_graph.fillna(0, inplace=True)
print(df_mash_graph.shape)
G = nx.from_pandas_adjacency(df=df_mash_graph)
print(len(list(nx.connected_components(G))))
partition = community_louvain.best_partition(G)
print('Number of communities detected: ',len(set(partition.values())), ', cutoff:', cut_off)

In [None]:
df_mash = pd.read_csv('df_mash.csv', index_col=0)
cut_off = 0.95
df_mash_sim = 1 - df_mash
df_mash_graph = df_mash_sim[df_mash_sim>cut_off]
df_mash_graph.fillna(0, inplace=True)
print(df_mash_graph.shape)
G = nx.from_pandas_adjacency(df=df_mash_graph)
print(len(list(nx.connected_components(G))))
partition = community_louvain.best_partition(G)
print('Number of communities detected: ',len(set(partition.values())), ', cutoff:', cut_off)

## References
<font size="2">
{% for i in project().rule_used['mash']['references'] %}
- *{{ i }}*
{% endfor %}
</font>