In [None]:
import tensorflow as tf
import numpy as np

In [None]:
class SyntheticDataset(tf.keras.utils.Sequence):
    def __init__(self, num_samples, input_dim, output_dim, batch_size):
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.batch_size = batch_size

        # Generate synthetic data
        self.inputs = np.random.rand(self.num_samples, self.input_dim).astype(
            np.float32
        )
        self.targets = np.random.rand(self.num_samples, self.output_dim).astype(
            np.float32
        )

    def __len__(self):
        return int(np.ceil(self.num_samples / self.batch_size))

    def __getitem__(self, idx):
        start_idx = idx * self.batch_size
        end_idx = min((idx + 1) * self.batch_size, self.num_samples)
        batch_inputs = self.inputs[start_idx:end_idx]
        batch_targets = self.targets[start_idx:end_idx]
        return batch_inputs, batch_targets

In [None]:
# Example usage:

# batch_size is set here
batch_size = 32

# Base object for fitting to a sequence of data, such as a dataset.
# Equivalent to Dataset class object in PyTorch.
dataset = SyntheticDataset(
    num_samples=1000, input_dim=10, output_dim=1, batch_size=batch_size
)

# Creates a `Dataset` whose elements are slices of the given tensors.
#
# Two tensors can be combined into one Dataset object.
# features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
# labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
# dataset = Dataset.from_tensor_slices((features, labels))
#
# Both the features and the labels tensors can be converted
# to a Dataset object separately and combined after.
# features_dataset = Dataset.from_tensor_slices(features)
# labels_dataset = Dataset.from_tensor_slices(labels)
# dataset = Dataset.zip((features_dataset, labels_dataset))
dataloader = tf.data.Dataset.from_tensor_slices((dataset.inputs, dataset.targets))

# Combines consecutive elements of this dataset into batches.
# Randomly shuffles the elements of this dataset.
dataloader = dataloader.shuffle(buffer_size=len(dataset)).batch(batch_size)

for inputs, targets in dataloader:
    print(inputs, targets)

In [None]:
class SyntheticDataset2(tf.data.Dataset):
    def _generator(num_samples, input_dim, output_dim):
        for _ in range(num_samples):
            inputs = tf.random.normal(shape=(input_dim,))
            targets = tf.random.normal(shape=(output_dim,))
            yield inputs, targets

    def __new__(cls, num_samples, input_dim, output_dim, **kwargs):
        return tf.data.Dataset.from_generator(
            lambda: cls._generator(num_samples, input_dim, output_dim),
            output_signature=(
                tf.TensorSpec(shape=(input_dim,), dtype=tf.float32),
                tf.TensorSpec(shape=(output_dim,), dtype=tf.float32),
            ),
        )

In [None]:
# Example usage 2:

# Represents a potentially large set of elements.
dataset2 = SyntheticDataset2(num_samples=1000, input_dim=10, output_dim=1)

# batch_size is set here
batch_size = 32
# buffer_size is set here
buffer_size = 1000

# Combines consecutive elements of this dataset into batches.
# Randomly shuffles the elements of this dataset.
dataloader = dataset2.shuffle(buffer_size=buffer_size).batch(batch_size=batch_size)

for inputs, targets in dataloader:
    print(inputs, targets)