# 🥶 Dealing with non-jax types

In essence, how to pass non-inexact types (e.g. int, str, Callables, ...) over jax transformations like `jax.grad`

## `jax` and inexact data types
`jax` transformations like `jax.grad` can handle  pytrees of inexact data types ( `float`, `complex`, `array` of `float`/`complex`). any other input type will lead to type error, the following example shows this.

In [1]:
import jax


@jax.grad
def identity_grad(x):
    # x can be any pyt
    return sum(x)


# valid input
identity_grad([1.0, 1.0])

# invalid input (not in-exact)
try:
    identity_grad([1])
except TypeError as e:
    print(e)

grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.


## Using `tree_mask`

* However, in cases when you function needs to pass non-inexact data type, we can mask the non-inexact typed leaves with a frozen wrapper through `pytreeclass.tree_mask`. Masked leaves are wrapped with a wrapper that yields no leaves when interacting with jax transformations.

* Masking with `tree_mask` is equivalent to applying `freeze` to the masked leaves.
```python
    >>> import pytreeclass as pytc
    >>> import jax
    >>> tree = [1, 2, {"a": 3, "b": 4.}]
    >>> # mask all non-differentiable nodes by default
    >>> def mask_if_nondiff(x):
    ...     return pytc.freeze(x) if pytc.is_nondiff(x) else x
    >>> masked_tree = jax.tree_map(mask_if_nondiff, tree)
```

In [2]:
import pytreeclass as pytc

# 1 is an int of non-inexact type
# thus when `tree_mask` is applied it will wrap it
# with a frozen wrapper and this wrapper will be indicated
# in the object repr/str with a `#` prefix
print(pytc.tree_mask(1))

# the type of the wrapped object is `Frozen` type variant
# frozen types yields no leaves when flattened by jax internals
# thus excluding them from jax transformations
print(type(pytc.tree_mask(1)))

#1
<class 'pytreeclass._src.tree_mask._FrozenHashable'>


In [3]:
import pytreeclass as pytc


@jax.grad
def identity_grad(x):
    return 1.0


try:
    # this will fail
    identity_grad([1, 1.0])
except TypeError as e:
    print(e)

# this will work because the tree_mask will
# wrap the non-inexact type (int)
identity_grad([pytc.tree_mask(1), 1.0])

grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.


[#1, Array(0., dtype=float32, weak_type=True)]

Notice that using `pytc.tree_mask` we were able to pass non-inexact type to `jax` transformation without `jax` complaining.
however, inside the function we need to unmask this value if we want to use it, if do not need to use the value , we dont need to worry about unfreezing it. the following example shows that.

In [4]:
import pytreeclass as pytc


@jax.grad
def identity_grad(x):
    # this function does not use the frozen value
    return x[1] ** 2


print(identity_grad([pytc.tree_mask(1), 1.0]))

[#1, Array(2., dtype=float32, weak_type=True)]


However, if we need to pass non-inexact value to the function to use inside the function we need to freeze it before passing it to the function, and unfreeze it inside the function. The next example explain this concept

In [5]:
import pytreeclass as pytc


@jax.grad
def func(x):
    # this function uses the non-inexact and inexact values
    # the non-inexact value is frozen so we need to unfreeze it
    x = pytc.tree_unmask(x)
    return x[0] ** 2 + x[1] ** 2


print(func([pytc.tree_mask(1), 1.0]))

[#1, Array(2., dtype=float32, weak_type=True)]


The result of previous cell reveals something interesting, we know that $\frac{d}{dx} x^2 = 2x$, however this derivative is only evaluated for the inexact value of type `float` and returned the result as `Array(2.)`, but for the value of type `int` which was frozen on input, it has not changed. this is working as intended, in fact we can use this mechanism not only to pass invalid types to `jax` transformation without raising an error, but we can use this scheme to prevent values from being updated/take derivative with respect to. the following example shows this:

In [6]:
import pytreeclass as pytc
import jax


@jax.grad
def func(x):
    x = pytc.tree_unmask(x)
    return x**2


# using `tree_mask` with a mask that always returns `True`
# to select all leaves
print(func(pytc.tree_mask(1.0, mask=lambda _: True)))

# or using `pytc.freeze` to apply frozen wrapper directly
print(func(pytc.freeze(1.0)))

#1.0
#1.0


Another example to mask values with a frozen wrapper by `pytc.freeze`

In [7]:
import pytreeclass as pytc
import jax


@jax.grad
def sum_grad(x):
    # unfreeze the input in case any of the values are frozen
    # this is not necessary if you know the input is not frozen
    x = pytc.tree_unmask(x)
    return sum(x)


print(sum_grad([pytc.freeze(1.0), 1.0]))
print(sum_grad([1.0, pytc.freeze(1.0)]))
print(sum_grad([pytc.freeze(1.0), pytc.freeze(1.0)]))
print(sum_grad([1.0, 1.0]))

[#1.0, Array(1., dtype=float32, weak_type=True)]
[Array(1., dtype=float32, weak_type=True), #1.0]
[#1.0, #1.0]
[Array(1., dtype=float32, weak_type=True), Array(1., dtype=float32, weak_type=True)]


Using the previous scheme , you can achieve a low-overhead training when using `jax` and `PyTreeClass`

## Using `tree_mask` with a mask recipes
The following examples shows how to effictively using `tree_mask` and `TreeClass` instances to freeze certain values.

In [8]:
from __future__ import annotations
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import pytreeclass as pytc


@pytc.autoinit
class Tree(pytc.TreeClass):
    a: int = 1
    b: float = 2.0
    c: jax.Array = jnp.array([3.0, 4.0, 5.0])


tree = Tree()
tree

Tree(a=1, b=2.0, c=f32[3](μ=4.00, σ=0.82, ∈[3.00,5.00]))

### Freeze leaves by specifying a mask

In [9]:
# lets freeze all int values
mask = jtu.tree_map(lambda x: isinstance(x, int), tree)
frozen_tree = pytc.tree_mask(tree, mask)
print(frozen_tree)
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze by a mask
unfrozen_tree = pytc.tree_unmask(frozen_tree)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])

Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])


### Freeze leaves by specifying the leaf name

Since `tree_mask` applies `freeze` using `tree_map`, in case of applying on single leaf, we can just use `freeze` directly.

In [10]:
frozen_tree = tree.at["a"].apply(pytc.freeze)
print(frozen_tree)  # `a` has a prefix `#`
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze `a`
unfrozen_tree = pytc.tree_unmask(frozen_tree)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])

Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])


## Dealing with buffers

The following shows how to deal with buffer arrays in two ways:
1) Using `jax.lax.stop_gradient`.
2) Using frozen wrapper.

### Using `jax.lax.stop_gradient`

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them. from [jax docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html)

In [11]:
import pytreeclass as pytc
import jax
import jax.numpy as jnp


# using jax.lax.stop_gradient
@pytc.autoinit
class Tree(pytc.TreeClass):
    buffer: jax.Array

    def __call__(self, x):
        return jax.lax.stop_gradient(self.buffer) + x


@jax.jit
def func(t: Tree, x):
    return t(x).sum()


x = jnp.array([1.0, 2.0, 3.0])
t = Tree(buffer=jnp.array([1.0, 2, 3]))

%timeit func(t, x)
%timeit jax.grad(func)(t, x)

3.61 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
1.06 ms ± 356 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Using frozen wrapper

In [12]:
import pytreeclass as pytc
import jax
import jax.numpy as jnp


@pytc.autoinit
class Tree(pytc.TreeClass):
    buffer: jax.Array

    def __call__(self, x):
        return self.buffer + x


@jax.jit
def func(t: Tree, x: jax.Array):
    t = pytc.tree_unmask(t)  # unmask the frozen leaves
    return t(x).sum()


x = jnp.array([1.0, 2.0, 3.0])
t = Tree(buffer=pytc.freeze(jnp.array([1.0, 2, 3])))

%timeit func(t, x)
%timeit jax.grad(func)(t, x)

3.58 µs ± 290 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
341 µs ± 52.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
