In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import random
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

#### We built and trained an MLP using JAX in the last notebook but we cannot similarly build other NN Libraries or Layers such as NN.Linear without more knowledge

## Custom PyTrees

In [2]:
# Imagine a linear layer or a convolution layer
class MyContainer:
    """A named Container"""

    def __init__(self, name:str, a:int, b:int, c:int):
        self.name = name
        self.a = a
        self.b = b
        self.c = c   

In [3]:
example_pytree = [MyContainer('Alice', 1, 2, 3), MyContainer("Bob", 4, 5, 6)]
# We would think this Pytree will have 8 leaves
# Nope!

In [5]:
leaves = jax.tree.leaves(example_pytree)
print(f"{repr(example_pytree):<45}\n has {len(leaves)} leaves: \n{leaves}")

[<__main__.MyContainer object at 0x0000019FDC21F230>, <__main__.MyContainer object at 0x0000019FAACEE0D0>]
 has 2 leaves: 
[<__main__.MyContainer object at 0x0000019FDC21F230>, <__main__.MyContainer object at 0x0000019FAACEE0D0>]


- Only 2 leaves!
- In this case, we cannot do manipulations on each value (like use tree.map)
- Image if this was a Linear Layer, we wont be able to do any Stochastic GD

In [6]:
print(jax.tree.map(lambda x: x+1, example_pytree))
# This will not work (would be nice if it did)

TypeError: unsupported operand type(s) for +: 'MyContainer' and 'int'

#### To get it to work, we need define two functions:
- flatten
- unflatten

In [7]:
def flatten_MyContainer(container):
    # Returns an iterable over container contents, and aux data
    flat_contents = [container.a, container.b, container.c]

    # We dont want name to appear as a child, so it is auxiliary data
    # Auxiliary Data is often a description of the strucuture of a node,
    # Eg: Keys of a dict -> anything that is not a node's children

    aux_data = container.name

    return flat_contents, aux_data

In [8]:
def unflatten_MyContainer(aux_data, flat_contents):
    return MyContainer(aux_data, *flat_contents)

In [9]:
# Register a custom PyTree node
jax.tree_util.register_pytree_node(MyContainer, flatten_MyContainer, unflatten_MyContainer)

- register_pytree_node teaches jax how to traverse and manipulate custom Python classes as PyTrees
- When we create custom classes, JAX does not know how to look inside it (which is why previously we saw only 2 leaves)
- Without proper pytree registration, JAX operations fail<br><br><br>
- The flatten function basically tells JAX how to decompose your object into children and auxiliary data
- The unflatten function tells JAX how to reconstruct your object from flattened representation

In [10]:
# Let's try again
leaves = jax.tree.leaves(example_pytree)
print(f"{repr(example_pytree):<45}\n has {len(leaves)} leaves: \n{leaves}")

[<__main__.MyContainer object at 0x0000019FDC21F230>, <__main__.MyContainer object at 0x0000019FAACEE0D0>]
 has 6 leaves: 
[1, 2, 3, 4, 5, 6]


In [11]:
# Lets try applying tree.map again
res = jax.tree.map(lambda x: x+1, example_pytree)
print(jax.tree.leaves(res))

[2, 3, 4, 5, 6, 7]


### Common Gotcha for PyTrees: Mistaking nodes for leaves/children

In [12]:
zeros_tree = [jnp.zeros((2,3)), jnp.zeros((3,4))]
print(zeros_tree)

[Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32), Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)]


In [13]:
# Try to make another similar tree with ones instead of zeros
shapes = jax.tree.map(lambda x:x.shape, zeros_tree)
print(shapes)

[(2, 3), (3, 4)]


In [14]:
ones_tree = jax.tree.map(jnp.ones, shapes)
print(ones_tree)

[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)), (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]


#### Issue: Treating the tuples as PyTree nodes instead of leaves!

- zeros_tree is a list with two leaves (2 arrays)
- the .shape returns a tuple; so shapes is a list containing 2 tuples
- JAX treats tuples as PyTree containers, not leaves

List containing:<br>
  ├─ Tuple(2, 3)  ← This is a CONTAINER with 2 children<br>
  │   ├─ 2<br>
  │   └─ 3<br>
  └─ Tuple(3, 4)  ← This is a CONTAINER with 2 children<br>
      ├─ 3<br>
      └─ 4<br>

- Total leaves: 4 scalars [2, 3, 3, 4]
- When we apply do jax.tree.map(...), JAX applies jnp.ones to each leaf (each scalar)
- Then the final structure maintains the original tree (with 4 leaves)

#### Solution 1: Get shapes directly without Intermediate tree

In [18]:
ones_tree = jax.tree.map(lambda x: jnp.ones(x.shape), zeros_tree)
print(ones_tree)

[Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32), Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)]


#### Solution 2: Use tree_map with is_leaf parameter

In [19]:
shapes = jax.tree.map(lambda x: x.shape, zeros_tree)
ones_tree = jax.tree.map(jnp.ones, shapes, is_leaf= lambda x: isinstance(x, tuple))
print(ones_tree)

[Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32), Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)]


#### Great! Now we can create custom layers and train even bigger neural networks!

#### But what if our NN is really big? Can we train it across multiple devices?