## Utils

Global configs, PRNGSequence, check installed.

In [None]:
#| default_exp utils

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
import jax_dataloader as jdl
import collections

## Configs

In [None]:
#| export
@dataclass
class Config:
    """Global configuration for the library"""
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

In [None]:
#| exporti
main_config = Config.default()

In [None]:
#| export
def get_config() -> Config:
    return main_config

In [None]:
#| export
def manual_seed(seed: int):
    """Set the seed for the library"""
    main_config.global_seed = seed

In [None]:
manual_seed(11)
assert get_config().global_seed == 11

## Check Installation

In [None]:
#| export
def check_pytorch_installed():
    if torch_data is None:
        raise ModuleNotFoundError("`pytorch` library needs to be installed. "
            "Try `pip install torch`. Please refer to pytorch documentation for details: "
            "https://pytorch.org/get-started/.")


In [None]:
#| torch
check_pytorch_installed()

In [None]:
#| export
def has_pytorch_tensor(batch) -> bool:
    if isinstance(batch[0], torch.Tensor):
        return True
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return any([has_pytorch_tensor(samples) for samples in transposed])
    else:
        return False

In [None]:
#| export
def check_hf_installed():
    if hf_datasets is None:
        raise ModuleNotFoundError("`datasets` library needs to be installed. "
            "Try `pip install datasets`. Please refer to huggingface documentation for details: "
            "https://huggingface.co/docs/datasets/installation.html.")

In [None]:
#| hf
check_hf_installed()

In [None]:
#| export
def check_tf_installed():
    if tf is None:
        raise ModuleNotFoundError("`tensorflow` library needs to be installed. "
            "Try `pip install tensorflow`. Please refer to tensorflow documentation for details: "
            "https://www.tensorflow.org/install/pip.")

In [None]:
#| tf
check_tf_installed()

## Util Functions

In [None]:
#| export
def asnumpy(x) -> np.ndarray:
    if isinstance(x, np.ndarray):
        return x
    elif isinstance(x, jnp.ndarray):
        return x.__array__()
    elif torch_data and isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    elif tf and isinstance(x, tf.Tensor):
        return x.numpy()
    elif isinstance(x, (tuple, list)):
        return map(asnumpy, x)
    else:
        raise ValueError(f"Unknown type: {type(x)}")

In [None]:
np_x = np.array([1, 2, 3])
jnp_x = jnp.array([1, 2, 3])
torch_x = torch.tensor([1, 2, 3])
tf_x = tf.constant([1, 2, 3])
assert np.array_equal(asnumpy(np_x), np_x)
assert np.array_equal(asnumpy(jnp_x), np_x) and not isinstance(asnumpy(jnp_x), jnp.ndarray)
assert np.array_equal(asnumpy(torch_x), np_x) and not isinstance(asnumpy(torch_x), torch.Tensor)
assert np.array_equal(asnumpy(tf_x), np_x) and not isinstance(asnumpy(tf_x), tf.Tensor)
