Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Tensorflow Dataset #134

Merged
merged 10 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -30,20 +30,11 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install build wheel setuptools pytest
- name: Build and install package Python 3.9 (no GPJax)
if: matrix.python-version == '3.9'
run: |
output=$(python -m build --wheel)
pip install dist/${output##* }[gymnasium,envpool,neuroevolution,distributed,test]
- name: Build and install package Python 3.10 and above
if: matrix.python-version == '3.10' || matrix.python-version == '3.11'
run: |
output=$(python -m build --wheel)
pip install dist/${output##* }[full,test]
- name: Test with pytest Python 3.9
if: matrix.python-version == '3.9'
run: |
pytest -k 'not test_im_moea and not test_gp'
- name: Test with pytest Python 3.10 and above
if: matrix.python-version == '3.10' || matrix.python-version == '3.11'
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Supervised Learning
.. toctree::
:maxdepth: 1

torchvision
tfds
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
==================
Tensorflow Dataset
==================

.. autoclass:: evox.problems.neuroevolution.TensorflowDataset
:members:

This file was deleted.

4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
"envpool",
"gymnasium",
"ray",
"torch",
"torchvision",
"tensorflow_datasets",
"grain",
"gpjax",
]

Expand Down
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ test = [
"chex >= 0.1.0",
"flax >= 0.5.0",
"pytest >= 6.0.0",
"tensorflow >= 2.12.0",
]

vis = [
Expand All @@ -55,21 +56,25 @@ gymnasium = ["gymnasium >= 0.29.0"]
envpool = ["envpool >= 0.8.0"]

neuroevolution = [
"torch >= 1.0.0",
"torchvision >= 0.1.0",
"tensorflow-datasets >= 4.0.0",
"grain >= 0.1.0",
"brax >= 0.1.0",
]

distributed = ["ray >= 2.0.0"]

full = [
"gymnasium >= 0.29.0",
"ray >= 2.0.0",
"torch >= 1.0.0",
"torchvision >= 0.1.0",
"envpool >= 0.8.0",
"gpjax >= 0.8.0",
"plotly >= 5.0.0",
"pandas >= 2.0.0",
"tensorflow-datasets >= 4.0.0",
"grain >= 0.1.0",
"brax >= 0.1.0",
"plotly >= 5.0.0",
"pandas >= 2.0.0",
]

gp = ["gpjax >= 0.8.0"]
Expand Down
4 changes: 2 additions & 2 deletions requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ flax >= 0.5.0
pytest >= 6.0.0
gymnasium >= 0.29.0
ray >= 2.0.0
torch >= 1.0.0
torchvision >= 0.1.0
tensorflow-datasets >= 4.0.0,
grain >= 0.1.0,
envpool >= 0.8.0
gpjax >= 0.8.0
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,10 @@
import jax.numpy as jnp
import numpy as np
from jax import jit, random, vmap, lax
from jax.tree_util import tree_map
from jax.experimental import io_callback

from evox import Problem, State, jit_class


def _x32_func_call(func):
def inner_func(*args, **kwargs):
return _to_x32_if_needed(func(*args, **kwargs))

return inner_func


def _to_x32_if_needed(values):
if jax.config.jax_enable_x64:
# we have 64-bit enabled, so nothing to do
return values

def to_x32(value):
if value.dtype == np.float64:
return value.astype(np.float32)
elif value.dtype == np.int64:
return value.astype(np.int32)
else:
return value
return tree_map(to_x32, values)
from evox.utils.io import to_x32_if_needed, x32_func_call


@jit_class
Expand Down Expand Up @@ -60,8 +38,8 @@ def evaluate(self, state, pop):
key, subkey = random.split(state.key)
seed = random.randint(subkey, (1,), 0, jnp.iinfo(jnp.int32).max)
io_callback(self.env.seed, None, seed)
obs, info = _to_x32_if_needed(self.env.reset(None))
obs, info = io_callback(_x32_func_call(self.env.reset), (obs, info), None)
obs, info = to_x32_if_needed(self.env.reset(None))
obs, info = io_callback(x32_func_call(self.env.reset), (obs, info), None)
total_reward = 0
i = 0

Expand All @@ -75,11 +53,11 @@ def cond_func(loop_state):
def step(loop_state):
i, done, total_reward, obs = loop_state
action = self.batch_policy(pop, obs)
obs, reward, terminated, truncated, info = _to_x32_if_needed(
obs, reward, terminated, truncated, info = to_x32_if_needed(
self.env.step(np.zeros(action.shape))
)
obs, reward, terminated, truncated, info = io_callback(
_x32_func_call(lambda action: self.env.step(np.copy(action))),
x32_func_call(lambda action: self.env.step(np.copy(action))),
(obs, reward, terminated, truncated, info),
action,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
try:
# optional dependency: torchvision, optax
from .torchvision_dataset import TorchvisionDataset
from .tfds import TensorflowDataset
except ImportError as e:
original_error_msg = str(e)

def TorchvisionDataset(*args, **kwargs):
def TensorflowDataset(*args, **kwargs):
raise ImportError(
f'TorchvisionDataset requires torchvision, optax but got "{original_error_msg}" when importing'
f'TensorflowDataset requires tensorflow-datasets, grain but got "{original_error_msg}" when importing'
)
135 changes: 135 additions & 0 deletions src/evox/problems/neuroevolution/supervised_learning/tfds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from dataclasses import field
from typing import Any, Callable, List, Optional

import grain.python as pygrain
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds
from jax.tree_util import tree_map

from evox import Problem, Static, dataclass, jit_class
from evox.utils.io import x32_func_call


def get_dtype_shape(data):
def to_dtype_struct(x):
if hasattr(x, "shape") and hasattr(x, "dtype"):
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
elif isinstance(x, int):
return jax.ShapeDtypeStruct(shape=(), dtype=jnp.int32)
elif isinstance(x, float):
return jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32)

return tree_map(to_dtype_struct, data)


@jit_class
@dataclass
class TensorflowDataset(Problem):
"""Wrap a tensorflow dataset as a problem.

TensorFlow Datasets (TFDS) directly depends on the package `tensorflow-datasets` and `grain`.
Additionally, when downloading the dataset for the first time, it requires `tensorflow` to be installed and a active internet connection.
If you want to avoid installing `tensorflow`, you can prepare the dataset beforehand in another environment with `tensorflow` installed,
run:

.. code-block:: python

import tensorflow_datasets as tfds
tfds.data_source(self.dataset)

and then copy the dataset to the target machine.
The default location is`~/tensorflow_datasets`. `~/` means the home directory of the user.

Please notice that the data is loaded under JAX's jit context, so the data should be valid JAX data type,
namely JAX or Numpy arrays, or Python's int, float, list, and dict.
If the data contains other types like strings, you should convert them into arrays using the `operations` parameter.

You can also download the dataset through a proxy server by setting the environment variable `TFDS_HTTP_PROXY` and `TFDS_HTTPS_PROXY`,
for http and https proxy respectively.

The details of the dataset can be found at https://www.tensorflow.org/datasets/catalog/overview
The details about operations/transformations can be found at https://github.com/google/grain/blob/main/docs/transformations.md

Parameters
----------
dataset
The dataset name.
batch_size
The batch size.
loss_func
The loss function.
The function signature is loss(weights, data) -> loss_value, and it should be jittable.
The `weight` is the weight of the neural network, and the `data` is the data from TFDS, which is a dictionary.
split
Which split of the dataset to use.
Default to "train".
operations
The list of transformations to apply to the data.
Default to [].
After the transformations, we will always apply a batch operation to create a batch of data.
datadir
The directory to store the dataset.
Default to None, which means tensorflow-datasets will automatically determine the directory.
seed
The random seed used to seed the dataloader.
Given the same seed, the dataloader should data in the same order.
Default to 0.
try_gcs
Whether to try to download the dataset from Google Cloud Storage.
Usually Google's storage server is faster than the original server of the dataset.
"""

dataset: Static[str]
batch_size: Static[int]
loss_func: Static[Callable]
split: Static[str] = field(default="train")
operations: Static[List[Any]] = field(default_factory=list)
datadir: Static[Optional[str]] = field(default=None)
seed: Static[int] = field(default=0)
try_gcs: Static[bool] = field(default=True)
iterator: Static[pygrain.PyGrainDatasetIterator] = field(init=False)
data_shape_dtypes: Static[Any] = field(init=False)

def __post_init__(self):
if self.datadir is None:
data_source = tfds.data_source(
self.dataset, try_gcs=self.try_gcs, split=self.split
)
else:
data_source = tfds.data_source(
self.dataset,
data_dir=self.datadir,
try_gcs=self.try_gcs,
split=self.split,
)

sampler = pygrain.IndexSampler(
num_records=len(data_source),
shard_options=pygrain.NoSharding(),
shuffle=True,
seed=self.seed,
)

operations = self.operations + [
pygrain.Batch(batch_size=self.batch_size, drop_remainder=True)
]

loader = pygrain.DataLoader(
data_source=data_source,
operations=operations,
sampler=sampler,
worker_count=0,
)
object.__setattr__(self, "iterator", iter(loader))
data_shape_dtypes = get_dtype_shape(self._next_data())
object.__setattr__(self, "data_shape_dtypes", data_shape_dtypes)

@x32_func_call
def _next_data(self):
return next(self.iterator)

def evaluate(self, state, pop):
data = jax.experimental.io_callback(self._next_data, self.data_shape_dtypes)
loss = jax.vmap(self.loss_func, in_axes=(0, None))(pop, data)
return loss, state
Loading
Loading