Skip to content

Commit

Permalink
Multiple embeddings (#235)
Browse files Browse the repository at this point in the history
* Refact: session_label as a dataclass for later development.
* Feat: draft hierarchical embedding.
* Fix: fixed wrapper and docstrings.
* Fix: Typo.
  • Loading branch information
RukuangHuang committed Mar 25, 2024
1 parent f3e990f commit ff7154d
Show file tree
Hide file tree
Showing 11 changed files with 552 additions and 303 deletions.
25 changes: 14 additions & 11 deletions examples/simulation/dive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import os
import numpy as np
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from osl_dynamics import data, simulation
from osl_dynamics.inference import metrics, modes, tf_ops
Expand Down Expand Up @@ -72,10 +72,12 @@
random_seed=1234,
)
sim.standardize()
training_data = data.Data(
sim.time_series,
session_labels={"session_id": np.arange(config.n_sessions)},
training_data = data.Data(sim.time_series)
training_data.add_session_labels(
"session_id", np.arange(config.n_sessions), "categorical"
)
training_data.add_session_labels("group_id", sim.assigned_groups, "categorical")
config.session_labels = training_data.get_session_labels()

# Build model
model = Model(config)
Expand All @@ -93,7 +95,7 @@
# Full training
print("Training model")
history = model.fit(training_data)

model.save("tmp")
# Free energy = Log Likelihood - KL Divergence
free_energy = model.free_energy(training_data)
print(f"Free energy: {free_energy}")
Expand All @@ -112,9 +114,10 @@

# Plot the simulated and inferred embeddings with group labels
sim_embeddings = sim.embeddings
inf_embeddings = model.get_embeddings()
inf_embeddings -= np.mean(inf_embeddings, axis=0)
inf_embeddings /= np.std(inf_embeddings, axis=0)
inf_embeddings = model.get_summed_embeddings()
lda_inf_embeddings = LinearDiscriminantAnalysis(n_components=2).fit_transform(
inf_embeddings, sim.assigned_groups
)
group_masks = [sim.assigned_groups == i for i in range(sim.n_groups)]

fig, axes = plotting.create_figure(1, 2, figsize=(10, 5))
Expand All @@ -132,8 +135,8 @@

# Perform PCA on the embeddings to visualise the embeddings
plotting.plot_scatter(
[inf_embeddings[group_mask, 0] for group_mask in group_masks],
[inf_embeddings[group_mask, 1] for group_mask in group_masks],
[lda_inf_embeddings[group_mask, 0] for group_mask in group_masks],
[lda_inf_embeddings[group_mask, 1] for group_mask in group_masks],
x_label="dim1",
y_label="dim2",
annotate=[
Expand All @@ -147,4 +150,4 @@
plotting.save(fig, filename="figures/embeddings.png")

# Delete temporary directory
training_data.delete_dir()
# training_data.delete_dir()
24 changes: 14 additions & 10 deletions examples/simulation/hive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import os
import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from osl_dynamics import data, simulation
from osl_dynamics.inference import metrics, modes, tf_ops
Expand All @@ -32,12 +33,12 @@
dev_activation="tanh",
dev_normalization="layer",
dev_regularizer="l1",
dev_regularizer_factor=10,
dev_regularizer_factor=0.01,
learn_means=False,
learn_covariances=True,
batch_size=64,
learning_rate=0.005,
lr_decay=0.05,
lr_decay=0.1,
n_epochs=40,
learn_trans_prob=True,
do_kl_annealing=True,
Expand Down Expand Up @@ -68,10 +69,12 @@
sim_stc = np.concatenate(sim.mode_time_course)

# Create training dataset
training_data = data.Data(
sim.time_series,
session_labels={"session_id": np.arange(config.n_sessions)},
training_data = data.Data(sim.time_series)
training_data.add_session_labels(
"session_id", np.arange(config.n_sessions), "categorical"
)
training_data.add_session_labels("group_id", sim.assigned_groups, "categorical")
config.session_labels = training_data.get_session_labels()

# Build model
model = Model(config)
Expand Down Expand Up @@ -128,9 +131,10 @@
print("Fractional occupancies (Inferred):", modes.fractional_occupancies(inf_stc))

sim_embeddings = sim.embeddings
inf_embeddings = model.get_embeddings()
inf_embeddings -= np.mean(inf_embeddings, axis=0)
inf_embeddings /= np.std(inf_embeddings, axis=0)
inf_embeddings = model.get_summed_embeddings()
lda_inf_embeddings = LinearDiscriminantAnalysis(n_components=2).fit_transform(
inf_embeddings, sim.assigned_groups
)
group_masks = [sim.assigned_groups == i for i in range(sim.n_groups)]

fig, axes = plotting.create_figure(1, 2, figsize=(10, 5))
Expand All @@ -147,8 +151,8 @@
)

plotting.plot_scatter(
[inf_embeddings[group_mask, 0] for group_mask in group_masks],
[inf_embeddings[group_mask, 1] for group_mask in group_masks],
[lda_inf_embeddings[group_mask, 0] for group_mask in group_masks],
[lda_inf_embeddings[group_mask, 1] for group_mask in group_masks],
x_label="dim_1",
y_label="dim_2",
annotate=[
Expand Down
10 changes: 8 additions & 2 deletions osl_dynamics/config_api/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ def train_hive(
if data is None:
raise ValueError("data must be passed.")

if not data.get_session_labels():
data.add_session_labels("session_id", np.arange(data.n_sessions), "categorical")

from osl_dynamics.models import hive

init_kwargs = {} if init_kwargs is None else init_kwargs
Expand Down Expand Up @@ -410,6 +413,7 @@ def train_hive(
"kl_annealing_curve": "tanh",
"kl_annealing_sharpness": 10,
"n_kl_annealing_epochs": 15,
"session_labels": data.get_session_labels(),
}
config_kwargs = override_dict_defaults(default_config_kwargs, config_kwargs)

Expand Down Expand Up @@ -461,15 +465,17 @@ def train_hive(
alpha = model.get_alpha(data)
means, covs = model.get_means_covariances()
session_means, session_covs = model.get_session_means_covariances()
embeddings = model.get_embeddings()
summed_embeddings = model.get_summed_embeddings()
embedding_weights = model.get_embedding_weights()

# Save inferred parameters
save(f"{inf_params_dir}/alp.pkl", alpha)
save(f"{inf_params_dir}/means.npy", means)
save(f"{inf_params_dir}/covs.npy", covs)
save(f"{inf_params_dir}/session_means.npy", session_means)
save(f"{inf_params_dir}/session_covs.npy", session_covs)
save(f"{inf_params_dir}/embeddings.npy", embeddings)
save(f"{inf_params_dir}/summed_embeddings.npy", summed_embeddings)
save(f"{inf_params_dir}/embedding_weights.npy", embedding_weights)


def get_inf_params(data, output_dir, observation_model_only=False):
Expand Down
2 changes: 1 addition & 1 deletion osl_dynamics/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
/tutorials_build/data_preparation.html>`_
"""

from osl_dynamics.data.base import Data, load_tfrecord_dataset
from osl_dynamics.data.base import Data, SessionLabels, load_tfrecord_dataset

__all__ = ["Data"]
68 changes: 56 additions & 12 deletions osl_dynamics/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from contextlib import contextmanager
from shutil import rmtree
import os
from dataclasses import dataclass

import numpy as np
from pqdm.threads import pqdm
Expand Down Expand Up @@ -85,7 +86,7 @@ class Data:
use_tfrecord : bool, optional
Should we save the data as a TensorFlow Record? This is recommended for
training on large datasets. Default is :code:`False`.
session_labels : dict, optional
session_labels : list of SessionLabels, optional
Extra session labels.
n_jobs : int, optional
Number of processes to load the data in parallel.
Expand Down Expand Up @@ -122,7 +123,6 @@ def __init__(
self.buffer_size = buffer_size
self.use_tfrecord = use_tfrecord
self.n_jobs = n_jobs
self.session_labels = dict()

# Validate inputs
self.inputs = rw.validate_inputs(inputs)
Expand Down Expand Up @@ -161,9 +161,8 @@ def __init__(
self.keep = list(range(self.n_sessions))

# Extra session labels
if session_labels is not None:
for label_name, label_values in session_labels.items():
self.add_session_labels(label_name, label_values)
if session_labels is None:
self.session_labels = []

def __iter__(self):
return iter(self.arrays)
Expand Down Expand Up @@ -1044,8 +1043,10 @@ def _create_data_dict(self, i, array):

# Add other session labels
placeholder = np.zeros(array.shape[0], dtype=np.float32)
for name, value in self.session_labels.items():
data[name] = placeholder + value[i]
for session_label in self.session_labels:
label_name = session_label.name
label_values = session_label.values
data[label_name] = placeholder + label_values[i]

return data

Expand Down Expand Up @@ -1315,7 +1316,7 @@ def _save_tfrecord(i, filepath):
"sequence_length": self.sequence_length,
"n_channels": self.n_channels,
"step_size": self.step_size,
"session_labels": list(self.session_labels.keys()),
"session_labels": [label.name for label in self.session_labels],
"n_sessions": self.n_sessions,
}
misc.save(f"{tfrecord_dir}/tfrecord_config.pkl", tfrecord_config)
Expand Down Expand Up @@ -1384,7 +1385,7 @@ def tfrecord_dataset(
keep=self.keep,
)

def add_session_labels(self, label_name, label_values):
def add_session_labels(self, label_name, label_values, label_type):
"""Add session labels as a new channel to the data.
Parameters
Expand All @@ -1393,15 +1394,25 @@ def add_session_labels(self, label_name, label_values):
Name of the new channel.
label_values : np.ndarray
Labels for each session.
label_type : str
Type of label, either "categorical" or "continuous".
"""
if label_values.ndim != 1:
raise ValueError("label_values must be a 1D array.")
if len(label_values) != self.n_sessions:
raise ValueError(
"label_values must have the same length as the number of sessions."
)

self.session_labels[label_name] = label_values
self.session_labels.append(SessionLabels(label_name, label_values, label_type))

def get_session_labels(self):
"""Get the session labels.
Returns
-------
session_labels : List[SessionLabels]
List of session labels.
"""
return self.session_labels

def save_preparation(self, output_dir="."):
"""Save a pickle file containing preparation settings.
Expand Down Expand Up @@ -1693,3 +1704,36 @@ def _parse_example(example):
)

return training_datasets, validation_datasets


@dataclass
class SessionLabels:
"""Class for session labels.
Parameters
----------
name : str
Name of the session label.
values : np.ndarray
Value for each session. Must be a 1D array of numbers.
label_type : str
Type of the session label. Options are "categorical" and "continuous".
"""

name: str
values: np.ndarray
label_type: str

def __post_init__(self):
if self.label_type not in ["categorical", "continuous"]:
raise ValueError("label_type must be 'categorical' or 'continuous'.")

if self.values.ndim != 1:
raise ValueError("values must be a 1D array.")

if self.label_type == "categorical":
self.values = self.values.astype(np.int32)
self.n_classes = len(np.unique(self.values))
else:
self.values = self.values.astype(np.float32)
self.n_classes = None
Loading

0 comments on commit ff7154d

Please sign in to comment.