Skip to content

Commit

Permalink
Enhancements and features
Browse files Browse the repository at this point in the history
* Feat: Allow multiple rows in combined network plots.
* Refact: suppress external logger info messages.
* Feat: function to match sets of vectors.
* Feat: Function to re-normalise mixing coefs with correlations.
* Enhance: Individual plots now deleted if combined = True for power and connectivity plots.
  • Loading branch information
cgohil8 committed May 23, 2024
1 parent e9f1898 commit c55188e
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 32 deletions.
20 changes: 16 additions & 4 deletions osl_dynamics/analysis/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
/tutorials_build/dynemo_plotting_networks.html>`_.
"""

import os
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -763,6 +764,7 @@ def save(
axes=None,
combined=False,
titles=None,
n_rows=1,
):
"""Save connectivity maps as image files.
Expand Down Expand Up @@ -795,9 +797,13 @@ def save(
List of matplotlib axes to plot the connectivity maps on.
combined : bool, optional
Should the connectivity maps be combined on the same figure?
The combined image is always shown on screen (for Juptyer notebooks).
Note if :code:`True` is passed, the individual images will be deleted.
titles : list, optional
List of titles for each connectivity map. Only used if
:code:`combined=True`.
n_rows : int, optional
Number of rows in the combined image. Only used if :code:`combined=True`.
Examples
--------
Expand Down Expand Up @@ -884,15 +890,21 @@ def save(
if filename is None:
raise ValueError("filename must be passed to save the combined image.")

n_columns = -(n_modes // -n_rows)
titles = titles or [None] * n_modes
fig, axes = plt.subplots(1, n_modes, figsize=(n_modes * 10, 5))
for i, ax in enumerate(axes):
ax.imshow(plt.imread(output_files[i]))
fig, axes = plt.subplots(n_rows, n_columns, figsize=(n_columns * 5, n_rows * 5))
for i, ax in enumerate(axes.flatten()):
ax.axis("off")
ax.set_title(titles[i], fontsize=20)
if i < n_modes:
ax.imshow(plt.imread(output_files[i]))
ax.set_title(titles[i], fontsize=20)
fig.tight_layout()
fig.savefig(filename)

# Remove the individual images
for output_file in output_files:
os.remove(output_file)


def save_interactive(
connectivity_map,
Expand Down
24 changes: 20 additions & 4 deletions osl_dynamics/analysis/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def parcel_vector_to_voxel_grid(mask_file, parcellation_file, vector):
Value at each voxel. Shape is (x, y, z), where :code:`x`,
:code:`y` and :code:`z` correspond to 3D voxel locations.
"""
# Suppress INFO messages from nibabel
logging.getLogger("nibabel.global").setLevel(logging.ERROR)

# Validation
mask_file = files.check_exists(mask_file, files.mask.directory)
parcellation_file = files.check_exists(
Expand Down Expand Up @@ -304,6 +307,7 @@ def save(
show_plots=True,
combined=False,
titles=None,
n_rows=1,
):
"""Saves power maps.
Expand Down Expand Up @@ -343,8 +347,11 @@ def save(
combined : bool, optional
Should the individual plots be combined into a single image?
The combined image is always shown on screen (for Juptyer notebooks).
Note if :code:`True` is passed, the individual images will be deleted.
titles : list, optional
List of titles for each power plot. Only used if :code:`combined=True`.
n_rows : int, optional
Number of rows in the combined image. Only used if :code:`combined=True`.
Returns
-------
Expand Down Expand Up @@ -470,16 +477,25 @@ def save(
plt.close(fig)

if combined:
n_columns = -(n_modes // -n_rows)

titles = titles or [None] * n_modes
# Combine images into a single image
fig, axes = plt.subplots(1, n_modes, figsize=(n_modes * 5, 5))
for i, ax in enumerate(axes):
ax.imshow(plt.imread(output_files[i]))
fig, axes = plt.subplots(
n_rows, n_columns, figsize=(n_columns * 5, n_rows * 5)
)
for i, ax in enumerate(axes.flatten()):
ax.axis("off")
ax.set_title(titles[i], fontsize=20)
if i < n_modes:
ax.imshow(plt.imread(output_files[i]))
ax.set_title(titles[i], fontsize=20)
fig.tight_layout()
fig.savefig(filename)

# Remove the individual images
for output_file in output_files:
os.remove(output_file)


def multi_save(
group_power_map,
Expand Down
16 changes: 8 additions & 8 deletions osl_dynamics/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def trim_time_series(
if verbose:
_logger.info(
f"Removing {n_remove} data points from the start and end"
+ " of each array due to time embedding/sliding window."
" of each array due to time embedding/sliding window."
)

# What data should we trim?
Expand All @@ -1009,8 +1009,8 @@ def trim_time_series(
n_keep = n_sequences * sequence_length
if verbose:
_logger.info(
f"Removing {array.shape[0] - n_keep} data points"
+ f" from the end of array {i} due to sequencing."
f"Removing {array.shape[0] - n_keep} data points "
f"from the end of array {i} due to sequencing."
)
array = array[:n_keep]

Expand Down Expand Up @@ -1175,8 +1175,8 @@ def dataset(
validation_dataset = full_dataset.skip(training_dataset_size)
_logger.info(
f"{len(training_dataset)} batches in training dataset, "
+ f"{len(validation_dataset)} batches in the validation "
+ "dataset."
f"{len(validation_dataset)} batches in the validation "
"dataset."
)

return training_dataset.prefetch(
Expand Down Expand Up @@ -1228,9 +1228,9 @@ def dataset(
)
_logger.info(
f"Session {i}: "
+ f"{len(training_datasets[i])} batches in training dataset, "
+ f"{len(validation_datasets[i])} batches in the validation "
+ "dataset."
f"{len(training_datasets[i])} batches in training dataset, "
f"{len(validation_datasets[i])} batches in the validation "
"dataset."
)
return training_datasets, validation_datasets

Expand Down
126 changes: 110 additions & 16 deletions osl_dynamics/inference/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange
from scipy.optimize import linear_sum_assignment
from scipy.cluster import hierarchy
from scipy import cluster, spatial, optimize
from sklearn.cluster import AgglomerativeClustering

from osl_dynamics import analysis, array_ops
Expand Down Expand Up @@ -294,7 +293,7 @@ def match_covariances(
F[j, k] = -metrics.pairwise_rv_coefficient(
np.array([covariances[i][k], covariances[0][j]])
)[0, 1]
order = linear_sum_assignment(F)[1]
order = optimize.linear_sum_assignment(F)[1]

# Add the ordered matrix to the list
matched_covariances.append(covariances[i][order])
Expand All @@ -306,6 +305,75 @@ def match_covariances(
return tuple(matched_covariances)


def match_vectors(*vectors, comparison="correlation", return_order=False):
"""Matches vectors.
Parameters
----------
vectors : tuple of np.ndarray
Sets of vectors to match.
Each variable must be shape (n_vectors, n_channels).
comparison : str, optional
Must be :code:`'correlation' or :code:`'cosine_similarity'`.
return_order : bool, optional
Should we return the order instead of the matched vectors?
Returns
-------
matched_vectors : tuple of np.ndarray
Set of matched vectors of shape (n_vectors, n_channels)
or order if :code:`return_order=True`.
Examples
--------
Reorder the vectors directly:
>>> v1, v2 = match_vectors(v1, v2, comparison="correlation")
Just get the reordering:
>>> orders = match_vectors(v1, v2, comparison="correlation", return_order=True)
>>> print(orders[0]) # order for v1 (always unchanged)
>>> print(orders[1]) # order for v2
"""
# Validation
for vector in vectors[1:]:
if vector.shape != vectors[0].shape:
raise ValueError("Vectors must have the same shape.")

if comparison not in ["correlation", "cosine_similarity"]:
raise ValueError("Comparison must be 'correlation' or 'cosine_similarity'.")

# Number of arguments and number of vectors in each argument passed
n_args = len(vectors)
n_vectors = vectors[0].shape[0]

# Calculate the similarity between vectors
F = np.empty([n_vectors, n_vectors])
matched_vectors = [vectors[0]]
orders = [np.arange(vectors[0].shape[0])]
for i in range(1, n_args):
for j in range(n_vectors):
# Find the vector that is most similar to vector j
for k in range(n_vectors):
if comparison == "correlation":
F[j, k] = -np.corrcoef(vectors[i][k], vectors[0][j])[0, 1]
elif comparison == "cosine_similarity":
F[j, k] = -(
1 - spatial.distance.cosine(vectors[i][k], vectors[0][j])
)
order = optimize.linear_sum_assignment(F)[1]

# Add the ordered vector to the list
matched_vectors.append(vectors[i][order])
orders.append(order)

if return_order:
return orders
else:
return tuple(matched_vectors)


def match_modes(*mode_time_courses, return_order=False):
"""Find correlated modes between mode time courses.
Expand Down Expand Up @@ -356,7 +424,7 @@ def match_modes(*mode_time_courses, return_order=False):
correlation = np.nan_to_num(
np.nan_to_num(correlation, nan=np.nanmin(correlation) - 1)
)
matches = linear_sum_assignment(-correlation)
matches = optimize.linear_sum_assignment(-correlation)
matched_mode_time_courses.append(mode_time_course[:n_samples, matches[1]])
orders.append(matches[1])

Expand Down Expand Up @@ -565,20 +633,46 @@ def reweight_alphas(alpha, covs):
reweighted_alpha : list of np.ndarray or np.ndarray
Re-weighted mixing coefficients. Shape is the same as :code:`alpha`.
"""
if isinstance(alpha, np.ndarray):
alpha = [alpha]
return reweight_mtc(alpha, covs, "covariance")

# Calculate normalised alphas
traces = np.trace(covs, axis1=1, axis2=2)
reweighted_alpha = [a * traces[np.newaxis, :] for a in alpha]
reweighted_alpha = [
na / np.sum(na, axis=1, keepdims=True) for na in reweighted_alpha
]

if len(reweighted_alpha) == 1:
reweighted_alpha = reweighted_alpha[0]
def reweight_mtc(mtc, params, params_type):
"""Re-weight mixing coefficients to account for the magnitude of
observation model parameters.
Parameters
----------
mtc : List[np.ndarray] or np.ndarray
Raw mixing coefficients. Shape must be (n_sessions, n_samples, n_modes)
or (n_samples, n_modes).
params : np.ndarray
Observation model parameters. Shape must be (n_modes, n_channels, n_channels).
params_type : str
Observation model parameters type. Either 'covariance' or 'correlation'.
Returns
-------
reweighted_mtc : List[np.ndarray]
Re-weighted mixing coefficients. Shape is the same as :code:`mtc`.
"""
if isinstance(mtc, np.ndarray):
mtc = [mtc]

if params_type == "covariance":
weights = np.trace(params, axis1=1, axis2=2)
elif params_type == "correlation":
m, n = np.tril_indices(params.shape[-1], -1)
weights = np.sum(np.abs(params[:, m, n]), axis=-1)
else:
raise ValueError("params_type must be 'covariance' or 'correlation'.")

reweighted_mtc = [x * weights[np.newaxis, :] for x in mtc]
reweighted_mtc = [x / np.sum(x, axis=1, keepdims=True) for x in reweighted_mtc]

if len(reweighted_mtc) == 1:
reweighted_mtc = reweighted_mtc[0]

return reweighted_alpha
return reweighted_mtc


def average_runs(alpha, n_clusters=None, return_cluster_info=False):
Expand Down Expand Up @@ -657,7 +751,7 @@ def average_runs(alpha, n_clusters=None, return_cluster_info=False):

if return_cluster_info:
# Create a dictionary containing the clustering info
linkage = hierarchy.linkage(dissimilarity, method="ward")
linkage = cluster.hierarchy.linkage(dissimilarity, method="ward")
cluster_info = {
"correlation": corr,
"dissimilarity": dissimilarity,
Expand Down
1 change: 1 addition & 0 deletions osl_dynamics/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_logger = logging.getLogger("osl-dynamics")

warnings.filterwarnings("ignore", category=NumbaWarning)
logging.getLogger("numba.core.transforms").setLevel(logging.ERROR)

EPS = sys.float_info.epsilon

Expand Down
3 changes: 3 additions & 0 deletions osl_dynamics/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

_logger = logging.getLogger("osl-dynamics")

# Suppress matplotlib warnings
logging.getLogger("matplotlib.category").setLevel(logging.ERROR)


def set_style(params):
"""Sets matplotlib's style.
Expand Down

0 comments on commit c55188e

Please sign in to comment.