Skip to content

Commit

Permalink
HIVE enhance (#248)
Browse files Browse the repository at this point in the history
* Fix: Bug in MSess_HMM_MVN in special cases.
* Feat: Added method to generate covariances from HIVE/DIVE prior.
* Doc: Updated docstrings, more consistent variable names and improved wrapper for HIVE.
* Rename: map --> param.
* black.
  • Loading branch information
RukuangHuang committed Apr 22, 2024
1 parent 49d7b06 commit 1f7ed56
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 82 deletions.
2 changes: 1 addition & 1 deletion osl_dynamics/config_api/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def train_hive(
save(f"{inf_params_dir}/session_means.npy", session_means)
save(f"{inf_params_dir}/session_covs.npy", session_covs)
save(f"{inf_params_dir}/summed_embeddings.npy", summed_embeddings)
save(f"{inf_params_dir}/embedding_weights.npy", embedding_weights)
save(f"{inf_params_dir}/embedding_weights.pkl", embedding_weights)


def get_inf_params(data, output_dir, observation_model_only=False):
Expand Down
36 changes: 18 additions & 18 deletions osl_dynamics/inference/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,48 +1454,48 @@ def call(self, inputs):
return concat_embeddings


class SessionMapLayer(layers.Layer):
"""Layer for getting the array specific maps.
class SessionParamLayer(layers.Layer):
"""Layer for getting the array specific parameters.
This layer adds deviations to the group spatial maps.
This layer adds deviations to the group spatial parameters.
Parameters
----------
which_map : str
Which spatial map are we using? Must be :code:`'means'` or
param : str
Which parameter are we using? Must be :code:`'means'` or
:code:`'covariances'`.
epsilon : float
Error added to the diagonal of covariances for numerical stability.
kwargs : keyword arguments, optional
Keyword arguments to pass to the base class.
"""

def __init__(self, which_map, epsilon, **kwargs):
def __init__(self, param, epsilon, **kwargs):
super().__init__(**kwargs)
self.which_map = which_map
self.param = param
self.epsilon = epsilon
if which_map == "covariances":
if param == "covariances":
self.bijector = tfb.Chain(
[tfb.CholeskyOuterProduct(), tfb.FillScaleTriL()],
)
elif which_map == "means":
elif param == "means":
self.bijector = tfb.Identity()
else:
raise ValueError("which_map must be one of 'means' and 'covariances'.")
raise ValueError("param must be one of 'means' and 'covariances'.")

def call(self, inputs):
group_map, dev = inputs
group_map = self.bijector.inverse(group_map)
group_param, dev = inputs
group_param = self.bijector.inverse(group_param)

# Match dimensions for addition
group_map = tf.expand_dims(group_map, axis=0)
session_map = tf.add(group_map, dev)
session_map = self.bijector(session_map)
group_param = tf.expand_dims(group_param, axis=0)
session_param = tf.add(group_param, dev)
session_param = self.bijector(session_param)

if self.which_map == "covariances":
session_map = add_epsilon(session_map, self.epsilon, diag=True)
if self.param == "covariances":
session_param = add_epsilon(session_param, self.epsilon, diag=True)

return session_map
return session_param


class MixSessionSpecificParametersLayer(layers.Layer):
Expand Down
6 changes: 3 additions & 3 deletions osl_dynamics/models/dive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
SampleGammaDistributionLayer,
SoftmaxLayer,
ConcatEmbeddingsLayer,
SessionMapLayer,
SessionParamLayer,
MixSessionSpecificParametersLayer,
ZeroLayer,
InverseCholeskyLayer,
Expand Down Expand Up @@ -880,10 +880,10 @@ def _model_structure(config):
# Add deviations to group level parameters

# Layer definitions
session_means_layer = SessionMapLayer(
session_means_layer = SessionParamLayer(
"means", config.covariances_epsilon, name="session_means"
)
session_covs_layer = SessionMapLayer(
session_covs_layer = SessionParamLayer(
"covariances", config.covariances_epsilon, name="session_covs"
)

Expand Down
6 changes: 3 additions & 3 deletions osl_dynamics/models/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
VectorsLayer,
CovarianceMatricesLayer,
ConcatEmbeddingsLayer,
SessionMapLayer,
SessionParamLayer,
ZeroLayer,
InverseCholeskyLayer,
SampleGammaDistributionLayer,
Expand Down Expand Up @@ -918,10 +918,10 @@ def _model_structure(config):
# Add deviations to group level parameters

# Layer definitions
session_means_layer = SessionMapLayer(
session_means_layer = SessionParamLayer(
"means", config.covariances_epsilon, name="session_means"
)
session_covs_layer = SessionMapLayer(
session_covs_layer = SessionParamLayer(
"covariances", config.covariances_epsilon, name="session_covs"
)

Expand Down
119 changes: 62 additions & 57 deletions osl_dynamics/models/obs_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,17 @@ def set_embeddings_initializer(model, initial_embeddings):
model : osl_dynamics.models.*.Model.model
The model. * must be :code:`hive` or :code:`dive`.
initial_embeddings : dict
The initial_embeddings dictionary. {name: initial_embeddings}
The initial_embeddings dictionary. {name: value}
"""

def _set_embeddings_initializer(layer_name, initial_embeddings):
# Helper function to set a single layer's initializer
def _set_embeddings_initializer(layer_name, value):
embedding_layer = model.get_layer(layer_name)
embedding_layer.embedding_layer.embeddings_initializer = WeightInitializer(
initial_embeddings
value
)

for k, v in config.items():
for k, v in initial_embeddings.items():
_set_embeddings_initializer(f"{k}_embeddings", v)


Expand Down Expand Up @@ -455,25 +456,25 @@ def get_covs_spatial_embeddings(model):
return covs_spatial_embeddings.numpy()


def get_spatial_embeddings(model, map):
def get_spatial_embeddings(model, param):
"""Wrapper for getting the spatial embeddings for the means and covariances."""
if map == "means":
if param == "means":
return get_means_spatial_embeddings(model)
elif map == "covs":
elif param == "covs":
return get_covs_spatial_embeddings(model)
else:
raise ValueError("map must be either 'means' or 'covs'")
raise ValueError("param must be either 'means' or 'covs'")


def get_concatenated_embeddings(model, map, session_labels):
def get_concatenated_embeddings(model, param, session_labels):
"""Get the concatenated embeddings.
Parameters
----------
model : osl_dynamics.models.*.Model.model
The model. * must be :code:`hive` or :code:`dive`.
map : str
The map to use. Either :code:`"means"` or :code:`"covs"`.
param : str
The param to use. Either :code:`"means"` or :code:`"covs"`.
embeddings : np.ndarray, optional
Input embeddings. If :code:`None`, they are retrieved from
the model. Shape is (n_sessions, embeddings_dim).
Expand All @@ -485,14 +486,14 @@ def get_concatenated_embeddings(model, map, session_labels):
embeddings_dim + spatial_embeddings_dim).
"""
embeddings = get_summed_embeddings(model, session_labels)
if map == "means":
if param == "means":
spatial_embeddings = get_means_spatial_embeddings(model)
concat_embeddings_layer = model.get_layer("means_concat_embeddings")
elif map == "covs":
elif param == "covs":
spatial_embeddings = get_covs_spatial_embeddings(model)
concat_embeddings_layer = model.get_layer("covs_concat_embeddings")
else:
raise ValueError("map must be either 'means' or 'covs'")
raise ValueError("param must be either 'means' or 'covs'")
concat_embeddings = concat_embeddings_layer([embeddings, spatial_embeddings])
return concat_embeddings.numpy()

Expand Down Expand Up @@ -561,55 +562,55 @@ def get_covs_dev_mag_parameters(model):
return covs_dev_mag_inf_alpha.numpy(), covs_dev_mag_inf_beta.numpy()


def get_dev_mag_parameters(model, map):
def get_dev_mag_parameters(model, param):
"""Wrapper for getting the deviance magnitude parameters for the means
and covariances."""
if map == "means":
if param == "means":
return get_means_dev_mag_parameters(model)
elif map == "covs":
elif param == "covs":
return get_covs_dev_mag_parameters(model)
else:
raise ValueError("map must be either 'means' or 'covs'")
raise ValueError("param must be either 'means' or 'covs'")


def get_dev_mag(model, map):
def get_dev_mag(model, param):
"""Getting the deviance magnitude.
Parameters
----------
model : osl_dynamics.models.*.Model.model
The model. * must be :code:`hive` or :code:`dive`.
map : str
The map. Must be either :code:`'means'` or :code:`'covs'`.
param : str
The param. Must be either :code:`'means'` or :code:`'covs'`.
Returns
-------
dev_mag : np.ndarray
The deviance magnitude. Shape is (n_sessions, n_states, 1).
"""
if map == "means":
if param == "means":
alpha, beta = get_means_dev_mag_parameters(model)
dev_mag_layer = model.get_layer("means_dev_mag")
elif map == "covs":
elif param == "covs":
alpha, beta = get_covs_dev_mag_parameters(model)
dev_mag_layer = model.get_layer("covs_dev_mag")
else:
raise ValueError("map must be either 'means' or 'covs'")
raise ValueError("param must be either 'means' or 'covs'")

n_sessions = alpha.shape[0]
dev_mag = dev_mag_layer([alpha, beta, np.arange(n_sessions)[..., None]])
return dev_mag.numpy()


def get_dev_map(model, map, session_labels):
def get_dev_map(model, param, session_labels):
"""Get the deviance map.
Parameters
----------
model : osl_dynamics.models.*.Model.model
The model. * must be :code:`hive` or :code:`dive`.
map : str
The map to use. Either :code:`"means"` or :code:`"covs"`.
param : str
The param to use. Either :code:`"means"` or :code:`"covs"`.
embeddings : np.ndarray, optional
Input embeddings. If :code:`None`, they are retrieved from
the model. Shape is (n_sessions, embeddings_dim).
Expand All @@ -618,21 +619,21 @@ def get_dev_map(model, map, session_labels):
-------
dev_map : np.ndarray
The deviance map.
If :code:`map="means"`, shape is (n_sessions, n_states, n_channels).
If :code:`map="covs"`, shape is (n_sessions, n_states,
If :code:`param="means"`, shape is (n_sessions, n_states, n_channels).
If :code:`param="covs"`, shape is (n_sessions, n_states,
n_channels * (n_channels + 1) // 2).
"""
concat_embeddings = get_concatenated_embeddings(model, map, session_labels)
if map == "means":
concat_embeddings = get_concatenated_embeddings(model, param, session_labels)
if param == "means":
dev_decoder_layer = model.get_layer("means_dev_decoder")
dev_map_layer = model.get_layer("means_dev_map")
norm_dev_map_layer = model.get_layer("norm_means_dev_map")
elif map == "covs":
elif param == "covs":
dev_decoder_layer = model.get_layer("covs_dev_decoder")
dev_map_layer = model.get_layer("covs_dev_map")
norm_dev_map_layer = model.get_layer("norm_covs_dev_map")
else:
raise ValueError("map must be either 'means' or 'covs'")
raise ValueError("param must be either 'means' or 'covs'")
dev_decoder = dev_decoder_layer(concat_embeddings)
dev_map = dev_map_layer(dev_decoder)
norm_dev_map = norm_dev_map_layer(dev_map)
Expand All @@ -655,12 +656,8 @@ def get_session_dev(
Whether the mean is learnt.
learn_covariances : bool
Whether the covariances are learnt.
embeddings : np.ndarray, optional
Input embeddings. Shape is (n_sessions, embeddings_dim).
If :code:`None`, then the embeddings are retrieved from the model.
n_neighbours : int, optional
The number of nearest neighbours if :code:`embedding` is not
:code:`None`.
session_labels : List[osl_dynamics.data.SessionLabel]
List of session labels.
Returns
-------
Expand Down Expand Up @@ -706,6 +703,8 @@ def get_session_means_covariances(
Whether the mean is learnt.
learn_covariances : bool
Whether the covariances are learnt.
session_labels : List[osl_dynamics.data.SessionLabel]
List of session labels.
Returns
-------
Expand All @@ -729,31 +728,37 @@ def get_session_means_covariances(
return mu.numpy(), D.numpy()


def get_nearest_neighbours(model, embeddings, n_neighbours):
"""Get the indices of the nearest neighours in the embedding space.
def generate_covariances(model, session_labels):
"""Generate covariances from the generative model.
Parameters
----------
model : osl_dynamics.models.*.Model.model
The model. * must be :code:`hive` or :code:`dive`.
embeddings : np.ndarray
Input embeddings. Shape is (n_sessions, embeddings_dim).
n_neighbours : int
The number of nearest neighbours.
session_labels : List[osl_dynamics.data.SessionLabel]
List of session labels.
Returns
-------
nearest_neighbours : np.ndarray
The indices of the nearest neighbours.
Shape is (n_sessions, n_neighbours).
covs : np.ndarray
The covariances. Shape is (n_sessions, n_states, n_channels, n_channels)
or (n_states, n_channels, n_channels).
"""
model_embeddings = get_embeddings(model)
distances = np.linalg.norm(
np.expand_dims(embeddings, axis=1) - np.expand_dims(model_embeddings, axis=0),
axis=-1,
)

# Sort distances and get indices of nearest neighbours
sorted_distances = np.argsort(distances, axis=1)
nearest_neighbours = sorted_distances[:, :n_neighbours]
return nearest_neighbours
dev_map = get_dev_map(model, "covs", session_labels)
concat_embeddings = get_concatenated_embeddings(model, "covs", session_labels)

covs_dev_decoder_layer = model.get_layer("covs_dev_decoder")
dev_mag_mod_layer = model.get_layer("covs_dev_mag_mod_beta")
dev_mag_mod = 1 / dev_mag_mod_layer(covs_dev_decoder_layer(concat_embeddings))

# Generate deviations
dev_layer = model.get_layer("covs_dev")
dev = dev_layer([dev_mag_mod, dev_map])

# Generate covariances
group_covs = get_observation_model_parameter(model, "group_covs")
covs_layer = model.get_layer("session_covs")
covs = np.squeeze(covs_layer([group_covs, dev]).numpy())

return covs

0 comments on commit 1f7ed56

Please sign in to comment.