Skip to content

Commit

Permalink
Feat: function to match sets of vectors.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgohil8 committed May 21, 2024
1 parent d6f9cdb commit 82a2856
Showing 1 changed file with 73 additions and 5 deletions.
78 changes: 73 additions & 5 deletions osl_dynamics/inference/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange
from scipy.optimize import linear_sum_assignment
from scipy.cluster import hierarchy
from scipy import cluster, spatial, optimize
from sklearn.cluster import AgglomerativeClustering

from osl_dynamics import analysis, array_ops
Expand Down Expand Up @@ -294,7 +293,7 @@ def match_covariances(
F[j, k] = -metrics.pairwise_rv_coefficient(
np.array([covariances[i][k], covariances[0][j]])
)[0, 1]
order = linear_sum_assignment(F)[1]
order = optimize.linear_sum_assignment(F)[1]

# Add the ordered matrix to the list
matched_covariances.append(covariances[i][order])
Expand All @@ -306,6 +305,75 @@ def match_covariances(
return tuple(matched_covariances)


def match_vectors(*vectors, comparison="correlation", return_order=False):
"""Matches vectors.
Parameters
----------
vectors : tuple of np.ndarray
Sets of vectors to match.
Each variable must be shape (n_vectors, n_channels).
comparison : str, optional
Must be :code:`'correlation' or :code:`'cosine_similarity'`.
return_order : bool, optional
Should we return the order instead of the matched vectors?
Returns
-------
matched_vectors : tuple of np.ndarray
Set of matched vectors of shape (n_vectors, n_channels)
or order if :code:`return_order=True`.
Examples
--------
Reorder the vectors directly:
>>> v1, v2 = match_vectors(v1, v2, comparison="correlation")
Just get the reordering:
>>> orders = match_vectors(v1, v2, comparison="correlation", return_order=True)
>>> print(orders[0]) # order for v1 (always unchanged)
>>> print(orders[1]) # order for v2
"""
# Validation
for vector in vectors[1:]:
if vector.shape != vectors[0].shape:
raise ValueError("Vectors must have the same shape.")

if comparison not in ["correlation", "cosine_similarity"]:
raise ValueError("Comparison must be 'correlation' or 'cosine_similarity'.")

# Number of arguments and number of vectors in each argument passed
n_args = len(vectors)
n_vectors = vectors[0].shape[0]

# Calculate the similarity between vectors
F = np.empty([n_vectors, n_vectors])
matched_vectors = [vectors[0]]
orders = [np.arange(vectors[0].shape[0])]
for i in range(1, n_args):
for j in range(n_vectors):
# Find the vector that is most similar to vector j
for k in range(n_vectors):
if comparison == "correlation":
F[j, k] = -np.corrcoef(vectors[i][k], vectors[0][j])[0, 1]
elif comparison == "cosine_similarity":
F[j, k] = -(
1 - spatial.distance.cosine(vectors[i][k], vectors[0][j])
)
order = optimize.linear_sum_assignment(F)[1]

# Add the ordered vector to the list
matched_vectors.append(vectors[i][order])
orders.append(order)

if return_order:
return orders
else:
return tuple(matched_vectors)


def match_modes(*mode_time_courses, return_order=False):
"""Find correlated modes between mode time courses.
Expand Down Expand Up @@ -356,7 +424,7 @@ def match_modes(*mode_time_courses, return_order=False):
correlation = np.nan_to_num(
np.nan_to_num(correlation, nan=np.nanmin(correlation) - 1)
)
matches = linear_sum_assignment(-correlation)
matches = optimize.linear_sum_assignment(-correlation)
matched_mode_time_courses.append(mode_time_course[:n_samples, matches[1]])
orders.append(matches[1])

Expand Down Expand Up @@ -683,7 +751,7 @@ def average_runs(alpha, n_clusters=None, return_cluster_info=False):

if return_cluster_info:
# Create a dictionary containing the clustering info
linkage = hierarchy.linkage(dissimilarity, method="ward")
linkage = cluster.hierarchy.linkage(dissimilarity, method="ward")
cluster_info = {
"correlation": corr,
"dissimilarity": dissimilarity,
Expand Down

0 comments on commit 82a2856

Please sign in to comment.