In [None]:
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import normalized_mutual_info_score


##读取预处理的基因表达量矩阵
matrix = pd.read_csv("/home/GaoBX/Study.xlsx", header=None,skiprows=[0])


##读取细胞类别文件，一般为“meatdata”文件
df = pd.read_csv("/home/GaoBX/work/metadata.csv",index_col=0)
true_labels = df['cluster']


##kmeans聚类
from sklearn.cluster import KMeans
n_clusters = 6
kmeans = KMeans(n_clusters=6, init='k-means++')
kmeans.fit(matrix)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
ari = adjusted_rand_score(true_labels, labels)
nmi = normalized_mutual_info_score(true_labels, labels)
print(f"ARI score: {ari}")
print(f"NMI score: {nmi}")


##k-means++聚类
from sklearn.cluster import KMeans
n_clusters = 6
kmeans = KMeans(n_clusters=6, init='k-means++')
kmeans.fit(matrix)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
ari = adjusted_rand_score(true_labels, labels)
nmi = normalized_mutual_info_score(true_labels, labels)
print(f"ARI score: {ari}")
print(f"NMI score: {nmi}")


##MinibatchKmeans聚类
from sklearn.cluster import MiniBatchKMeans
kmeans = MiniBatchKMeans(n_clusters=6, batch_size=250, max_iter=600)
kmeans.fit(matrix)
labels = kmeans.labels_
ari = adjusted_rand_score(true_labels, labels)
nmi = normalized_mutual_info_score(true_labels, labels)
print(f"ARI score: {ari}")
print(f"NMI score: {nmi}")


##leiden聚类
import bbknn
bbknn.bbknn(matrix)
resolution = 0.5
sc.pp.neighbors(adata_copy)
sc.tl.leiden(adata_copy, resolution=resolution)
sc_plot.umap(adata_copy, color='leiden')
print(adata_copy.obs['leiden'])
ari = adjusted_rand_score(true_labels, labels)  
nmi = normalized_mutual_info_score(true_labels, labels)  
print("ARI: ", ari)
print("NMI: ", nmi)



##Birch聚类
from sklearn.cluster import Birch
birch = Birch(n_clusters=6, threshold=0.5)
birch.fit(data_reduced)
labels = birch.labels_
ari = adjusted_rand_score(ground_truth_labels, labels)  # ground_truth_labels为真实的类别标签
nmi = normalized_mutual_info_score(ground_truth_labels, labels)  # ground_truth_labels为真实的类别标签
print("ARI: ", ari)
print("NMI: ", nmi)


##高斯混合模型聚类
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=6)
gmm.fit(data_reduced)
labels = gmm.predict(data_reduced)
ari = adjusted_rand_score(ground_truth_labels, labels)  # ground_truth_labels为真实的类别标签
nmi = normalized_mutual_info_score(ground_truth_labels, labels)  # ground_truth_labels为真实的类别标签
print("ARI: ", ari)
print("NMI: ", nmi)


##层次聚类
from sklearn.cluster import AgglomerativeClustering
clustering = AgglomerativeClustering(n_clusters = 10, affinity = "euclidean",
linkage = 'ward').fit(matrix)
ari = adjusted_rand_score(true_labels, clustering.labels_)
nmi = normalized_mutual_info_score(true_labels, clustering.labels_)
print(f"ARI score: {ari}")
print(f"NMI score: {nmi}")



##谱聚类
from sklearn.cluster import SpectralClustering
spectral = SpectralClustering(n_clusters=6, affinity='nearest_neighbors', assign_labels='kmeans')
labels = spectral.fit_predict(data_reduced)
ari = adjusted_rand_score(true_labels, clustering.labels_)
nmi = normalized_mutual_info_score(true_labels, clustering.labels_)
print(f"ARI score: {ari}")
print(f"NMI score: {nmi}")



##聚类结果二维可视化

from sklearn.manifold import TSNE 

tsne = TSNE(n_components=2, random_state=42, n_jobs=-1)
tsne_coords = tsne.fit_transform(X_encoded_reshape)

colors = {
    0: 'maroon',
    1: 'darkgreen',
    2: 'darkblue',
    3: 'teal',
    4: 'magenta',
    5: 'sienna',
    6: 'olive'
}
fig = plt.figure()
for i in range(len(tsne_coords)):
    plt.scatter(tsne_coords[i, 0], tsne_coords[i, 1], color=colors[labels[i]])

legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label=f'Cluster {i}', markerfacecolor=color, markersize=10) for i, color in colors.items()]
legend = plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))

fig.add_artist(legend)
plt.show()


##聚类结果三维可视化

pca = PCA(n_components=3)
pca.fit(matrix)
total_photo_3d = pca.transform(matrix)
centers = labels

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
markers = ['o', 'o', 'o', 'o', 'o', 'o', 'o']

for i in range(28):
    ax.scatter(total_photo_3d[labels == i, 0], 
               total_photo_3d[labels == i, 1], 
               total_photo_3d[labels == i, 2],
               marker = markers[i],
               label=f'Cluster {i+1}')
    
    ax.tick_params(axis='x', which='major', labelsize=12, length=4, width=1)
    ax.tick_params(axis='y', which='major', labelsize=12, length=4, width=1)
    ax.tick_params(axis='z', which='major', labelsize=12, length=4, width=1)
    
    ax.view_init(elev=0, azim=90)
    
    #ax.set_xticklabels('') 
    #ax.set_yticklabels('')
    #ax.set_zticklabels('')
    
fig.set_size_inches(10, 7)
plt.show()

