Skip to content

Commit

Permalink
Feat: Allow embedding vectors to be passed in for simulation.
Browse files Browse the repository at this point in the history
  • Loading branch information
RukuangHuang committed May 13, 2024
1 parent 9c3e12e commit 59d5b81
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 48 deletions.
5 changes: 5 additions & 0 deletions osl_dynamics/simulation/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,9 @@ class MSess_HMM_MVN(Simulation):
Number of channels.
n_covariances_act : int, optional
Number of iterations to add activations to covariance matrices.
embedding_vectors : np.ndarray, optional
Embedding vectors for each state, shape should be
(n_states, embeddings_dim).
n_sessions : int, optional
Number of sessions.
embeddings_dim : int
Expand Down Expand Up @@ -570,6 +573,7 @@ def __init__(
n_modes=None,
n_channels=None,
n_covariances_act=1,
embedding_vectors=None,
n_sessions=None,
embeddings_dim=None,
spatial_embeddings_dim=None,
Expand All @@ -590,6 +594,7 @@ def __init__(
n_modes=n_states,
n_channels=n_channels,
n_covariances_act=n_covariances_act,
embedding_vectors=embedding_vectors,
n_sessions=n_sessions,
embeddings_dim=embeddings_dim,
spatial_embeddings_dim=spatial_embeddings_dim,
Expand Down
110 changes: 62 additions & 48 deletions osl_dynamics/simulation/mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ class MSess_MVN(MVN):
Number of channels.
n_covariances_act : int, optional
Number of iterations to add activations to covariance matrices.
embedding_vectors : np.ndarray, optional
Embedding vectors for each session, shape should be
(n_sessions, embeddings_dim).
n_sessions : int, optional
Number of sessions.
embeddings_dim : int, optional
Expand All @@ -400,6 +403,7 @@ def __init__(
n_modes=None,
n_channels=None,
n_covariances_act=1,
embedding_vectors=None,
n_sessions=None,
embeddings_dim=None,
spatial_embeddings_dim=None,
Expand All @@ -416,6 +420,12 @@ def __init__(
self.n_groups = n_groups
self.between_group_scale = between_group_scale

if embedding_vectors is not None:
n_sessions = embedding_vectors.shape[0]
self.n_sessions = n_sessions
embeddings_dim = embedding_vectors.shape[1]
self.embeddings_dim = embeddings_dim

# Both the session means and covariances were passed as numpy arrays
if isinstance(session_means, np.ndarray) and isinstance(
session_covariances, np.ndarray
Expand Down Expand Up @@ -468,8 +478,8 @@ def __init__(
self.n_modes = session_means.shape[1]
self.n_channels = session_means.shape[2]

self.validate_embedding_parameters()
self.create_embeddings()
self.validate_embedding_parameters(embedding_vectors)
self.create_embeddings(embedding_vectors)

self.group_means = None
self.session_means = session_means
Expand All @@ -486,8 +496,8 @@ def __init__(
self.n_channels = session_covariances.shape[2]

if not session_means == "zero":
self.validate_embedding_parameters()
self.create_embeddings()
self.validate_embedding_parameters(embedding_vectors)
self.create_embeddings(embedding_vectors)

self.group_means = super().create_means(session_means)
self.session_means = self.create_session_means(session_means)
Expand All @@ -509,60 +519,64 @@ def __init__(
self.n_modes = n_modes
self.n_channels = n_channels

self.validate_embedding_parameters()
self.create_embeddings()
self.validate_embedding_parameters(embedding_vectors)
self.create_embeddings(embedding_vectors)

self.group_means = super().create_means(session_means)
self.session_means = self.create_session_means(session_means)

self.group_covariances = super().create_covariances(session_covariances)
self.session_covariances = self.create_session_covariances()

def validate_embedding_parameters(self):
if self.embeddings_dim is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'embeddings_dim'."
)
if self.spatial_embeddings_dim is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'spatial_embeddings_dim'."
)
if self.embeddings_scale is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'embeddings_scale'."
)
if self.n_groups is None:
raise ValueError(
"Session means or covariances not passed, please pass 'n_groups'."
)
if self.between_group_scale is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'between_group_scale'."
)

def create_embeddings(self):
# Assign groups to sessions
assigned_groups = np.random.choice(self.n_groups, self.n_sessions)
self.group_centroids = np.random.normal(
scale=self.between_group_scale,
size=[self.n_groups, self.embeddings_dim],
)
def validate_embedding_parameters(self, embedding_vectors):
if embedding_vectors is None:
if self.embeddings_dim is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'embeddings_dim'."
)
if self.spatial_embeddings_dim is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'spatial_embeddings_dim'."
)
if self.embeddings_scale is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'embeddings_scale'."
)
if self.n_groups is None:
raise ValueError(
"Session means or covariances not passed, please pass 'n_groups'."
)
if self.between_group_scale is None:
raise ValueError(
"Session means or covariances not passed, please pass "
"'between_group_scale'."
)

embeddings = np.zeros([self.n_sessions, self.embeddings_dim])
for i in range(self.n_groups):
group_mask = assigned_groups == i
embeddings[group_mask] = np.random.multivariate_normal(
mean=self.group_centroids[i],
cov=self.embeddings_scale * np.eye(self.embeddings_dim),
size=[np.sum(group_mask)],
def create_embeddings(self, embedding_vectors):
if embedding_vectors is None:
# Assign groups to sessions
assigned_groups = np.random.choice(self.n_groups, self.n_sessions)
self.group_centroids = np.random.normal(
scale=self.between_group_scale,
size=[self.n_groups, self.embeddings_dim],
)

self.assigned_groups = assigned_groups
self.embeddings = embeddings
embeddings = np.zeros([self.n_sessions, self.embeddings_dim])
for i in range(self.n_groups):
group_mask = assigned_groups == i
embeddings[group_mask] = np.random.multivariate_normal(
mean=self.group_centroids[i],
cov=self.embeddings_scale * np.eye(self.embeddings_dim),
size=[np.sum(group_mask)],
)

self.assigned_groups = assigned_groups
self.embeddings = embeddings
else:
self.embeddings = embedding_vectors

def create_linear_transform(self, input_dim, output_dim, scale=0.1):
linear_transform = np.random.normal(
Expand Down

0 comments on commit 59d5b81

Please sign in to comment.