Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/aspire/source/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from scipy.linalg import eigh, qr
from sklearn.metrics import adjusted_rand_score

from aspire.image import Image
from aspire.noise import NoiseAdder
Expand Down Expand Up @@ -406,18 +407,19 @@ def eval_eigs(self, eigs_est, lambdas_est):

def eval_clustering(self, vol_idx):
"""
Evaluate clustering estimation
Evaluate clustering estimation using an adjusted Rand score.

:param vol_idx: Indexes of the volumes determined (0-indexed)
:return: Accuracy [0-1] in terms of proportion of correctly assigned labels
:return: Accuracy [-0.5, 1] in terms of proportion of correctly assigned labels.
Identical clusters (up to a permutation) have a score of 1, random labeling
will be close to 0, and discordant clusterings will be negative.
"""
assert (
len(vol_idx) == self.n
), f"Need {self.n} vol indexes to evaluate clustering"
# Remember that `states` is 1-indexed while vol_idx is 0-indexed
correctly_classified = np.sum(self.states - 1 == vol_idx)
# Remember that `states` is 1-indexed while vol_idx is 0-indexed.

return correctly_classified / self.n
return adjusted_rand_score(self.states - 1, vol_idx)

def eval_coords(self, mean_vol, eig_vols, coords_est):
"""
Expand All @@ -435,7 +437,6 @@ def eval_coords(self, mean_vol, eig_vols, coords_est):

# 0-indexed states vector
states = self.states - 1

coords_true = coords_true[states]
res_norms = res_norms[states]
res_inners = res_inners[:, states]
Expand Down