Skip to content

Commit

Permalink
dev: add try_gcs, split, and datadir parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 23, 2024
1 parent 47e6241 commit c150c31
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions src/evox/problems/neuroevolution/supervised_learning/tfds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import field
from typing import Any, Callable, List
from typing import Any, Callable, List, Optional

import grain.python as pygrain
import jax
Expand Down Expand Up @@ -45,6 +45,9 @@ class TensorflowDataset(Problem):
namely JAX or Numpy arrays, or Python's int, float, list, and dict.
If the data contains other types like strings, you should convert them into arrays using the `operations` parameter.
You can also download the dataset through a proxy server by setting the environment variable `TFDS_HTTP_PROXY` and `TFDS_HTTPS_PROXY`,
for http and https proxy respectively.
The details of the dataset can be found at https://www.tensorflow.org/datasets/catalog/overview
The details about operations/transformations can be found at https://github.com/google/grain/blob/main/docs/transformations.md
Expand All @@ -58,26 +61,49 @@ class TensorflowDataset(Problem):
The loss function.
The function signature is loss(weights, data) -> loss_value, and it should be jittable.
The `weight` is the weight of the neural network, and the `data` is the data from TFDS, which is a dictionary.
split
Which split of the dataset to use.
Default to "train".
operations
The list of transformations to apply to the data.
Default to [].
After the transformations, we will always apply a batch operation to create a batch of data.
datadir
The directory to store the dataset.
Default to None, which means tensorflow-datasets will automatically determine the directory.
seed
The random seed used to seed the dataloader.
Given the same seed, the dataloader should data in the same order.
Default to 0.
try_gcs
Whether to try to download the dataset from Google Cloud Storage.
Usually Google's storage server is faster than the original server of the dataset.
"""

dataset: Static[str]
batch_size: Static[int]
loss_func: Static[Callable]
split: Static[str] = field(default="train")
operations: Static[List[Any]] = field(default_factory=list)
datadir: Static[Optional[str]] = field(default=None)
seed: Static[int] = field(default=0)
try_gcs: Static[bool] = field(default=True)
iterator: Static[pygrain.PyGrainDatasetIterator] = field(init=False)
data_shape_dtypes: Static[Any] = field(init=False)

def __post_init__(self):
data_source = tfds.data_source(self.dataset, split="train")
if self.datadir is None:
data_source = tfds.data_source(
self.dataset, try_gcs=self.try_gcs, split=self.split
)
else:
data_source = tfds.data_source(
self.dataset,
data_dir=self.datadir,
try_gcs=self.try_gcs,
split=self.split,
)

sampler = pygrain.IndexSampler(
num_records=len(data_source),
shard_options=pygrain.NoSharding(),
Expand Down

0 comments on commit c150c31

Please sign in to comment.