#### Sequence similarity network clustering

In [17]:
import pickle
from ete3 import Tree
from sys import exit
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import cm  # Import colormap
mpl.rc('font', size=14)
mpl.rc('axes', titlesize='large', labelsize='large')
mpl.rc('xtick', labelsize='large')
mpl.rc('ytick', labelsize='large')
import matplotlib.pyplot as plt
from sklearn import linear_model
from sklearn.metrics import adjusted_mutual_info_score
import time
import openpyxl  # Required for Excel output

# Read and process the tree:
t = Tree("/home/cdchiang/vae/PF01494/latent_space/simulated_msa/output/random_tree.newick", format=1)
num_leaf = len(t)
t.name = str(num_leaf)

for node in t.traverse('preorder'):
    if node.is_root():
        node.add_feature('anc', [])
        node.add_feature('sumdist', 0)
    else:
        node.add_feature('anc', node.up.anc + [node.up.name])
        node.add_feature('sumdist', node.up.sumdist + node.dist)

# Lists of dist_cutoff and AS values
dist_cutoff_list = np.arange(0.5, 1.3, 0.1)
AS_list = [10, 15, 20, 25, 30, 35]

# Create a list to store the results
results = []

for dist_cutoff in dist_cutoff_list:
    dist_cutoff = round(dist_cutoff, 1)  # Round to one decimal place for consistent sheet names
    head_node_names = []

    for node in t.traverse('preorder'):
        if node.is_leaf() and node.sumdist < dist_cutoff:
            head_node_names.append(node.name)
        if (not node.is_leaf()) and node.sumdist > dist_cutoff and node.up.sumdist < dist_cutoff:
            head_node_names.append(node.name)

    cluster_node_names = {}  # Cluster to sequences mapping
    leaf2cluster_dict = {}   # Sequence to cluster mapping
    cluster_no = 1
    for name in head_node_names:
        cluster_node_names[name] = []
        for leaf in (t & name).iter_leaves():
            cluster_node_names[name].append(leaf.name)
            leaf2cluster_dict[int(leaf.name)] = cluster_no
        cluster_no += 1

    # Now, for each AS value
    for AS in AS_list:
        # Read the SSN file:
        csv_filename = f'simulated_leaf_msa_AS{AS} Full Network colorized default node.csv'
        try:
            df = pd.read_csv(csv_filename)
        except FileNotFoundError:
            print(f"File {csv_filename} not found. Skipping.")
            continue

        # Drop data with NaN
        df = df.dropna(subset=['name', 'Node Count Cluster Number'])

        # Ensure 'name' is of integer type
        df['name'] = df['name'].astype(int)

        SSN_dict = dict(zip(df['name'], df['Node Count Cluster Number']))

        # Create common keys
        common_keys = leaf2cluster_dict.keys() & SSN_dict.keys()
        if len(common_keys) == 0:
            print(f"No common keys for dist_cutoff={dist_cutoff:.1f} and AS={AS}. Skipping.")
            continue

        tree_labels = [leaf2cluster_dict[key] for key in common_keys]
        SSN_labels = [SSN_dict[key] for key in common_keys]

        # Compute the number of clusters in the tree and SSN
        tree_cluster_number = len(set(tree_labels))
        ssn_cluster_number = len(set(SSN_labels))

        # AMI calculation
        ami_score = adjusted_mutual_info_score(tree_labels, SSN_labels)
        print(f"Distance cutoff: {dist_cutoff:.1f}, AS value: {AS}")
        print(f"Tree clusters: {tree_cluster_number}, SSN clusters: {ssn_cluster_number}")
        print(f"Adjusted Mutual Information (AMI) score: {ami_score:.4f}\n")

        # Store the results
        results.append({
            'dist_cutoff': dist_cutoff,
            'AS': AS,
            'tree_cluster_number': tree_cluster_number,
            'ssn_cluster_number': ssn_cluster_number,
            'ami_score': ami_score
        })

# Create a DataFrame to store the results
results_df = pd.DataFrame(results)

# Output the DataFrame to an Excel file with different sheets for each dist_cutoff
output_filename = 'SSN_clustering_results.xlsx'

# Use ExcelWriter to write to multiple sheets
with pd.ExcelWriter(output_filename, engine='openpyxl') as writer:
    # Group the results by 'dist_cutoff'
    grouped = results_df.groupby('dist_cutoff')
    for dist_cutoff, group in grouped:
        # Use the dist_cutoff value as the sheet name
        sheet_name = f"Cutoff_{dist_cutoff}"
        # Write the group to the sheet
        group.to_excel(writer, sheet_name=sheet_name, index=False)

print(f"Results have been written to {output_filename} with separate sheets for each dist_cutoff.")



Distance cutoff: 0.5, AS value: 10
Tree clusters: 11, SSN clusters: 1
Adjusted Mutual Information (AMI) score: 0.0000

Distance cutoff: 0.5, AS value: 15
Tree clusters: 11, SSN clusters: 4
Adjusted Mutual Information (AMI) score: 0.6282

Distance cutoff: 0.5, AS value: 20
Tree clusters: 11, SSN clusters: 51
Adjusted Mutual Information (AMI) score: 0.7569

Distance cutoff: 0.5, AS value: 25
Tree clusters: 11, SSN clusters: 296
Adjusted Mutual Information (AMI) score: 0.5905

Distance cutoff: 0.5, AS value: 30
Tree clusters: 11, SSN clusters: 977
Adjusted Mutual Information (AMI) score: 0.4530

Distance cutoff: 0.5, AS value: 35
Tree clusters: 11, SSN clusters: 2044
Adjusted Mutual Information (AMI) score: 0.3391

Distance cutoff: 0.6, AS value: 10
Tree clusters: 22, SSN clusters: 1
Adjusted Mutual Information (AMI) score: 0.0000

Distance cutoff: 0.6, AS value: 15
Tree clusters: 22, SSN clusters: 4
Adjusted Mutual Information (AMI) score: 0.5920

Distance cutoff: 0.6, AS value: 20
Tree 