In [1]:
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, adjusted_rand_score, normalized_mutual_info_score
import numpy as np

In [2]:
table = pd.read_csv('/home/haiping_liu/code/My_model/ImCluster/data/hoch_berd_embed.csv')

data = table.iloc[:, 0:20].values
batch_id = table.iloc[:, 20].values
cell_type = table.iloc[:, 21].values

## Kmeans

In [4]:
# Function to map predicted labels to true labels
def map_labels(true_labels, predicted_labels):
    from scipy.optimize import linear_sum_assignment
    true_label_set = np.unique(true_labels)
    pred_label_set = np.unique(predicted_labels)
    cost_matrix = np.zeros((len(true_label_set), len(pred_label_set)))
    for i, true_label in enumerate(true_label_set):
        for j, pred_label in enumerate(pred_label_set):
            cost_matrix[i, j] = np.sum((true_labels == true_label) & (predicted_labels == pred_label))
    row_ind, col_ind = linear_sum_assignment(-cost_matrix)
    label_map = {pred_label_set[j]: true_label_set[i] for i, j in zip(row_ind, col_ind)}
    return np.array([label_map[label] for label in predicted_labels])

# Initialize results storage
results = []

# Loop through cluster numbers from 2 to 10
for n_clusters in range(2, 9):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    kmeans_labels = kmeans.fit_predict(data)
    
    # Map predicted labels to true labels
    mapped_labels = map_labels(cell_type, kmeans_labels)
    
    # Calculate metrics
    acc = accuracy_score(cell_type, mapped_labels)
    ari = adjusted_rand_score(cell_type, kmeans_labels)
    nmi = normalized_mutual_info_score(cell_type, kmeans_labels)
    
    # Append results
    results.append({
        "Clusters": n_clusters,
        "ACC": acc,
        "ARI": ari,
        "NMI": nmi
    })

# Convert results to DataFrame and display
results_df = pd.DataFrame(results)
print(results_df)

   Clusters       ACC       ARI       NMI
0         2  0.475638  0.217359  0.187703
1         3  0.640013  0.499535  0.411553
2         4  0.647339  0.503203  0.437579
3         5  0.656969  0.541886  0.442599
4         6  0.678798  0.555574  0.464776
5         7  0.596734  0.432202  0.457715
6         8  0.590521  0.425688  0.465630
