Skip to content

Commit

Permalink
Upgrade TensorFlow Version (#253)
Browse files Browse the repository at this point in the history
* Install: upgraded to TensorFlow v2.11.0.
* Fix: bug re-initialising with TF 2.11.0.
* Fix: bug with data leakage in training-validation split.
* Refact: don't set random seed in HIVE simulation example.
  • Loading branch information
cgohil8 committed May 24, 2024
1 parent 3bc71bb commit f6874bf
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 64 deletions.
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ If you have already installed `OSL <https://github.com/OHBA-analysis/osl>`_ you
conda activate osl
cd osl-dynamics
pip install tensorflow==2.9.1
pip install tensorflow-probability==0.17
pip install tensorflow==2.11.0
pip install tensorflow-probability==0.19.0
pip install -e .
Note, if you're using a Mac computer you need to install TensorFlow with ``pip install tensorflow-macos==2.9.1`` instead of ``tensorflow==2.9.1``.
Note, if you're using a Mac computer you need to install TensorFlow with ``pip install tensorflow-macos==2.11.0`` instead of ``tensorflow==2.11.0``.

Removing osl-dynamics
---------------------
Expand Down
10 changes: 5 additions & 5 deletions doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ OSL Dynamics can be installed in three steps. Open a Terminal and execute the fo

::

conda create --name osld python=3.10
conda create --name osld python=3.10.14
conda activate osld

Note, this environment must be activated every time you want to use osl-dynamics.
Expand All @@ -23,27 +23,27 @@ OSL Dynamics can be installed in three steps. Open a Terminal and execute the fo

::

pip install tensorflow==2.9.1
pip install tensorflow==2.11.0

If you have GPU resources you may need to install additional libraries (CUDA/cuDNN), see https://www.tensorflow.org/install/pip for detailed instructions.

If you are using an Apple Mac, you will need to use the following instead:

::

pip install tensorflow-macos==2.9.1
pip install tensorflow-macos==2.11.0

If pip can not find the package, then you can try installing TensorFlow with conda:

::

conda install tensorflow=2.9.1
conda install tensorflow=2.11.0

After you have installed TensorFlow, install the tensorflow-probability addon with:

::

pip install tensorflow-probability==0.17
pip install tensorflow-probability==0.19.0

#. Finally, install osl-dynamics:

Expand Down
8 changes: 4 additions & 4 deletions envs/linux.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: osld
dependencies:
- python=3.10.13
- pip=23.3.1
- python=3.10.14
- pip=24.0
- pip:
- glmtools==0.2.1
- jupyter==1.0.0
Expand All @@ -21,5 +21,5 @@ dependencies:
- seaborn==0.13.0
- tabulate==0.9.0
- tqdm==4.66.1
- tensorflow==2.9.1
- tensorflow-probability==0.17
- tensorflow==2.11.0
- tensorflow-probability==0.19.0
8 changes: 4 additions & 4 deletions envs/mac.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: osld
dependencies:
- python=3.10.13
- pip=23.3.1
- python=3.10.14
- pip=24.0
- pip:
- glmtools==0.2.1
- jupyter==1.0.0
Expand All @@ -21,5 +21,5 @@ dependencies:
- seaborn==0.13.0
- tabulate==0.9.0
- tqdm==4.66.1
- tensorflow-macos==2.9.1
- tensorflow-probability==0.17
- tensorflow-macos==2.11.0
- tensorflow-probability==0.19.0
1 change: 0 additions & 1 deletion examples/simulation/hive_hmm-mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# GPU settings
tf_ops.gpu_growth()

set_random_seed(1234)
# Settings
config = Config(
n_states=5,
Expand Down
36 changes: 11 additions & 25 deletions osl_dynamics/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,13 +1166,11 @@ def dataset(
return full_dataset.prefetch(tf.data.AUTOTUNE)

else:
# Calculate how many batches should be in the training dataset
dataset_size = len(full_dataset)
training_dataset_size = round((1.0 - validation_split) * dataset_size)

# Split the full dataset into a training and validation dataset
training_dataset = full_dataset.take(training_dataset_size)
validation_dataset = full_dataset.skip(training_dataset_size)
training_dataset, validation_dataset = tf.keras.utils.split_dataset(
full_dataset,
right_size=validation_split,
)
_logger.info(
f"{len(training_dataset)} batches in training dataset, "
f"{len(validation_dataset)} batches in the validation "
Expand Down Expand Up @@ -1209,28 +1207,16 @@ def dataset(
training_datasets = []
validation_datasets = []
for i in range(len(full_datasets)):
# Calculate the number of batches in the training dataset
dataset_size = len(full_datasets[i])
training_dataset_size = round(
(1.0 - validation_split) * dataset_size
)

# Split this session's dataset
training_datasets.append(
full_datasets[i]
.take(training_dataset_size)
.prefetch(tf.data.AUTOTUNE)
)
validation_datasets.append(
full_datasets[i]
.skip(training_dataset_size)
.prefetch(tf.data.AUTOTUNE)
tds, vds = tf.keras.utils.split_dataset(
full_datasets[i],
right_size=validation_split,
)
training_datasets.append(tds.prefetch(tf.data.AUTOTUNE))
validation_datasets.append(vds.prefetch(tf.data.AUTOTUNE))
_logger.info(
f"Session {i}: "
f"{len(training_datasets[i])} batches in training dataset, "
f"{len(validation_datasets[i])} batches in the validation "
"dataset."
f"{len(tds)} batches in training dataset, "
f"{len(vds)} batches in the validation dataset."
)
return training_datasets, validation_datasets

Expand Down
32 changes: 13 additions & 19 deletions osl_dynamics/data/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,15 @@ def _parse_example(example):
return full_dataset.prefetch(tf.data.AUTOTUNE)

else:
# Calculate how many batches should be in the training dataset
dataset_size = dtf.get_n_batches(full_dataset)
training_dataset_size = round((1.0 - validation_split) * dataset_size)

# Split the dataset into training and validation datasets
training_dataset = full_dataset.take(training_dataset_size)
validation_dataset = full_dataset.skip(training_dataset_size)
training_dataset, validation_dataset = tf.keras.utils.split_dataset(
full_dataset,
right_size=validation_split,
)
_logger.info(
f"{training_dataset_size} batches in training dataset, "
+ f"{dataset_size - training_dataset_size} batches in the validation "
+ "dataset."
f"{len(training_dataset)} batches in training dataset, "
f"{len(validation_dataset)} batches in the validation "
"dataset."
)
return training_dataset.prefetch(
tf.data.AUTOTUNE
Expand Down Expand Up @@ -326,18 +324,14 @@ def _parse_example(example):
training_datasets = []
validation_datasets = []
for i, ds in enumerate(full_datasets):
# Calculate how many batches should be in the training dataset
dataset_size = dtf.get_n_batches(ds)
training_dataset_size = round((1.0 - validation_split) * dataset_size)

# Split the dataset into training and validation datasets
training_datasets.append(ds.take(training_dataset_size))
validation_datasets.append(ds.skip(training_dataset_size))
tds, vds = tf.keras.utils.split_dataset(
full_datasets[i],
right_size=validation_split,
)
_logger.info(
f"Session {i}: "
+ f"{training_dataset_size} batches in training dataset, "
+ f"{dataset_size - training_dataset_size} batches in the validation "
+ "dataset."
f"{len(tds)} batches in training dataset, "
f"{len(vds)} batches in the validation dataset."
)

return training_datasets, validation_datasets
Expand Down
5 changes: 2 additions & 3 deletions osl_dynamics/inference/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
"""

from copy import deepcopy

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
Expand Down Expand Up @@ -211,7 +209,8 @@ def reinitialize_layer_weights(layer):
#
# We need to create a new initializer to get new
# random values
new_initializer = deepcopy(initializer)
config = initializer.get_config()
new_initializer = initializer_type.from_config(config)

# Get the variable (i.e. weights) we want to re-initialize
if key == "recurrent_initializer":
Expand Down

0 comments on commit f6874bf

Please sign in to comment.