# Loading packages and data

In [2]:
import random
from collections import Counter
from tqdm import tqdm
import os

import torch
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib_venn import venn2
import numpy as np
import pandas as pd
import seaborn as sns
import time
import csv
import json # for saving dictionnaries
import pickle # for saving dictionnaries
import copy

from sklearn.decomposition import PCA
import esm
from sklearn.cluster import KMeans
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import adjusted_mutual_info_score
from collections import Counter, OrderedDict
from Bio.Blast import Applications
from sklearn.metrics.pairwise import pairwise_distances
import itertools


Due to the on going maintenance burden of keeping command line application
wrappers up to date, we have decided to deprecate and eventually remove these
modules.

We instead now recommend building your command line and invoking it directly
with the subprocess module.


In [4]:
# MODIFICATION OF THE PATHS
path_to_data = "Datasets/"
species = "pfal_pber"
FASTA_PATH = path_to_data + species + ".fasta"

model_list = ["esm2_t48_15B_UR50D", "esm2_t36_3B_UR50D", "esm2_t33_650M_UR50D", "esm2_t30_150M_UR50D", "esm2_t12_35M_UR50D", "esm2_t6_8M_UR50D"]
EMB_LAYER_list = [48, 36, 33, 30, 12, 6]

model_index = 3 # change to select the model
model = model_list[model_index]
EMB_LAYER = EMB_LAYER_list[model_index]

EMB_PATH = path_to_data + species + "_emb_" + model + "/"


Xs = []
prot_names_and_group = []
aaSequences = []
for header, _seq in esm.data.read_fasta(FASTA_PATH):
    fn = f'{EMB_PATH}/{header[0:]}.pt'
    if os.path.isfile(fn):
        embs = torch.load(fn)
        Xs.append(embs['mean_representations'][EMB_LAYER])
        prot_names_and_group.append(header.split('|'))        
        aaSequences.append(_seq)

Xs = torch.stack(Xs, dim=0).numpy()
print(Xs.shape)

Xs_train = Xs
prot_names_and_group_train = prot_names_and_group

RuntimeError: stack expects a non-empty TensorList

This code handles proteome fasta files downloaded from OrthoMCL-DB.

In particular the variable prot_names_and_group stores protein information under the format: >species_code|protein_code | organism=Species full name | Protein description | Ortholog_group_id

For other databases, adapt accordingly.

# Clustering

In [19]:
list_num_pca_components = [20, 40, 60, 80, 100, 120, 0]
list_num_pca_components = [20]
list_n_clusters = [int(Xs_train.shape[0] / 2)]

### New run

In [20]:
def saving_from_kmeans(Xs_train_pca, prot_names_and_group_train, kmeans):
    
    n_clusters = kmeans.n_clusters
    n_samples = Xs_train_pca.shape[0]
    X_labels = kmeans.labels_
    # We do not save training data Xs_train_pca because of memory requirements. It can be recreated quickly if necessary.

    return [n_clusters, n_samples, prot_names_and_group_train, X_labels, kmeans,
            "n_clusters, n_samples, prot_names_and_group_train, X_labels, kmeans"]

In [22]:
kmeans_saving = {}
for num_pca_components in list_num_pca_components:
    print("num_pca_components: " + str(num_pca_components))
    t0 = time.time()
    if num_pca_components==0:
        Xs_train_pca = Xs_train
    else:
        pca = PCA(n_components = num_pca_components, svd_solver = "full")
        Xs_train_pca = pca.fit_transform(Xs_train)

    kmeans_saving["n_pca" + str(num_pca_components)] = {}
    for n_clusters in list_n_clusters:
        print("n_clusters: " + str(n_clusters))
        kmeans = KMeans(n_clusters = n_clusters, n_init=5, random_state=0).fit(Xs_train_pca)
        kmeans_saving["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)] = saving_from_kmeans(Xs_train_pca, prot_names_and_group_train, kmeans)
        
    print("Elapsed time: " + str(time.time() - t0) + " seconds")
    kmeans_saving["n_pca" + str(num_pca_components)]["running_time_by_PCA"] = str(time.time() - t0) + " seconds"
    kmeans_saving["n_pca" + str(num_pca_components)]["Xs_train_pca"] = Xs_train_pca

if True: # Change to True for saving the results
    print("Saving results")
    with open(path_to_data + "Results/" + species + "_" + model + "_" + str(len(list_num_pca_components)) + "PCA_" +
              str(len(list_n_clusters)) + "cluster_kmeans_save.pkl", 'wb') as fp:
        pickle.dump(kmeans_saving, fp)

num_pca_components: 20
n_clusters: 5131
Elapsed time: 22.357588529586792 seconds
Saving results


### Loading results from a previous run

In [23]:
with open(path_to_data + "Results/pfal_pber_esm2_t33_650M_UR50D_1PCA_1cluster_kmeans_save.pkl", 'rb') as f:
    kmeans_saving = pickle.load(f)
    

In [24]:
print(list(kmeans_saving.keys()))
print(list(kmeans_saving[list(kmeans_saving.keys())[0]].keys()))

['n_pca20']
['n_clusters5131', 'running_time_by_PCA', 'Xs_train_pca']


# Analysing results

In [25]:
def measure_pairwise_performance(saved_results, Xs_train_pca):

    n_clusters = saved_results[0]
    n_samples = saved_results[1]
    prot_names_and_group_train = saved_results[2]
    X_labels = saved_results[3]
    kmeans = saved_results[4]

    X_dist = kmeans.transform(Xs_train_pca)**2 # distance from each point to all of the clusters
    
    orthologs_naiveSearch = []
    orthologs_distanceBasedSearch = []
    orthologs_1_to_1 = []
    orthologs_SATURN = []
    
    n_species_total = len(list(set([prot[0] for prot in prot_names_and_group_train])))
    for cluster in range(n_clusters):
        ind_fromCluster = [i for i, x in enumerate(X_labels) if x==cluster] # get all indices of sequences in that cluster
        if len(ind_fromCluster) == 1:
            continue # skipping an iteration in case finding a cluster of size 1

        ind_sorted = np.argsort(X_dist[ind_fromCluster,cluster]) # sort by increasing distance to the centroid
        all_prots = [prot_names_and_group_train[ind_fromCluster[i]] for i in ind_sorted]
        all_specs = [prot_names_and_group_train[ind_fromCluster[i]][0] for i in ind_sorted]
        
        # Naive search for ortholog combinations
        if len(list(set(all_specs))) > 1: # need at least 2 different species inside the cluster
            for i1 in range(len(all_specs) - 1): # we explore all pairwise combinations that give orthologs
                for i2 in range(i1 + 1, len(all_specs)):
                    if all_specs[i1] != all_specs[i2]:
                        orthologs_naiveSearch.append([all_prots[i1], all_prots[i2]])
        
        # Distance-based search for ortholog among the 2 sequences closest to the centroid
        all_prots_1st_pair = all_prots[0:2]
        if all_prots_1st_pair[0][0] != all_prots_1st_pair[1][0]: # search for orthologs (different species)
            orthologs_distanceBasedSearch.append(all_prots_1st_pair)

        # Search for 1 to 1 orthologs
        if len(all_specs) == len(list(set(all_specs))) == n_species_total:
            orthologs_1_to_1.append(all_prots)

        # SATURN search (build pairs by increasing distances from the centroid)
        ind_species2 = 1
        if all_specs[0] != all_specs[ind_species2]:
            orthologs_SATURN.append([all_prots[0], all_prots[ind_species2]])
        else:
            while all_specs[0] == all_specs[ind_species2]:
                ind_species2 += 1
                if ind_species2 == len(all_specs):
                    break
                if all_specs[0] != all_specs[ind_species2]:
                    orthologs_SATURN.append([all_prots[0], all_prots[ind_species2]])
                    break
   
    return [orthologs_naiveSearch,
            orthologs_distanceBasedSearch,
            orthologs_1_to_1,
            orthologs_SATURN]

In [26]:
def measure_group_performance(saved_results):

    n_clusters = saved_results[0]
    n_samples = saved_results[1]
    prot_names_and_group_train = saved_results[2]
    X_labels = saved_results[3]
    kmeans = saved_results[4]

    toReturn = []

    # First, make a list of all groups present in each cluster.
    list_groups_in_clusters = [[] for i in range(n_clusters)]
    for i_label in range(n_samples):
        list_groups_in_clusters[X_labels[i_label]].append(prot_names_and_group_train[i_label][4])
    #print(list_groups_in_clusters)

    # Gather all true group names
    list_all_groups = list(set([prot[4] for prot in prot_names_and_group_train])) # remove duplicated names
    
    # Count how many sequences from each group are present in each cluster
    list_OG_count_in_cluster = np.zeros((len(list_all_groups), n_clusters))
    for i_cluster in range(n_clusters):
        for group in list_groups_in_clusters[i_cluster]:
            list_OG_count_in_cluster[list_all_groups.index(group), i_cluster] += 1
    if not(np.sum(list_OG_count_in_cluster) == n_samples):
        print("Error: missing some sequences in the count matrix")
    #print(list_OG_count_in_cluster)

    family_complet_stat = 0
    for i_group in range(len(list_all_groups)):
        family_complet_stat += max(list_OG_count_in_cluster[i_group,:])
    family_complet_stat = family_complet_stat / n_samples
    toReturn.append(family_complet_stat)

    toReturn.append(adjusted_mutual_info_score([prot[4] for prot in prot_names_and_group_train], X_labels))

    # Count groups that are totally and exclusively contained in a single cluster
    i_count_success_exactMatch = 0
    for i_group in range(len(list_all_groups)):
        if np.count_nonzero(list_OG_count_in_cluster[i_group,:] == 0) == n_clusters - 1: # check if all sequences from the OG belong to only one of the 100 cluster, and the other clusters contain 0 sequences from that group
            # Check if the corresponding cluster contains only that group
            i_cluster = np.nonzero(list_OG_count_in_cluster[i_group,:])[0][0]
            if np.count_nonzero(list_OG_count_in_cluster[:,i_cluster] == 0) == len(list_all_groups) - 1:
                i_count_success_exactMatch += 1

    toReturn.append(i_count_success_exactMatch / n_clusters)

    toReturn.append("Family Completeness, Adjusted Mutual Information, Exact matches over total")
    return(toReturn)

In [None]:
def get_total_count_n2m_orthologs(prot_names_and_group_train):
    list_all_OGs = list(set([protein[4] for protein in prot_names_and_group_train]))
    
    list_proteins_in_OG = [[] for i in range(len(list_all_OGs))]
    
    for protein in prot_names_and_group_train:
        list_proteins_in_OG[list_all_OGs.index(protein[4])].append(protein[0])

    names_species = list(set([protein[0] for protein in prot_names_and_group_train]))
    
    total_n2m_ortholog_count = 0
    for list_proteins in list_proteins_in_OG:
        species = []
        for name in names_species:
            species.append(list_proteins.count(name)) # we just need the counts

        for pairwise_comb in list(itertools.combinations([i for i in range(len(species))], 2)): # so can deal with more than 2 species
            total_n2m_ortholog_count += species[pairwise_comb[0]] * species[pairwise_comb[1]]

    return(total_n2m_ortholog_count)

def get_total_count_121_orthologs(prot_names_and_group_train):
    list_all_OGs = list(set([protein[4] for protein in prot_names_and_group_train]))
    
    list_proteins_in_OG = [[] for i in range(len(list_all_OGs))]
    
    for protein in prot_names_and_group_train:
        list_proteins_in_OG[list_all_OGs.index(protein[4])].append(protein[0])

    names_species = list(set([protein[0] for protein in prot_names_and_group_train]))
    print(names_species)
    
    total_121_ortholog_count = 0
    total_121_ortholog_count_groups = 0
    for list_proteins in list_proteins_in_OG:

        if (len(list(set(list_proteins))) == len(list_proteins) == len(names_species)): # if getting exactly one protein from each species
            total_121_ortholog_count_groups += 1
            species = []
            for name in names_species:
                species.append(list_proteins.count(name))
    
            for pairwise_comb in list(itertools.combinations([i for i in range(len(species))], 2)): # so can deal with more than 2 species
                total_121_ortholog_count += species[pairwise_comb[0]] * species[pairwise_comb[1]]

    return(total_121_ortholog_count)

print(get_total_count_n2m_orthologs(prot_names_and_group_train))
print(get_total_count_121_orthologs(prot_names_and_group_train))

### Working on saved run

In [27]:
performance_stored = {}

n_samples = Xs_train.shape[0]
list_all_groups_no_set = [prot[4] for prot in prot_names_and_group_train]
n_species = len(list(set([prot[0] for prot in prot_names_and_group_train])))

i_pca = 0
for num_pca_components in list_num_pca_components:
    i_clu = 0
    print("num_pca_components: " + str(num_pca_components))
    t0 = time.time()
    # Need to recreate training data as it is too heavy to be stored
    if num_pca_components==0:
        Xs_train_pca = Xs_train
    else:
        pca = PCA(n_components = num_pca_components, svd_solver = "full")
        Xs_train_pca = pca.fit_transform(Xs_train)

    performance_stored["n_pca" + str(num_pca_components)] = {}

    for n_clusters in list_n_clusters:
        print("n_clusters: " + str(n_clusters))
        performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)] = {}

        # Pair level
        results = measure_pairwise_performance(kmeans_saving["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)], Xs_train_pca)
        Xs_train_pca
        for index in [0,1,2,3]:
            n_retrieved = 0
            n_correct = 0
            FalsePositive_pairs_list = []
            for prots in results[index]:
                if len(list(set([p[4] for p in prots]))) == 1: # verify that they all come from the same group
                    if index==2: # for 1:1 orthologs, Additional check that we are not from a n:m configuration
                        if list_all_groups_no_set.count(prots[0][4]) == n_species:
                            n_correct += 1
                        else: # Study False Positives
                            FalsePositive_pairs_list.append([prots[0][1], prots[0][4], prots[1][1], prots[1][4]])
                    else:
                        n_correct += 1
                else: # Study False Positives
                    FalsePositive_pairs_list.append([prots[0][1], prots[0][4], prots[1][1], prots[1][4]])
                n_retrieved += 1
            print("Index: " + str(index))
            print([n_correct, n_retrieved])

            if True: # if not needed to store the false positives
                FalsePositive_pairs_list = []

            if n_retrieved != 0:
                print((n_correct / n_retrieved * 100))
            if index==0:
                performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)]["Naive pairing"] = [n_correct, n_retrieved, FalsePositive_pairs_list]
            if index==1:
                performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)]["Distance-based pairing"] = [n_correct, n_retrieved, FalsePositive_pairs_list]
            if index==2:
                performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)]["1-to-1 orthologs"] = [n_correct, n_retrieved, FalsePositive_pairs_list]

        # Group level
        results_group = measure_group_performance(kmeans_saving["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)])
        performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)]["Family Completeness"] = results_group[0]
        performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)]["Adjusted Mutual Information"] = results_group[1]
        performance_stored["n_pca" + str(num_pca_components)]["n_clusters" + str(n_clusters)]["% exact group match"] = results_group[2] * 100

        i_clu+=1
    i_pca+=1
    print("Elapsed time: " + str(time.time() - t0) + " seconds")

print(performance_stored)

num_pca_components: 20
n_clusters: 5131
Index: 0
[1960, 3480]
56.32183908045977
Index: 1
[1729, 1990]
86.88442211055276
Index: 2
[1316, 1587]
82.92375551354758
Index: 3
[1784, 2150]
82.97674418604652
Elapsed time: 38.9793267250061 seconds
{'n_pca20': {'n_clusters5131': {'Naive pairing': [1960, 3480, []], 'Distance-based pairing': [1729, 1990, []], '1-to-1 orthologs': [1316, 1587, []], 'Family Completeness': 0.7231803566208711, 'Adjusted Mutual Information': 0.43927034395026243, '% exact group match': 34.82751900214383}}}
