In [1]:
import numpy as np

def contingency_matrix(true_labels, predicted_labels):
    """
    Computes the contingency matrix between true and predicted labels.
    
    Parameters:
    true_labels : list or numpy array
        True cluster labels of the data points.
    predicted_labels : list or numpy array
        Predicted cluster labels of the data points.
    
    Returns:
    numpy.ndarray
        Contingency matrix (n_clusters_true x n_clusters_pred).
    """
    n_classes = len(np.unique(true_labels))
    n_clusters = len(np.unique(predicted_labels))
    contingency = np.zeros((n_classes, n_clusters), dtype=np.int64)
    
    for i in range(len(true_labels)):
        contingency[true_labels[i], predicted_labels[i]] += 1
    
    return contingency

def calculate_adjusted_rand_index(true_labels, predicted_labels):
    """
    Calculates the Adjusted Rand Index (ARI) between true and predicted cluster labels.
    
    Parameters:
    true_labels : list or numpy array
        True cluster labels of the data points.
    predicted_labels : list or numpy array
        Predicted cluster labels of the data points.
    
    Returns:
    float
        Adjusted Rand Index between true and predicted cluster labels.
    """
    contingency = contingency_matrix(true_labels, predicted_labels)
    n_samples = len(true_labels)
    
    # Calculate the ARI using the contingency matrix
    a_i = np.sum(contingency, axis=1)
    b_j = np.sum(contingency, axis=0)
    
    # Calculate the terms needed for ARI
    sum_combinations = np.sum(comb2(a_i))  # sum of C(a_i, 2) for all clusters i
    sum_combinations_prime = np.sum(comb2(b_j))  # sum of C(b_j, 2) for all clusters j
    sum_combinations_squared = np.sum(comb2(contingency.flatten()))  # sum of C(n_ij, 2) for all pairs (i, j)
    total_combinations = comb2(n_samples)
    
    # Expected index
    expected_index = (sum_combinations * sum_combinations_prime) / total_combinations
    max_index = (sum_combinations + sum_combinations_prime) / 2
    
    # Adjusted Rand Index
    if max_index == expected_index:
        ARI = 0.0
    else:
        ARI = (sum_combinations_squared - expected_index) / (max_index - expected_index)
    
    return ARI

def comb2(n):
    """
    Calculate the number of combinations C(n, 2).
    
    Parameters:
    n : int
        Number to calculate combinations for.
    
    Returns:
    int
        Combinations C(n, 2).
    """
    return n * (n - 1) // 2




In [2]:
#load cluster
import json
def loadClusterFromFile(filepath):
    map_cluster={}
    f = open(filepath)
    data = json.load(f)
    c_id=0
    gene_set =  set()
    for i in data.keys():
        #print(i)
        for gene in data[i]['gene_id']:
            #print(gene)
            map_cluster[gene]=c_id
            gene_set.add(gene)
        c_id=c_id+1
    print(str(len(gene_set))+" seq, "+str(c_id)+" cluster")
    list_gene=list(gene_set)
    list_gene.sort()
    
    arr=[]
    for i in range(len(list_gene)):
        arr.append(map_cluster[list_gene[i]])
    return arr
            

In [3]:
g_cdhit_a_diamond_c_mcl=loadClusterFromFile('/mnt/data/data/amromics/panta2/panta/out/Sp600/g_cdhit_a_diamond_c_mcl/annotated_clusters.json')
g_diamond_a_diamond_c_mcl=loadClusterFromFile('/mnt/data/data/amromics/panta2/panta/out/Sp600/g_diamond_a_diamond_c_mcl/annotated_clusters.json')
g_mmseq_a_diamond_c_mcl=loadClusterFromFile('/mnt/data/data/amromics/panta2/panta/out/Sp600/g_mmseq_a_diamond_c_mcl/annotated_clusters.json')
c_mmseq=loadClusterFromFile('/mnt/data/data/amromics/panta2/panta/out/Sp600/c_mmseq/annotated_clusters.json')
c_diamond=loadClusterFromFile('/mnt/data/data/amromics/panta2/panta/out/Sp600/c_diamond/annotated_clusters.json')

1194625 seq, 7633 cluster
1194625 seq, 27044 cluster
1194625 seq, 7366 cluster
1194625 seq, 9443 cluster
1194625 seq, 16782 cluster


In [4]:
calculate_adjusted_rand_index(g_cdhit_a_diamond_c_mcl,g_diamond_a_diamond_c_mcl)


0.7890185563290272

In [5]:
calculate_adjusted_rand_index(g_cdhit_a_diamond_c_mcl,g_mmseq_a_diamond_c_mcl)

0.9951394757083515

In [6]:
calculate_adjusted_rand_index(g_cdhit_a_diamond_c_mcl,c_mmseq)

0.9844329097212963

In [7]:
calculate_adjusted_rand_index(g_cdhit_a_diamond_c_mcl,c_diamond)

0.9260373926710616