Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SC-DyNeMo to models. #218

Merged
merged 33 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8c6a74a
feat: Add DampedOscillatorCovarianceMatricesLayer
evanr70 Jan 25, 2024
ee468b9
feat: Add SC-DyNeMo model
evanr70 Jan 25, 2024
e7da554
fix typo and add retrieval functions
evanr70 Jan 25, 2024
2045e6a
Add dataclass decorator to Config class
evanr70 Jan 25, 2024
95384ea
Fix initialization of Config class in sc_dynemo.py
evanr70 Jan 25, 2024
81e01ac
Fix method calls in sc_dynemo.py
evanr70 Jan 25, 2024
3cefdba
refact: rename get_parameters ➡️ get_oscillator_parameters
evanr70 Jan 26, 2024
96022ee
refact: name returns to appease the machine 🤖
evanr70 Jan 26, 2024
b321b82
fix: typo in auto-covariance function equation
evanr70 Jan 26, 2024
c86965f
fix: typo in method name
evanr70 Jan 26, 2024
1320477
Refactor damped oscillator covariance layer
evanr70 Jan 26, 2024
006cc3c
fix: update functions to match refact
evanr70 Jan 26, 2024
0186aa8
Refact: formatting.
cgohil8 Jan 27, 2024
06b554f
Docstrings.
cgohil8 Jan 27, 2024
d053bd7
New separate module for custom errors.
cgohil8 Jan 27, 2024
ade150f
WIP: added SC-DyNeMo example.
cgohil8 Jan 28, 2024
fd22ed3
Reformatted with newest black version.
cgohil8 Jan 28, 2024
1a490be
added initializers and self.layers for resetting weights
evanr70 Jan 29, 2024
70e92d6
Refact: Avoid flake8 warning of undefined variable in evidence method.
RukuangHuang Jan 31, 2024
6258b31
Refact: don't load memmaps by default in Data class.
cgohil8 Feb 1, 2024
3f2ef1d
Add pushes to the main branch to the workflow.
evanr70 Feb 1, 2024
3fa5f9e
fix: actions/checkout@v4
evanr70 Feb 1, 2024
adb7e2f
fmt: black
evanr70 Feb 1, 2024
54a9712
Refact: avoid flake8 warning of undefined variable in hmm_poi evidenc…
RukuangHuang Feb 1, 2024
218025e
feat: Add constraint parameter to LearnableTensorLayer
evanr70 Feb 1, 2024
663f04e
docs: add constraint description
evanr70 Feb 1, 2024
4620e6f
Enhance: Individual plots now deleted if combined = True for power an…
RukuangHuang May 15, 2024
d6f9cdb
Feat: Function to re-normalise mixing coefs with correlations.
RukuangHuang May 16, 2024
d915f7d
Merge branch 'main' into feat/damped-oscillator
cgohil8 May 21, 2024
de98d38
Merge branch 'main' into feat/damped-oscillator
cgohil8 Jul 11, 2024
c975156
Minor tweaks.
cgohil8 Jul 11, 2024
1bf6c43
Updated example script.
cgohil8 Jul 11, 2024
3932f79
Added WIP note.
cgohil8 Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions examples/simulation/sc-dynemo_hmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Example script for training Single-Channel DyNeMo on simulated data.

Note, this model is a work in progress and has not been fully validated.

This example script achieves a dice ~= 0.8.
"""

print("Importing packages")

import numpy as np

from osl_dynamics.data import Data
from osl_dynamics.inference import modes, metrics
from osl_dynamics.simulation import HMM
from osl_dynamics.models.sc_dynemo import Config, Model
from osl_dynamics.utils import plotting

# Number of time points and sampling frequency
n_samples = 25600
sampling_frequency = 100

# Simulate a state time course
hmm = HMM(
trans_prob="sequence",
stay_prob=0.9,
n_states=3,
)
stc = hmm.generate_states(n_samples)

# Simulate observed data
t = np.arange(n_samples) / sampling_frequency
x = np.random.normal(0, 0.02, size=n_samples)

# State 1 - theta bursts
indices = stc[:, 0] == 1
phi = np.random.uniform(0, 2 * np.pi)
x[indices] += 2 * np.sin(2 * np.pi * 5 * t[indices] + phi)

# State 2 - alpha bursts
indices = stc[:, 1] == 1
phi = np.random.uniform(0, 2 * np.pi)
x[indices] += 2 * np.sin(2 * np.pi * 10 * t[indices] + phi)

# State 3 - beta bursts
indices = stc[:, 2] == 1
phi = np.random.uniform(0, 2 * np.pi)
x[indices] += np.sin(2 * np.pi * 20 * t[indices] + phi)

# Create Data object and prepare data
data = Data(x)
data.tde(n_embeddings=5)
data.standardize()

# Build model
config = Config(
n_modes=3,
n_channels=data.n_channels,
sequence_length=100,
inference_n_units=32,
inference_normalization="layer",
model_n_units=32,
model_normalization="layer",
learn_alpha_temperature=True,
initial_alpha_temperature=1.0,
learn_means=False,
learn_covariances=True,
learn_oscillator_amplitude=True,
oscillator_damping_limit=20,
oscillator_frequency_limit=(1, 30),
sampling_frequency=sampling_frequency,
do_kl_annealing=True,
kl_annealing_curve="tanh",
kl_annealing_sharpness=10,
n_kl_annealing_epochs=50,
batch_size=16,
learning_rate=0.001,
n_epochs=100,
)
model = Model(config)
model.summary()

# Train model
model.fit(data)

# Get inferred mixing coefficients and hard classify
alp = model.get_alpha(data)
alp = modes.argmax_time_courses(alp)

# Trim the simulate state time courses to match the inferred alphas
stc = stc[data.n_embeddings // 2 : alp.shape[0]]

# Match modes to simulation
stc, alp = modes.match_modes(stc, alp)

# Plot alphas
plotting.plot_alpha(stc, alp, n_samples=2000, filename="alpha.png")

# Print dice
dice = metrics.dice_coefficient(stc, alp)
print("Dice coefficient:", dice)
165 changes: 162 additions & 3 deletions osl_dynamics/inference/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Custom Tensorflow layers.

"""
"""Custom Tensorflow layers."""

import sys

Expand Down Expand Up @@ -449,6 +447,8 @@ class LearnableTensorLayer(layers.Layer):
Regularizer for the tensor. Must be from `inference.regularizers
<https://osl-dynamics.readthedocs.io/en/latest/autoapi/osl_dynamics\
/inference/regularizers/index.html>`_.
constraint : tf.keras.constraints.Constraint, optional
Constraint for the tensor. Limits the values the weights can take.
kwargs : keyword arguments, optional
Keyword arguments to pass to the base class.
"""
Expand All @@ -460,6 +460,7 @@ def __init__(
initializer=None,
initial_value=None,
regularizer=None,
constraint=None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -498,6 +499,9 @@ def __init__(
# This should be a function of the tensor that returns a float
self.regularizer = regularizer

# Constraint for the tensor
self.constraint = constraint

def add_regularization(self, tensor, static_loss_scaling_factor):
# Calculate the regularisation from the tensor
reg = self.regularizer(tensor)
Expand All @@ -515,6 +519,7 @@ def build(self, input_shape):
dtype=tf.float32,
initializer=self.tensor_initializer,
trainable=self.learn,
constraint=self.constraint,
)
self.built = True

Expand Down Expand Up @@ -898,6 +903,160 @@ def call(self, inputs, **kwargs):
return tf.linalg.diag(diagonals)


class DampedOscillatorLayer(layers.Layer):
"""Layer to learn a set of damped oscillators.

Parameters
----------
n : int
Number of oscillators.
m : int
Number of elements.
sampling_frequency : float
Sampling frequency in Hz.
damping_limit : float
Upper limit for the damping parameter.
Values are clipped to [0, damping_limit].
frequency_limit : tuple
Limits for the frequency parameter.
Upper limit should not be higher than the Nyquist frequency.
learn_amplitude : bool
Should the amplitudes be learnable?
If not, they will be fixed to 1.0.
Overriden if the general `learn` argument is False.
learn : bool
Should the oscillators be learnable?
kwargs : keyword arguments, optional
Keyword arguments to pass to the base class.
"""

def __init__(
self,
n,
m,
sampling_frequency,
damping_limit,
frequency_limit,
learn_amplitude,
learn,
**kwargs,
):
super().__init__(**kwargs)
self.sampling_frequency = sampling_frequency
self.damping_limit = damping_limit

self.tau = (
tf.expand_dims(tf.range(0, m, dtype=tf.float32), axis=0)
/ sampling_frequency
)

self.damping = LearnableTensorLayer(
shape=(n, 1),
learn=learn,
initializer=initializers.Constant(0.5),
name=self.name + "_damping",
)

self.frequency = LearnableTensorLayer(
shape=(n, 1),
learn=learn,
initializer=initializers.RandomUniform(
minval=frequency_limit[0],
maxval=frequency_limit[1],
),
name=self.name + "_frequency",
)

self.amplitude = LearnableTensorLayer(
shape=(n, 1),
learn=learn and learn_amplitude,
initializer=initializers.Constant(1.0),
name=self.name + "_amplitude",
)

self.layers = [self.damping, self.frequency, self.amplitude]

def call(self, inputs, **kwargs):
"""Calculate damped oscillator.

Note
----
The :code:`inputs` passed to this method are not used.
"""
damping = self.damping(inputs, **kwargs)
damping = tf.clip_by_value(damping, 0, self.damping_limit)
frequency = self.frequency(inputs, **kwargs)
frequency = tf.clip_by_value(frequency, 1, self.sampling_frequency / 2)
omega = 2 * np.pi * frequency
amplitude = self.amplitude(inputs, **kwargs)
return amplitude * tf.exp(-damping * self.tau) * tf.cos(omega * self.tau)


class DampedOscillatorCovarianceMatricesLayer(layers.Layer):
"""Layer to learn a set of damped oscillator covariances.

Parameters
----------
n : int
Number of matrices.
m : int
Number of rows/columns.
sampling_frequency : float
Sampling frequency in Hz.
damping_limit : float
Upper limit for the damping parameter.
Values are clipped to [0, damping_limit].
frequency_limit : tuple[float, float]
Limits for the frequency parameter.
Upper limit should not be higher than the Nyquist frequency.
learn_amplitude : bool
Should the amplitudes be learnable?
If not, they will be fixed to 1.0.
Overriden if the general `learn` argument is False.
learn : bool
Should the matrices be learnable?
kwargs : keyword arguments, optional
Keyword arguments to pass to the base class.
"""

def __init__(
self,
n,
m,
sampling_frequency,
damping_limit,
frequency_limit,
learn_amplitude,
learn,
**kwargs,
):
super().__init__(**kwargs)
self.oscillator_layer = DampedOscillatorLayer(
n=n,
m=m,
sampling_frequency=sampling_frequency,
damping_limit=damping_limit,
frequency_limit=frequency_limit,
learn_amplitude=learn_amplitude,
learn=learn,
)

self.layers = [self.oscillator_layer]

def call(self, inputs, **kwargs):
"""Retrieve the covariance matrices.

Note
----
The :code:`inputs` passed to this method are not used.
"""
oscillator = self.oscillator_layer(inputs, **kwargs)
return tf.linalg.LinearOperatorToeplitz(
row=oscillator,
col=oscillator,
).to_dense()


class MatrixLayer(layers.Layer):
"""Layer to learn a matrix.

Expand Down
Loading