Skip to content

Commit

Permalink
Feat: Allow extra session labels as extra channels in Data class. (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
RukuangHuang committed Mar 13, 2024
1 parent 9769f70 commit 6dd2358
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
5 changes: 4 additions & 1 deletion examples/simulation/dive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@
random_seed=1234,
)
sim.standardize()
training_data = data.Data(sim.time_series)
training_data = data.Data(
sim.time_series,
session_labels={"session_id": np.arange(config.n_sessions)},
)

# Build model
model = Model(config)
Expand Down
5 changes: 4 additions & 1 deletion examples/simulation/hive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@
sim_stc = np.concatenate(sim.mode_time_course)

# Create training dataset
training_data = data.Data(sim.time_series)
training_data = data.Data(
sim.time_series,
session_labels={"session_id": np.arange(config.n_sessions)},
)

# Build model
model = Model(config)
Expand Down
62 changes: 54 additions & 8 deletions osl_dynamics/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class Data:
use_tfrecord : bool, optional
Should we save the data as a TensorFlow Record? This is recommended for
training on large datasets. Default is :code:`False`.
session_labels : dict, optional
Extra session labels.
n_jobs : int, optional
Number of processes to load the data in parallel.
Default is 1, which loads data in serial.
Expand All @@ -104,6 +106,7 @@ def __init__(
store_dir="tmp",
buffer_size=100000,
use_tfrecord=False,
session_labels=None,
n_jobs=1,
):
self._identifier = id(self)
Expand All @@ -119,6 +122,7 @@ def __init__(
self.buffer_size = buffer_size
self.use_tfrecord = use_tfrecord
self.n_jobs = n_jobs
self.session_labels = dict()

# Validate inputs
self.inputs = rw.validate_inputs(inputs)
Expand Down Expand Up @@ -156,6 +160,11 @@ def __init__(
# Arrays to keep when making TensorFlow Datasets
self.keep = list(range(self.n_sessions))

# Extra session labels
if session_labels is not None:
for label_name, label_values in session_labels.items():
self.add_session_labels(label_name, label_values)

def __iter__(self):
return iter(self.arrays)

Expand Down Expand Up @@ -1016,6 +1025,30 @@ def count_sequences(self, sequence_length, step_size=None):
]
)

def _create_data_dict(self, i, array):
"""Create a dictionary of data for a single session.
Parameters
----------
i : int
Index of the session.
array : np.ndarray
Time series data for a single session.
Returns
-------
data : dict
Dictionary of data for a single session.
"""
data = {"data": array}

# Add other session labels
placeholder = np.zeros(array.shape[0], dtype=np.float32)
for name, value in self.session_labels.items():
data[name] = placeholder + value[i]

return data

def dataset(
self,
sequence_length,
Expand Down Expand Up @@ -1071,10 +1104,7 @@ def dataset(
# length
array = self.arrays[i][: n_sequences[i] * sequence_length]

# Dataset with the time series data and ID
array_tracker = np.zeros(array.shape[0], dtype=np.float32)
array_tracker = array_tracker + i
data = {"data": array, "session_id": array_tracker}
data = self._create_data_dict(i, array)

# Create dataset
dataset = dtf.create_dataset(
Expand Down Expand Up @@ -1243,10 +1273,7 @@ def _save_tfrecord(i, filepath):
# sequence length
array = self.arrays[i][: n_sequences[i] * sequence_length]

# Create a dataset with the time series data and ID
array_tracker = np.zeros(array.shape[0], dtype=np.float32)
array_tracker = array_tracker + i
data = {"data": array, "session_id": array_tracker}
data = self._create_data_dict(i, array)

# Save the dataset
dtf.save_tfrecord(
Expand Down Expand Up @@ -1399,6 +1426,25 @@ def _parse_example(example):

return training_datasets, validation_datasets

def add_session_labels(self, label_name, label_values):
"""Add session labels as a new channel to the data.
Parameters
----------
label_name : str
Name of the new channel.
label_values : np.ndarray
Labels for each session.
"""
if label_values.ndim != 1:
raise ValueError("label_values must be a 1D array.")
if len(label_values) != self.n_sessions:
raise ValueError(
"label_values must have the same length as the number of sessions."
)

self.session_labels[label_name] = label_values

def save_preparation(self, output_dir="."):
"""Save a pickle file containing preparation settings.
Expand Down

0 comments on commit 6dd2358

Please sign in to comment.