Skip to content

Commit

Permalink
Refact: Cleaned up M-DyNeMo code.
Browse files Browse the repository at this point in the history
  • Loading branch information
RukuangHuang committed May 14, 2024
1 parent 676226e commit 1612dbb
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 207 deletions.
58 changes: 30 additions & 28 deletions examples/simulation/mdynemo_hmm-mvn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Example script for running inference on simulated MDyn_HMM_MVN data.
- Multi-dynamic version for dynemo_hmm-mvn.py
- Should achieve a dice of ~0.99 for alpha and ~0.99 for gamma.
- Should achieve a dice of ~0.99 for alpha and ~0.99 for beta.
"""

print("Setting up")
Expand Down Expand Up @@ -29,7 +29,7 @@
theta_normalization="layer",
learn_means=True,
learn_stds=True,
learn_fcs=True,
learn_corrs=True,
do_kl_annealing=True,
kl_annealing_curve="tanh",
kl_annealing_sharpness=10,
Expand Down Expand Up @@ -74,50 +74,50 @@
print(f"Free energy: {free_energy}")

# Inferred mode mixing factors
inf_alpha, inf_gamma = model.get_mode_time_courses(training_data)
inf_alpha, inf_beta = model.get_mode_time_courses(training_data)

inf_alpha = modes.argmax_time_courses(inf_alpha)
inf_gamma = modes.argmax_time_courses(inf_gamma)
inf_beta = modes.argmax_time_courses(inf_beta)

# Simulated mode mixing factors
sim_alpha, sim_gamma = sim.mode_time_course
sim_alpha, sim_beta = sim.mode_time_course

# Inferred means, stds, fcs
inf_means, inf_stds, inf_fcs = model.get_means_stds_fcs()
# Inferred means, stds, corrs
inf_means, inf_stds, inf_corrs = model.get_means_stds_corrs()
sim_means = sim.means
sim_stds = sim.stds
sim_fcs = sim.fcs
sim_corrs = sim.corrs

# Match the inferred and simulated mixing factors
_, order_alpha = modes.match_modes(sim_alpha, inf_alpha, return_order=True)
_, order_gamma = modes.match_modes(sim_gamma, inf_gamma, return_order=True)
_, order_beta = modes.match_modes(sim_beta, inf_beta, return_order=True)

inf_alpha = inf_alpha[:, order_alpha]
inf_gamma = inf_gamma[:, order_gamma]
inf_beta = inf_beta[:, order_beta]

inf_means = inf_means[order_alpha]
inf_stds = np.array([np.diag(std) for std in inf_stds[order_alpha]])
inf_fcs = inf_fcs[order_gamma]
inf_corrs = inf_corrs[order_beta]

# Dice coefficients
dice_alpha = metrics.dice_coefficient(sim_alpha, inf_alpha)
dice_gamma = metrics.dice_coefficient(sim_gamma, inf_gamma)
dice_beta = metrics.dice_coefficient(sim_beta, inf_beta)

print("Dice coefficient for mean:", dice_alpha)
print("Dice coefficient for fc:", dice_gamma)
print("Dice coefficient for power:", dice_alpha)
print("Dice coefficient for FC:", dice_beta)

# Fractional occupancies
fo_sim_alpha = modes.fractional_occupancies(sim_alpha)
fo_sim_gamma = modes.fractional_occupancies(sim_gamma)
fo_sim_beta = modes.fractional_occupancies(sim_beta)

fo_inf_alpha = modes.fractional_occupancies(inf_alpha)
fo_inf_gamma = modes.fractional_occupancies(inf_gamma)
fo_inf_beta = modes.fractional_occupancies(inf_beta)

print("Fractional occupancies mean (Simulation):", fo_sim_alpha)
print("Fractional occupancies mean (DyNeMo):", fo_inf_alpha)

print("Fractional occupancies fc (Simulation):", fo_sim_gamma)
print("Fractional occupancies fc (DyNeMo):", fo_inf_gamma)
print("Fractional occupancies FC (Simulation):", fo_sim_beta)
print("Fractional occupancies FC (DyNeMo):", fo_inf_beta)

# Plots
plotting.plot_alpha(
Expand All @@ -135,18 +135,18 @@
filename="figures/inf_alpha.png",
)
plotting.plot_alpha(
sim_gamma,
sim_beta,
n_samples=2000,
title="Ground truth " + r"$\gamma$",
y_labels=r"$\gamma_{jt}$",
filename="figures/sim_gamma.png",
title="Ground truth " + r"$\beta$",
y_labels=r"$\beta_{jt}$",
filename="figures/sim_beta.png",
)
plotting.plot_alpha(
inf_gamma,
inf_beta,
n_samples=2000,
title="Inferred " + r"$\gamma$",
y_labels=r"$\gamma_{jt}$",
filename="figures/inf_gamma.png",
title="Inferred " + r"$\beta$",
y_labels=r"$\beta_{jt}$",
filename="figures/inf_beta.png",
)
plotting.plot_matrices(
sim_means, main_title="Ground Truth", filename="figures/sim_means.png"
Expand All @@ -160,8 +160,10 @@
)
plotting.plot_matrices(inf_stds, main_title="Inferred", filename="figures/inf_stds.png")
plotting.plot_matrices(
sim_fcs, main_title="Ground Truth", filename="figures/sim_fcs.png"
sim_corrs, main_title="Ground Truth", filename="figures/sim_corrs.png"
)
plotting.plot_matrices(
inf_corrs, main_title="Inferred", filename="figures/inf_corrs.png"
)
plotting.plot_matrices(inf_fcs, main_title="Inferred", filename="figures/inf_fcs.png")

training_data.delete_dir()
58 changes: 29 additions & 29 deletions osl_dynamics/models/inf_mod_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def predict(self, *args, **kwargs):
if not self.config.multiple_dynamics:
return_names = ["ll_loss", "kl_loss", "theta"]
else:
return_names = ["ll_loss", "kl_loss", "mean_theta", "fc_theta"]
return_names = ["ll_loss", "kl_loss", "power_theta", "fc_theta"]
predictions_dict = dict(zip(return_names, predictions))

return predictions_dict
Expand Down Expand Up @@ -686,8 +686,8 @@ def get_mode_logits(
Returns
-------
mean_theta : list or np.ndarray
Mode mixing logits for mean with shape (n_sessions, n_samples,
power_theta : list or np.ndarray
Mode mixing logits for power with shape (n_sessions, n_samples,
n_modes) or (n_samples, n_modes).
fc_theta : list or np.ndarray
Mode mixing logits for FC with shape (n_sessions, n_samples,
Expand Down Expand Up @@ -718,32 +718,32 @@ def get_mode_logits(
iterator = range(n_datasets)
_logger.info("Getting mode logits")

mean_theta = []
power_theta = []
fc_theta = []
for i in iterator:
predictions = self.predict(dataset[i], **kwargs)
mean_theta_ = predictions["mean_theta"]
power_theta_ = predictions["power_theta"]
fc_theta_ = predictions["fc_theta"]
if remove_edge_effects:
trim = step_size // 2 # throw away 25%
mean_theta_ = (
[mean_theta_[0, :-trim]]
+ list(mean_theta_[1:-1, trim:-trim])
+ [mean_theta_[-1, trim:]]
power_theta_ = (
[power_theta_[0, :-trim]]
+ list(power_theta_[1:-1, trim:-trim])
+ [power_theta_[-1, trim:]]
)
fc_theta_ = (
[fc_theta_[0, :-trim]]
+ list(fc_theta_[1:-1, trim:-trim])
+ [fc_theta_[-1, trim:]]
)
mean_theta.append(np.concatenate(mean_theta_))
power_theta.append(np.concatenate(power_theta_))
fc_theta.append(np.concatenate(fc_theta_))

if concatenate or len(mean_theta) == 1:
mean_theta = np.concatenate(mean_theta)
if concatenate or len(power_theta) == 1:
power_theta = np.concatenate(power_theta)
fc_theta = np.concatenate(fc_theta)

return mean_theta, fc_theta
return power_theta, fc_theta

def get_alpha(
self, dataset, concatenate=False, remove_edge_effects=False, **kwargs
Expand Down Expand Up @@ -830,11 +830,11 @@ def get_mode_time_courses(
Prediction data. This can be a list of datasets, one for each
session.
concatenate : bool, optional
Should we concatenate alpha/gamma for each session?
Should we concatenate alpha/beta for each session?
remove_edge_effects : bool, optional
Edge effects can arise due to separating the data into sequences.
We can remove these by predicting overlapping :code:`alpha`/
:code:`gamma` and disregarding the :code:`alpha`/:code:`gamma` near
:code:`beta` and disregarding the :code:`alpha`/:code:`beta` near
the ends. Passing :code:`True` does this by using sequences with 50%
overlap and throwing away the first and last 25% of predictions.
Expand All @@ -843,8 +843,8 @@ def get_mode_time_courses(
alpha : list or np.ndarray
Alpha time course with shape (n_sessions, n_samples, n_modes) or
(n_samples, n_modes).
gamma : list or np.ndarray
Gamma time course with shape (n_sessions, n_samples, n_modes) or
beta : list or np.ndarray
Beta time course with shape (n_sessions, n_samples, n_modes) or
(n_samples, n_modes).
"""
if self.is_multi_gpu:
Expand All @@ -864,7 +864,7 @@ def get_mode_time_courses(

dataset = self.make_dataset(dataset, step_size=step_size)
alpha_layer = self.model.get_layer("alpha")
gamma_layer = self.model.get_layer("gamma")
beta_layer = self.model.get_layer("beta")

n_datasets = len(dataset)
if len(dataset) > 1:
Expand All @@ -875,33 +875,33 @@ def get_mode_time_courses(
_logger.info("Getting mode time courses")

alpha = []
gamma = []
beta = []
for i in iterator:
predictions = self.predict(dataset[i], **kwargs)
mean_theta = predictions["mean_theta"]
power_theta = predictions["power_theta"]
fc_theta = predictions["fc_theta"]
alpha_ = alpha_layer(mean_theta)
gamma_ = gamma_layer(fc_theta)
alpha_ = alpha_layer(power_theta)
beta_ = beta_layer(fc_theta)
if remove_edge_effects:
trim = step_size // 2 # throw away 25%
alpha_ = (
[alpha_[0, :-trim]]
+ list(alpha_[1:-1, trim:-trim])
+ [alpha_[-1, trim:]]
)
gamma_ = (
[gamma_[0, :-trim]]
+ list(gamma_[1:-1, trim:-trim])
+ [gamma_[-1, trim:]]
beta_ = (
[beta_[0, :-trim]]
+ list(beta_[1:-1, trim:-trim])
+ [beta_[-1, trim:]]
)
alpha.append(np.concatenate(alpha_))
gamma.append(np.concatenate(gamma_))
beta.append(np.concatenate(beta_))

if concatenate or len(alpha) == 1:
alpha = np.concatenate(alpha)
gamma = np.concatenate(gamma)
beta = np.concatenate(beta)

return alpha, gamma
return alpha, beta

def losses(self, dataset, **kwargs):
"""Calculates the log-likelihood and KL loss for a dataset.
Expand Down
Loading

0 comments on commit 1612dbb

Please sign in to comment.