Skip to content

Commit

Permalink
Fixes:
Browse files Browse the repository at this point in the history
* Bug in MSess_HMM_MVN in special cases.
* Corrected dimension when using unit norm embeddings.
  • Loading branch information
cgohil8 committed Apr 19, 2024
1 parent 3ead063 commit 49d7b06
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/simulation/hive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
n_channels=20,
sequence_length=200,
n_sessions=100,
embeddings_dim=2,
embeddings_dim=5,
spatial_embeddings_dim=2,
dev_n_layers=5,
dev_n_units=32,
Expand Down Expand Up @@ -90,7 +90,7 @@

# Model initialization
model.random_state_time_course_initialization(
training_data, n_epochs=3, n_init=5, take=1
training_data, n_epochs=5, n_init=5, take=1
)

# Full model training
Expand Down
3 changes: 3 additions & 0 deletions osl_dynamics/inference/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,9 @@ class EmbeddingLayer(layers.Layer):

def __init__(self, input_dim, output_dim, unit_norm, **kwargs):
super().__init__(**kwargs)
if unit_norm:
output_dim = output_dim - 1

self.embedding_layer = layers.Embedding(
input_dim=input_dim,
output_dim=output_dim,
Expand Down
10 changes: 5 additions & 5 deletions osl_dynamics/simulation/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,10 +583,6 @@ def __init__(
if n_states is None:
n_states = n_modes

# Construct trans_prob for each session
if isinstance(trans_prob, str) or trans_prob is None:
trans_prob = [trans_prob] * n_sessions

# Observation model
self.obs_mod = MSess_MVN(
session_means=session_means,
Expand All @@ -607,6 +603,10 @@ def __init__(
self.n_channels = self.obs_mod.n_channels
self.n_sessions = self.obs_mod.n_sessions

# Construct trans_prob for each session
if isinstance(trans_prob, str) or trans_prob is None:
trans_prob = [trans_prob] * self.n_sessions

# Vary the stay probability for each session
if stay_prob is not None:
session_stay_prob = np.random.normal(
Expand All @@ -618,7 +618,7 @@ def __init__(
session_stay_prob = np.minimum(session_stay_prob, 1)
session_stay_prob = np.maximum(session_stay_prob, 0)
else:
session_stay_prob = [stay_prob] * n_sessions
session_stay_prob = [stay_prob] * self.n_sessions

# Initialise base class
super().__init__(n_samples=n_samples)
Expand Down
5 changes: 3 additions & 2 deletions osl_dynamics/simulation/mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,9 @@ def __init__(
self.n_modes = session_covariances.shape[1]
self.n_channels = session_covariances.shape[2]

self.validate_embedding_parameters()
self.create_embeddings()
if not session_means == "zero":
self.validate_embedding_parameters()
self.create_embeddings()

self.group_means = super().create_means(session_means)
self.session_means = self.create_session_means(session_means)
Expand Down

0 comments on commit 49d7b06

Please sign in to comment.