Skip to content

Commit

Permalink
Option to specify random seed (#237)
Browse files Browse the repository at this point in the history
* Refact: removed random_seed argument from simulations.
* Feat: function to set random seed.
  • Loading branch information
cgohil8 committed Mar 26, 2024
1 parent 1143323 commit a8ed6ae
Show file tree
Hide file tree
Showing 22 changed files with 48 additions and 144 deletions.
1 change: 0 additions & 1 deletion examples/simulation/dive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
n_groups=3,
between_group_scale=0.2,
stay_prob=0.9,
random_seed=1234,
)
sim.standardize()
training_data = data.Data(sim.time_series)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/dynemo_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
stay_prob=0.9,
means="zero",
covariances="random",
random_seed=123,
)

# Create Data object for training
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/dynemo_hmm-mvn_high-n-modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
stay_prob=0.9,
means="zero",
covariances="random",
random_seed=123,
)
sim.standardize()
training_data = data.Data(sim.time_series)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/dynemo_long-range-dep1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
observation_error=0.0,
gamma_shape=10,
gamma_scale=5,
random_seed=123,
)
sim.standardize()
training_data = data.Data(sim.time_series)
Expand Down
3 changes: 0 additions & 3 deletions examples/simulation/dynemo_long-range-dep2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@
bottom_level_trans_probs=bottom_level_trans_probs,
means="zero",
covariances="random",
top_level_random_seed=123,
bottom_level_random_seeds=[124, 126, 127],
data_random_seed=555,
)
sim.standardize()
training_data = Data(sim.time_series)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/dynemo_soft-mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
sampling_frequency=250,
means="zero",
covariances="random",
random_seed=123,
)
sim_alp = sim.mode_time_course
training_data = data.Data(sim.time_series)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/hive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
n_groups=3,
between_group_scale=0.2,
stay_prob=0.9,
random_seed=1234,
)
sim.standardize()
sim_stc = np.concatenate(sim.mode_time_course)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/hmm-poi_hmm-poi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
n_channels=11,
trans_prob="sequence",
stay_prob=0.9,
random_seed=123,
)

# Create Data object for training
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/hmm_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
stay_prob=0.9,
means="zero",
covariances="random",
random_seed=123,
)

# Create Data object for training
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/hmm_hmm-mvn_fisher-kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
n_groups=3,
between_group_scale=0.2,
stay_prob=0.9,
random_seed=1234,
)
sim.standardize()

Expand Down
1 change: 0 additions & 1 deletion examples/simulation/hmm_tinda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
stay_prob=0.9,
means="zero",
covariances="random",
random_seed=42,
)
sim_stc = sim.state_time_course

Expand Down
1 change: 0 additions & 1 deletion examples/simulation/mdynemo_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
stay_prob=0.9,
means="random",
covariances="random",
random_seed=123,
)
sim.standardize()
training_data = data.Data(sim.time_series)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/state-dynemo_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
stay_prob=0.9,
means="zero",
covariances="random",
random_seed=123,
)
sim.standardize()
training_data = data.Data(sim.time_series)
Expand Down
1 change: 0 additions & 1 deletion examples/simulation/state-dynemo_hmm-mvn_high-n-states.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
stay_prob=0.9,
means="zero",
covariances="random",
random_seed=123,
)
sim.standardize()
training_data = data.Data(sim.time_series)
Expand Down
53 changes: 3 additions & 50 deletions osl_dynamics/simulation/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,13 @@ class HMM:
n_states : int, optional
Number of states. Needed when :code:`trans_prob` is a :code:`str` to
construct the transition probability matrix.
random_seed : int, optional
Seed for random number generator.
"""

def __init__(
self,
trans_prob,
stay_prob=None,
n_states=None,
random_seed=None,
):
if isinstance(trans_prob, list):
trans_prob = np.ndarray(trans_prob)
Expand Down Expand Up @@ -112,9 +109,6 @@ def __init__(
# Infer number of states from the transition probability matrix
self.n_states = self.trans_prob.shape[0]

# Setup random number generator
self._rng = np.random.default_rng(random_seed)

@staticmethod
def construct_sequence_trans_prob(stay_prob, n_states):
trans_prob = np.zeros([n_states, n_states])
Expand All @@ -133,7 +127,7 @@ def construct_uniform_trans_prob(stay_prob, n_states):
def generate_states(self, n_samples):
# Here the time course always start from state 0
rands = [
iter(self._rng.choice(self.n_states, size=n_samples, p=self.trans_prob[i]))
iter(np.random.choice(self.n_states, size=n_samples, p=self.trans_prob[i]))
for i in range(self.n_states)
]
states = np.zeros(n_samples, int)
Expand Down Expand Up @@ -165,8 +159,6 @@ class HMM_MAR(Simulation):
stay_prob : float, optional
Used to generate the transition probability matrix is
:code:`trans_prob` is a :code:`str`. Must be between 0 and 1.
random_seed : int, optional
Seed for random number generator.
"""

def __init__(
Expand All @@ -176,10 +168,9 @@ def __init__(
coeffs,
covs,
stay_prob=None,
random_seed=None,
):
# Observation model
self.obs_mod = MAR(coeffs=coeffs, covs=covs, random_seed=random_seed)
self.obs_mod = MAR(coeffs=coeffs, covs=covs)

self.n_states = self.obs_mod.n_states
self.n_channels = self.obs_mod.n_channels
Expand All @@ -190,7 +181,6 @@ def __init__(
trans_prob=trans_prob,
stay_prob=stay_prob,
n_states=self.n_states,
random_seed=random_seed if random_seed is None else random_seed + 1,
)

# Initialise base class
Expand Down Expand Up @@ -246,8 +236,6 @@ class HMM_MVN(Simulation):
is a :code:`str`. Must be between 0 and 1.
observation_error : float, optional
Standard deviation of the error added to the generated data.
random_seed : int, optional
Seed for random number generator.
"""

def __init__(
Expand All @@ -262,7 +250,6 @@ def __init__(
n_covariances_act=1,
stay_prob=None,
observation_error=0.0,
random_seed=None,
):
if n_states is None:
n_states = n_modes
Expand All @@ -275,7 +262,6 @@ def __init__(
n_channels=n_channels,
n_covariances_act=n_covariances_act,
observation_error=observation_error,
random_seed=random_seed,
)

self.n_states = self.obs_mod.n_modes
Expand All @@ -287,7 +273,6 @@ def __init__(
trans_prob=trans_prob,
stay_prob=stay_prob,
n_states=self.n_states,
random_seed=random_seed if random_seed is None else random_seed + 1,
)

# Initialise base class
Expand Down Expand Up @@ -365,8 +350,6 @@ class MDyn_HMM_MVN(Simulation):
is a :code:`str`. Must be between 0 and 1.
observation_error : float, optional
Standard deviation of the error added to the generated data.
random_seed : int, optional
Seed for random number generator.
"""

def __init__(
Expand All @@ -381,7 +364,6 @@ def __init__(
n_covariances_act=1,
stay_prob=None,
observation_error=0.0,
random_seed=None,
):
if n_states is None:
n_states = n_modes
Expand All @@ -394,7 +376,6 @@ def __init__(
n_channels=n_channels,
n_covariances_act=n_covariances_act,
observation_error=observation_error,
random_seed=random_seed,
)

self.n_states = self.obs_mod.n_modes
Expand All @@ -406,13 +387,11 @@ def __init__(
trans_prob=trans_prob,
stay_prob=stay_prob,
n_states=self.n_states,
random_seed=random_seed if random_seed is None else random_seed + 1,
)
self.gamma_hmm = HMM(
trans_prob=trans_prob,
stay_prob=stay_prob,
n_states=self.n_states,
random_seed=random_seed if random_seed is None else random_seed + 2,
)

# Initialise base class
Expand Down Expand Up @@ -476,8 +455,6 @@ class HMM_Poi(Simulation):
Shape must be (n_states, n_channels).
stay_prob : float
Used to generate the transition probability matrix is trans_prob is a str.
random_seed : int
Seed for random number generator.
"""

def __init__(
Expand All @@ -488,14 +465,12 @@ def __init__(
n_states=None,
n_channels=None,
stay_prob=None,
random_seed=None,
):
# Observation model
self.obs_mod = Poisson(
rates=rates,
n_states=n_states,
n_channels=n_channels,
random_seed=random_seed,
)

self.n_states = self.obs_mod.n_states
Expand All @@ -507,7 +482,6 @@ def __init__(
trans_prob=trans_prob,
stay_prob=stay_prob,
n_states=self.n_states,
random_seed=random_seed if random_seed is None else random_seed + 1,
)

# Initialise base class
Expand Down Expand Up @@ -583,8 +557,6 @@ class MSess_HMM_MVN(Simulation):
Standard deviation when generating session-specific stay probability.
observation_error : float, optional
Standard deviation of the error added to the generated data.
random_seed : int, optional
Seed for random number generator.
"""

def __init__(
Expand All @@ -606,7 +578,6 @@ def __init__(
tc_std=0.0,
stay_prob=None,
observation_error=0.0,
random_seed=None,
):
if n_states is None:
n_states = n_modes
Expand All @@ -629,7 +600,6 @@ def __init__(
n_groups=n_groups,
between_group_scale=between_group_scale,
observation_error=observation_error,
random_seed=random_seed,
)

self.n_states = self.obs_mod.n_modes
Expand All @@ -638,7 +608,7 @@ def __init__(

# Vary the stay probability for each session
if stay_prob is not None:
session_stay_prob = self.obs_mod._rng.normal(
session_stay_prob = np.random.normal(
loc=stay_prob,
scale=tc_std,
size=self.n_sessions,
Expand All @@ -662,7 +632,6 @@ def __init__(
trans_prob=trans_prob[i],
stay_prob=session_stay_prob[i],
n_states=self.n_states,
random_seed=random_seed if random_seed is None else random_seed + 1 + i,
)
self.hmm.append(hmm)
self.state_time_course.append(hmm.generate_states(self.n_samples))
Expand Down Expand Up @@ -742,12 +711,6 @@ class HierarchicalHMM_MVN(Simulation):
Number of iterations to add activations to covariance matrices.
observation_error : float, optional
Standard deviation of random noise to be added to the observations.
top_level_random_seed : int, optional
Random seed for generating the state time course of the top level HMM.
bottom_level_random_seeds : list of int, optional
Random seeds for the bottom level HMMs.
data_random_seed : int, optional
Random seed for generating the observed data.
top_level_stay_prob : float, optional
The stay_prob for the top level HMM. Used if
:code:`top_level_trans_prob` is a :code:`str`.
Expand Down Expand Up @@ -779,9 +742,6 @@ def __init__(
n_channels=None,
n_covariances_act=1,
observation_error=0.0,
top_level_random_seed=None,
bottom_level_random_seeds=None,
data_random_seed=None,
top_level_stay_prob=None,
bottom_level_stay_probs=None,
top_level_hmm_type="hmm",
Expand All @@ -799,7 +759,6 @@ def __init__(
n_channels=n_channels,
n_covariances_act=n_covariances_act,
observation_error=observation_error,
random_seed=data_random_seed,
)

self.n_states = self.obs_mod.n_modes
Expand All @@ -808,15 +767,11 @@ def __init__(
if bottom_level_stay_probs is None:
bottom_level_stay_probs = [None] * len(bottom_level_trans_probs)

if bottom_level_random_seeds is None:
bottom_level_random_seeds = [None] * len(bottom_level_trans_probs)

# Top level HMM
# This will select the bottom level HMM at each time point
if top_level_hmm_type.lower() == "hmm":
self.top_level_hmm = HMM(
trans_prob=top_level_trans_prob,
random_seed=top_level_random_seed,
stay_prob=top_level_stay_prob,
n_states=len(bottom_level_trans_probs),
)
Expand All @@ -825,7 +780,6 @@ def __init__(
gamma_shape=top_level_gamma_shape,
gamma_scale=top_level_gamma_scale,
n_states=len(bottom_level_trans_probs),
random_seed=top_level_random_seed,
)
else:
raise ValueError(f"Unsupported top_level_hmm_type: {top_level_hmm_type}")
Expand All @@ -836,7 +790,6 @@ def __init__(
self.bottom_level_hmms = [
HMM(
trans_prob=bottom_level_trans_probs[i],
random_seed=bottom_level_random_seeds[i],
stay_prob=bottom_level_stay_probs[i],
n_states=n_states,
)
Expand Down
Loading

0 comments on commit a8ed6ae

Please sign in to comment.