Skip to content

Commit

Permalink
Merge pull request #134 from EMI-Group/tfds
Browse files Browse the repository at this point in the history
Introduce Tensorflow Dataset
  • Loading branch information
BillHuang2001 committed Apr 23, 2024
2 parents e05c5ee + c150c31 commit 25f0026
Show file tree
Hide file tree
Showing 14 changed files with 236 additions and 456 deletions.
11 changes: 1 addition & 10 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -30,20 +30,11 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install build wheel setuptools pytest
- name: Build and install package Python 3.9 (no GPJax)
if: matrix.python-version == '3.9'
run: |
output=$(python -m build --wheel)
pip install dist/${output##* }[gymnasium,envpool,neuroevolution,distributed,test]
- name: Build and install package Python 3.10 and above
if: matrix.python-version == '3.10' || matrix.python-version == '3.11'
run: |
output=$(python -m build --wheel)
pip install dist/${output##* }[full,test]
- name: Test with pytest Python 3.9
if: matrix.python-version == '3.9'
run: |
pytest -k 'not test_im_moea and not test_gp'
- name: Test with pytest Python 3.10 and above
if: matrix.python-version == '3.10' || matrix.python-version == '3.11'
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Supervised Learning
.. toctree::
:maxdepth: 1

torchvision
tfds
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
==================
Tensorflow Dataset
==================

.. autoclass:: evox.problems.neuroevolution.TensorflowDataset
:members:

This file was deleted.

4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
"envpool",
"gymnasium",
"ray",
"torch",
"torchvision",
"tensorflow_datasets",
"grain",
"gpjax",
]

Expand Down
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ test = [
"chex >= 0.1.0",
"flax >= 0.5.0",
"pytest >= 6.0.0",
"tensorflow >= 2.12.0",
]

vis = [
Expand All @@ -55,21 +56,25 @@ gymnasium = ["gymnasium >= 0.29.0"]
envpool = ["envpool >= 0.8.0"]

neuroevolution = [
"torch >= 1.0.0",
"torchvision >= 0.1.0",
"tensorflow-datasets >= 4.0.0",
"grain >= 0.1.0",
"brax >= 0.1.0",
]

distributed = ["ray >= 2.0.0"]

full = [
"gymnasium >= 0.29.0",
"ray >= 2.0.0",
"torch >= 1.0.0",
"torchvision >= 0.1.0",
"envpool >= 0.8.0",
"gpjax >= 0.8.0",
"plotly >= 5.0.0",
"pandas >= 2.0.0",
"tensorflow-datasets >= 4.0.0",
"grain >= 0.1.0",
"brax >= 0.1.0",
"plotly >= 5.0.0",
"pandas >= 2.0.0",
]

gp = ["gpjax >= 0.8.0"]
Expand Down
4 changes: 2 additions & 2 deletions requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ flax >= 0.5.0
pytest >= 6.0.0
gymnasium >= 0.29.0
ray >= 2.0.0
torch >= 1.0.0
torchvision >= 0.1.0
tensorflow-datasets >= 4.0.0,
grain >= 0.1.0,
envpool >= 0.8.0
gpjax >= 0.8.0
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,10 @@
import jax.numpy as jnp
import numpy as np
from jax import jit, random, vmap, lax
from jax.tree_util import tree_map
from jax.experimental import io_callback

from evox import Problem, State, jit_class


def _x32_func_call(func):
def inner_func(*args, **kwargs):
return _to_x32_if_needed(func(*args, **kwargs))

return inner_func


def _to_x32_if_needed(values):
if jax.config.jax_enable_x64:
# we have 64-bit enabled, so nothing to do
return values

def to_x32(value):
if value.dtype == np.float64:
return value.astype(np.float32)
elif value.dtype == np.int64:
return value.astype(np.int32)
else:
return value
return tree_map(to_x32, values)
from evox.utils.io import to_x32_if_needed, x32_func_call


@jit_class
Expand Down Expand Up @@ -60,8 +38,8 @@ def evaluate(self, state, pop):
key, subkey = random.split(state.key)
seed = random.randint(subkey, (1,), 0, jnp.iinfo(jnp.int32).max)
io_callback(self.env.seed, None, seed)
obs, info = _to_x32_if_needed(self.env.reset(None))
obs, info = io_callback(_x32_func_call(self.env.reset), (obs, info), None)
obs, info = to_x32_if_needed(self.env.reset(None))
obs, info = io_callback(x32_func_call(self.env.reset), (obs, info), None)
total_reward = 0
i = 0

Expand All @@ -75,11 +53,11 @@ def cond_func(loop_state):
def step(loop_state):
i, done, total_reward, obs = loop_state
action = self.batch_policy(pop, obs)
obs, reward, terminated, truncated, info = _to_x32_if_needed(
obs, reward, terminated, truncated, info = to_x32_if_needed(
self.env.step(np.zeros(action.shape))
)
obs, reward, terminated, truncated, info = io_callback(
_x32_func_call(lambda action: self.env.step(np.copy(action))),
x32_func_call(lambda action: self.env.step(np.copy(action))),
(obs, reward, terminated, truncated, info),
action,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
try:
# optional dependency: torchvision, optax
from .torchvision_dataset import TorchvisionDataset
from .tfds import TensorflowDataset
except ImportError as e:
original_error_msg = str(e)

def TorchvisionDataset(*args, **kwargs):
def TensorflowDataset(*args, **kwargs):
raise ImportError(
f'TorchvisionDataset requires torchvision, optax but got "{original_error_msg}" when importing'
f'TensorflowDataset requires tensorflow-datasets, grain but got "{original_error_msg}" when importing'
)
135 changes: 135 additions & 0 deletions src/evox/problems/neuroevolution/supervised_learning/tfds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from dataclasses import field
from typing import Any, Callable, List, Optional

import grain.python as pygrain
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from jax.tree_util import tree_map

from evox import Problem, Static, dataclass, jit_class
from evox.utils.io import x32_func_call


def get_dtype_shape(data):
def to_dtype_struct(x):
if hasattr(x, "shape") and hasattr(x, "dtype"):
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
elif isinstance(x, int):
return jax.ShapeDtypeStruct(shape=(), dtype=jnp.int32)
elif isinstance(x, float):
return jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32)

return tree_map(to_dtype_struct, data)


@jit_class
@dataclass
class TensorflowDataset(Problem):
"""Wrap a tensorflow dataset as a problem.
TensorFlow Datasets (TFDS) directly depends on the package `tensorflow-datasets` and `grain`.
Additionally, when downloading the dataset for the first time, it requires `tensorflow` to be installed and a active internet connection.
If you want to avoid installing `tensorflow`, you can prepare the dataset beforehand in another environment with `tensorflow` installed,
run:
.. code-block:: python
import tensorflow_datasets as tfds
tfds.data_source(self.dataset)
and then copy the dataset to the target machine.
The default location is`~/tensorflow_datasets`. `~/` means the home directory of the user.
Please notice that the data is loaded under JAX's jit context, so the data should be valid JAX data type,
namely JAX or Numpy arrays, or Python's int, float, list, and dict.
If the data contains other types like strings, you should convert them into arrays using the `operations` parameter.
You can also download the dataset through a proxy server by setting the environment variable `TFDS_HTTP_PROXY` and `TFDS_HTTPS_PROXY`,
for http and https proxy respectively.
The details of the dataset can be found at https://www.tensorflow.org/datasets/catalog/overview
The details about operations/transformations can be found at https://github.com/google/grain/blob/main/docs/transformations.md
Parameters
----------
dataset
The dataset name.
batch_size
The batch size.
loss_func
The loss function.
The function signature is loss(weights, data) -> loss_value, and it should be jittable.
The `weight` is the weight of the neural network, and the `data` is the data from TFDS, which is a dictionary.
split
Which split of the dataset to use.
Default to "train".
operations
The list of transformations to apply to the data.
Default to [].
After the transformations, we will always apply a batch operation to create a batch of data.
datadir
The directory to store the dataset.
Default to None, which means tensorflow-datasets will automatically determine the directory.
seed
The random seed used to seed the dataloader.
Given the same seed, the dataloader should data in the same order.
Default to 0.
try_gcs
Whether to try to download the dataset from Google Cloud Storage.
Usually Google's storage server is faster than the original server of the dataset.
"""

dataset: Static[str]
batch_size: Static[int]
loss_func: Static[Callable]
split: Static[str] = field(default="train")
operations: Static[List[Any]] = field(default_factory=list)
datadir: Static[Optional[str]] = field(default=None)
seed: Static[int] = field(default=0)
try_gcs: Static[bool] = field(default=True)
iterator: Static[pygrain.PyGrainDatasetIterator] = field(init=False)
data_shape_dtypes: Static[Any] = field(init=False)

def __post_init__(self):
if self.datadir is None:
data_source = tfds.data_source(
self.dataset, try_gcs=self.try_gcs, split=self.split
)
else:
data_source = tfds.data_source(
self.dataset,
data_dir=self.datadir,
try_gcs=self.try_gcs,
split=self.split,
)

sampler = pygrain.IndexSampler(
num_records=len(data_source),
shard_options=pygrain.NoSharding(),
shuffle=True,
seed=self.seed,
)

operations = self.operations + [
pygrain.Batch(batch_size=self.batch_size, drop_remainder=True)
]

loader = pygrain.DataLoader(
data_source=data_source,
operations=operations,
sampler=sampler,
worker_count=0,
)
object.__setattr__(self, "iterator", iter(loader))
data_shape_dtypes = get_dtype_shape(self._next_data())
object.__setattr__(self, "data_shape_dtypes", data_shape_dtypes)

@x32_func_call
def _next_data(self):
return next(self.iterator)

def evaluate(self, state, pop):
data = jax.experimental.io_callback(self._next_data, self.data_shape_dtypes)
loss = jax.vmap(self.loss_func, in_axes=(0, None))(pop, data)
return loss, state
Loading

0 comments on commit 25f0026

Please sign in to comment.