## Utils

> Global configs, PRNGSequence, check installed.

In [None]:
#| default_exp utils

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
import jax_dataloader as jdl

In [None]:
#| export
@dataclass
class Config:
    """Global configuration for the library"""
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

In [None]:
#| exporti
main_config = Config.default()

In [None]:
#| export
def get_config() -> Config:
    return main_config

In [None]:
#| export
class PRNGSequence(Iterator[PRNGKey]):
    """An Interator of Jax PRNGKey (minimal version of `haiku.PRNGSequence`)."""

    def __init__(self, seed: int):
        self._key = jax.random.PRNGKey(seed)
        self._subkeys = collections.deque()

    def reserve(self, num):
        """Splits additional ``num`` keys for later use."""
        if num > 0:
            new_keys = tuple(jax.random.split(self._key, num + 1))
            self._key = new_keys[0]
            self._subkeys.extend(new_keys[1:])
            
    def __next__(self):
        if not self._subkeys:
            self.reserve(get_config().rng_reserve_size)
        return self._subkeys.popleft()

In [None]:
keys = PRNGSequence(0)
key_1 = next(keys)
key_2 = next(keys)
assert key_1 is not key_2



In [None]:
#| export
def check_pytorch_installed():
    if torch_data is None:
        raise ModuleNotFoundError("`pytorch` library needs to be installed. "
            "Try `pip install torch`. Please refer to pytorch documentation for details: "
            "https://pytorch.org/get-started/.")


In [None]:
#| torch
check_pytorch_installed()

In [None]:
#| export
def has_pytorch_tensor(batch) -> bool:
    if isinstance(batch[0], torch.Tensor):
        return True
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return any([has_pytorch_tensor(samples) for samples in transposed])
    else:
        return False

In [None]:
#| export
def check_hf_installed():
    if hf_datasets is None:
        raise ModuleNotFoundError("`datasets` library needs to be installed. "
            "Try `pip install datasets`. Please refer to huggingface documentation for details: "
            "https://huggingface.co/docs/datasets/installation.html.")

In [None]:
#| hf
check_hf_installed()

In [None]:
#| export
def check_tf_installed():
    if tf is None:
        raise ModuleNotFoundError("`tensorflow` library needs to be installed. "
            "Try `pip install tensorflow`. Please refer to tensorflow documentation for details: "
            "https://www.tensorflow.org/install/pip.")

In [None]:
#| tf
check_tf_installed()

In [None]:
#| export
def is_hf_dataset(dataset):
    return hf_datasets and (
        isinstance(dataset, hf_datasets.Dataset) 
        or isinstance(dataset, hf_datasets.DatasetDict)
    )


In [None]:
#| export
def is_torch_dataset(dataset):
    return torch_data and isinstance(dataset, torch_data.Dataset)

In [None]:
#| export
def is_jdl_dataset(dataset):
    return isinstance(dataset, jdl.Dataset)

In [None]:
#| export
def is_tf_dataset(dataset):
    return tf and isinstance(dataset, tf.data.Dataset)