From 82a28560ce90e3d52ad8eb0392cb0ab818b1ca9a Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Tue, 21 May 2024 08:58:06 +0100 Subject: [PATCH] Feat: function to match sets of vectors. --- osl_dynamics/inference/modes.py | 78 ++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/osl_dynamics/inference/modes.py b/osl_dynamics/inference/modes.py index fa86b45c..ec686977 100644 --- a/osl_dynamics/inference/modes.py +++ b/osl_dynamics/inference/modes.py @@ -9,8 +9,7 @@ import numpy as np import matplotlib.pyplot as plt from tqdm.auto import trange -from scipy.optimize import linear_sum_assignment -from scipy.cluster import hierarchy +from scipy import cluster, spatial, optimize from sklearn.cluster import AgglomerativeClustering from osl_dynamics import analysis, array_ops @@ -294,7 +293,7 @@ def match_covariances( F[j, k] = -metrics.pairwise_rv_coefficient( np.array([covariances[i][k], covariances[0][j]]) )[0, 1] - order = linear_sum_assignment(F)[1] + order = optimize.linear_sum_assignment(F)[1] # Add the ordered matrix to the list matched_covariances.append(covariances[i][order]) @@ -306,6 +305,75 @@ def match_covariances( return tuple(matched_covariances) +def match_vectors(*vectors, comparison="correlation", return_order=False): + """Matches vectors. + + Parameters + ---------- + vectors : tuple of np.ndarray + Sets of vectors to match. + Each variable must be shape (n_vectors, n_channels). + comparison : str, optional + Must be :code:`'correlation' or :code:`'cosine_similarity'`. + return_order : bool, optional + Should we return the order instead of the matched vectors? + + Returns + ------- + matched_vectors : tuple of np.ndarray + Set of matched vectors of shape (n_vectors, n_channels) + or order if :code:`return_order=True`. + + Examples + -------- + Reorder the vectors directly: + + >>> v1, v2 = match_vectors(v1, v2, comparison="correlation") + + Just get the reordering: + + >>> orders = match_vectors(v1, v2, comparison="correlation", return_order=True) + >>> print(orders[0]) # order for v1 (always unchanged) + >>> print(orders[1]) # order for v2 + """ + # Validation + for vector in vectors[1:]: + if vector.shape != vectors[0].shape: + raise ValueError("Vectors must have the same shape.") + + if comparison not in ["correlation", "cosine_similarity"]: + raise ValueError("Comparison must be 'correlation' or 'cosine_similarity'.") + + # Number of arguments and number of vectors in each argument passed + n_args = len(vectors) + n_vectors = vectors[0].shape[0] + + # Calculate the similarity between vectors + F = np.empty([n_vectors, n_vectors]) + matched_vectors = [vectors[0]] + orders = [np.arange(vectors[0].shape[0])] + for i in range(1, n_args): + for j in range(n_vectors): + # Find the vector that is most similar to vector j + for k in range(n_vectors): + if comparison == "correlation": + F[j, k] = -np.corrcoef(vectors[i][k], vectors[0][j])[0, 1] + elif comparison == "cosine_similarity": + F[j, k] = -( + 1 - spatial.distance.cosine(vectors[i][k], vectors[0][j]) + ) + order = optimize.linear_sum_assignment(F)[1] + + # Add the ordered vector to the list + matched_vectors.append(vectors[i][order]) + orders.append(order) + + if return_order: + return orders + else: + return tuple(matched_vectors) + + def match_modes(*mode_time_courses, return_order=False): """Find correlated modes between mode time courses. @@ -356,7 +424,7 @@ def match_modes(*mode_time_courses, return_order=False): correlation = np.nan_to_num( np.nan_to_num(correlation, nan=np.nanmin(correlation) - 1) ) - matches = linear_sum_assignment(-correlation) + matches = optimize.linear_sum_assignment(-correlation) matched_mode_time_courses.append(mode_time_course[:n_samples, matches[1]]) orders.append(matches[1]) @@ -683,7 +751,7 @@ def average_runs(alpha, n_clusters=None, return_cluster_info=False): if return_cluster_info: # Create a dictionary containing the clustering info - linkage = hierarchy.linkage(dissimilarity, method="ward") + linkage = cluster.hierarchy.linkage(dissimilarity, method="ward") cluster_info = { "correlation": corr, "dissimilarity": dissimilarity,