## Utils

Global configs, PRNGSequence, check installed.

In [None]:
#| default_exp utils

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
import jax_dataloader as jdl
import collections

In [None]:
#| hide
from fastcore.test import test_fail

## Configs

In [None]:
#| export
@dataclass
class Config:
    """Global configuration for the library"""
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

In [None]:
#| exporti
main_config = Config.default()

In [None]:
#| export
def get_config() -> Config:
    return main_config

In [None]:
#| export
def manual_seed(seed: int):
    """Set the seed for the library"""
    main_config.global_seed = seed

In [None]:
manual_seed(11)
assert get_config().global_seed == 11

## Check Installation

In [None]:
#| export
def check_pytorch_installed():
    if torch_data is None:
        raise ModuleNotFoundError("`pytorch` library needs to be installed. "
            "Try `pip install torch`. Please refer to pytorch documentation for details: "
            "https://pytorch.org/get-started/.")


In [None]:
#| torch
check_pytorch_installed()

In [None]:
#| export
def has_pytorch_tensor(batch) -> bool:
    if isinstance(batch[0], torch.Tensor):
        return True
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return any([has_pytorch_tensor(samples) for samples in transposed])
    else:
        return False

In [None]:
#| export
def check_hf_installed():
    if hf_datasets is None:
        raise ModuleNotFoundError("`datasets` library needs to be installed. "
            "Try `pip install datasets`. Please refer to huggingface documentation for details: "
            "https://huggingface.co/docs/datasets/installation.html.")

In [None]:
#| hf
check_hf_installed()

In [None]:
#| export
def check_tf_installed():
    if tf is None:
        raise ModuleNotFoundError("`tensorflow` library needs to be installed. "
            "Try `pip install tensorflow`. Please refer to tensorflow documentation for details: "
            "https://www.tensorflow.org/install/pip.")

In [None]:
#| tf
check_tf_installed()

## Seed Generator

In [None]:
#| export
class Generator:
    def __init__(
        self, 
        *, 
        generator: jrand.Array | torch.Generator = None,
    ):
        self._seed = None
        self._jax_generator = None
        self._torch_generator = None

        if generator is None:
            self._seed = get_config().global_seed
        elif isinstance(generator, torch.Generator):
            self._torch_generator = generator
        elif isinstance(generator, jax.Array):
            self._jax_generator = generator
        else:
            raise ValueError(f"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.")
        
        if self._seed is None and self._torch_generator is not None:
            self._seed = self._torch_generator.initial_seed()

    def seed(self) -> Optional[int]:
        """The initial seed of the generator"""
        # TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`
        return self._seed
    
    def manual_seed(self, seed: int) -> Generator:
        """Set the seed for the generator. This will override the initial seed and the generator."""
        
        if self._jax_generator is not None:
            self._jax_generator = jrand.PRNGKey(seed)
        if self._torch_generator is not None:
            self._torch_generator = torch.Generator().manual_seed(seed)
        self._seed = seed
        return self
    
    def jax_generator(self) -> jax.Array:
        """The JAX generator"""
        if self._jax_generator is None:
            self._jax_generator = jrand.PRNGKey(self._seed)
        return self._jax_generator
    
    def torch_generator(self) -> torch.Generator:
        """The PyTorch generator"""
        check_pytorch_installed()
        if self._torch_generator is None and self._seed is not None:
            self._torch_generator = torch.Generator().manual_seed(self._seed)
        if self._torch_generator is None:
            raise ValueError("Neither pytorch generator or seed is specified.")
        return self._torch_generator

In [None]:
# Example of using the generator
g = Generator()
assert g.seed() == get_config().global_seed
assert jnp.array_equal(g.jax_generator(), jax.random.PRNGKey(get_config().global_seed)) 
assert g.torch_generator().initial_seed() == get_config().global_seed

# Examples of using the generator when passing a `jax.random.PRNGKey` or `torch.Generator`
g_jax = Generator(generator=jax.random.PRNGKey(123))
assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(123))
assert g_jax.seed() is None

g_torch = Generator(generator=torch.Generator().manual_seed(123))
assert g_torch.torch_generator().initial_seed() == 123
assert g_torch.seed() == 123
assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(123))

In [None]:
#| hide
test_fail(g_jax.torch_generator, contains='Neither pytorch generator or seed is specified')

In [None]:
# Example of using `manual_seed` to set the seed
g_jax.manual_seed(456)
assert g_jax.seed() == 456
assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(456))
assert g_jax.torch_generator().initial_seed() == 456

g_torch.manual_seed(789)
assert g_torch.seed() == 789
assert g_torch.torch_generator().initial_seed() == 789
assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(789))

## Util Functions

In [None]:
#| export
def asnumpy(x) -> np.ndarray:
    if isinstance(x, np.ndarray):
        return x
    elif isinstance(x, jnp.ndarray):
        return x.__array__()
    elif torch_data and isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    elif tf and isinstance(x, tf.Tensor):
        return x.numpy()
    elif isinstance(x, (tuple, list)):
        return map(asnumpy, x)
    else:
        raise ValueError(f"Unknown type: {type(x)}")

In [None]:
np_x = np.array([1, 2, 3])
jnp_x = jnp.array([1, 2, 3])
torch_x = torch.tensor([1, 2, 3])
tf_x = tf.constant([1, 2, 3])
assert np.array_equal(asnumpy(np_x), np_x)
assert np.array_equal(asnumpy(jnp_x), np_x) and not isinstance(asnumpy(jnp_x), jnp.ndarray)
assert np.array_equal(asnumpy(torch_x), np_x) and not isinstance(asnumpy(torch_x), torch.Tensor)
assert np.array_equal(asnumpy(tf_x), np_x) and not isinstance(asnumpy(tf_x), tf.Tensor)
