# A Better 🐍 PyTree 🌲 interface!

So ✨ `Equinox` ✨ is an awesome package that gives us the ability to build object-oriented software in `Jax` 🤯, however it is a low-level package, designed around generality and flexibilty, which can make building a simple API for user-facing software a challenege. The less that astronomers have to learn about 😱 `lambda` functions 😱, the better! To help with this, I've constructed a class with PyTree helper methods, designed to give a simpler and more intuitive interface.

There is also a class designed to simplify interfacing with some of the most valueable packages used in conjunction with our software, namely `Optax` and `Numpyro` 😎.

If you are new to `Jax`, note that all object are *immutable*, which simply means that you can NOT do in-place updates. ie any time we update some parameter, we return a new version of that object, this will become clear throughout the tutorial!


---
## What is a PyTree? 

PyTrees are the base object that Jax works with under the hood. Fundamentally they are any series of nested lists, tuples and dictionaries (https://jax.readthedocs.io/en/latest/pytrees.html). Equinox simply allows us to extend this definition to classes, hence object-oriented Jax! All classes in ∂Lux are PyTrees at the base level. Because of these arbitrary structures, indexing and setting new values 'leaves' can be difficult, since each 'leaf' must be referred to via some 'path'. 

In [1]:
import jax
import jax.numpy as np
import matplotlib.pyplot as plt
from dLux.base import Base, ExtendedBase
%matplotlib inline



## Examples

Lets create an example class that inherits from the `Base` class, which contains much of the low-level functionality. For this we will have some nested classes with various parameters. Lets instantiate these and have a look

In [2]:
# Example class
class Variances(Base):
    var_x: float
    var_y: float
    some_list: list
    some_dict: dict

    def __init__(self, var_x, var_y, some_list, some_dict):
        self.var_x = var_x
        self.var_y = var_y
        self.some_list = some_list
        self.some_dict = some_dict

# Example class
class SuperGaussian(Base):
    variances: object
    power: float

    def __init__(self, variances, power):
        self.variances = variances
        self.power = power
        
# Create an instance of the SuperGaussian object
var_x, var_y = 10, 10
power = 1
some_list = [-1, -2]
some_dict = {'a': 'foo', 'b': 'bar'}

# Create the object
variances = Variances(var_x, var_y, some_list, some_dict)
pytree = SuperGaussian(variances, 1)

# Examine the object
print(pytree)

SuperGaussian(
  variances=Variances(
    var_x=10,
    var_y=10,
    some_list=[-1, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)


Nice! Here we have a nested structure, so to look at some of these class methods, we first need to understand the 'path' object.

## The `path` object

A `path` is simply a string that refers to some place in a pytree, with nested structures connected with dots '.', similar to accessing class attributes. Some example paths for our example pytree would look like this:

 - 'variances.var_x'
 - 'power'
 - 'variances.some_list.0'
 - 'variances.some_dict.a'
 - 'variances.some_dict'

Each of these path objects refer to some place in the pytree, not neccesarily a leaf.

---
## New Methods

We have built a series of method to operate on the parts of the pytree that these paths refer to, matching the jax.numpy.at[] method:

 - `.get()`
 - `.set()`
 - `.add()`
 - `.multiply()`
 - `.divide()`
 - `.power()`
 - `.min()`
 - `.max()`
 - `.apply()`
 - `.apply_args()`

---

### `.get()`

In [3]:
print(pytree.get('variances.var_x'))
print(pytree.get('variances.some_list.0'))
print(pytree.get('variances.some_dict'))

10
-1
{'a': 'foo', 'b': 'bar'}


In [4]:
# Examine the output gaussian
gauss = pytree.model()

# Plot
plt.imshow(gauss)
plt.colorbar()
plt.show()

AttributeError: 'SuperGaussian' object has no attribute 'model'

---
---

## Accessor Methods

Accessors:
>
> - .get_leaf(path, path_dict=dict)
>
> - .get_leaves(paths, path_dict=dict)

These two methods simply take in a single path or list of paths and return the corresponding attributes!

Lets define some paths and check that everything works!

In [None]:
# Define paths
path1 = "power"
path2 = "variances.var_x"
path3 = "variances.useless_list.1"
paths = [path1, path2, path3]

# Access objects using .get_leaf()
print(pytree.get(path1))
print(pytree.get(path2))
print(pytree.get(path3))

# Access objects using .get_leaves()
print(pytree.get(paths))

Great! Simple enough, now lets move on the the Updater method

---
---

## Updater Methods

> - .update_leaves(paths, values, path_dict=dict)

This returns an updated version of the pytree, with the values places at the corresponding path.

> - .apply_to_leaves(paths, fns, path_dict=dict)

This returns an updated version of the pytree, with the values specified by the paths having the correspoinding function applied.


Note that these methods performs no checks and do not preserve data type. If you pass in the wrong data-type you will very likely break your code. For example if you pass in a list instead of a jax array no errors will be thrown untill some other part of the downstream code expect a jax array in its place. Be careful!

In [None]:
# Define paths
path1 = "power"
path2 = "variances.var_x"
path3 = "variances.useless_list.1"
paths = [path1, path2, path3]

value = [-10]
values = [1e2, 1e3, 1e4]

print(pytree.set(path1, value))
print(pytree.set(paths, values))

In [None]:
# Define paths
path1 = "power"
path2 = "variances.var_x"
path3 = "variances.useless_list.1"
paths = [path1, path2, path3]

fn = lambda x: 5 * x
fns = [lambda x: -x, lambda x: 1e2 * x, lambda x: x + 5]

print(pytree.apply(path1, fn))
print(pytree.apply(paths, fns))

---
### Nesting

So now is a good time to introduce the nesting concept. Lets say we wanted to update multiple parameters with the *same* value. We can achieve this simply by nesting our paths within each other! Each value will be applied to the corresponding list of paths!

This also works with the `.apply_to_leaves()` method

In [None]:
# Define paths
path1 = "power"
path2 = "variances.var_x"
path3 = "variances.useless_list.1"

# Nested paths strucutre
paths = [[path1, path2], path3]
print(pytree.set(paths, [-1e2, 1e4]))

---

## Path dictionary

The path dictionary as mentioned earlier is a way to further simplify our interface with PyTrees. By definintely the paths we care about inside the dictionary we can use simple keys to refer to those leaves! This is expecially useful for highly nested structures or leaves that we want to refer to many times. Note that we don't have to define *every* path the *every* leaf inside the dictionary, we can use a mix of paths and keys to refer to obejcts. We can also use the nesting concept with keys/paths interchangably.

Note: The path_dict keys MUST NOT match any of the parameter names within any of the classes or sub-classes, or the methods will break. ie each key must be uniquely named from all parameters!

In [None]:
pmap = {
    "pow":  "power",
    "xvar": "variances.var_x",
    "yvar": "variances.var_y",
}

print(pytree.get("pow", pmap=pmap))
print(pytree.get(["pow", "xvar"], pmap=pmap))

The methods have also been built to be flexible in the way that you pass in the path objects.

 1. A single list of keys are understood referencing multiple leaves, rather than a single path
 2. Nested lists of keys work identically to nested paths
 3. Single keys do not *need* to be wrapped in lists
 
Lets see those in action, using each of these edge cases to access the same parameters

In [None]:
# 1
print(pytree.get(["yvar", "pow"], pmap=pmap))

# 2
print(pytree.get([["variances.var_y"], "power"], pmap=pmap))

# 3
print(pytree.get([["variances.var_y"], "pow"], pmap=pmap))

---
---

## Interfacing Functions!


### Equinox filter function interface!

> .get_filter_spec(paths, path_dict=dict)

This takes in a list of paths and returns a filter_spec ready to be passed straight into any Equinox filter function!

Lets see how we can use this to optimise a model using Equniox along

In [None]:
from dLux.base import ExtendedBase

In [None]:
class Variances(ExtendedBase):
    var_x: float
    var_y: float
    useless_list: list

    def __init__(self, var_x, var_y, useless_list):
        self.var_x = var_x
        self.var_y = var_y
        self.useless_list = useless_list


class SuperGaussian(ExtendedBase):
    variances: object
    power: dict

    def __init__(self, variances, power):
        self.variances = variances
        self.power = power

    def model(self, flatten=False):
        xs = np.linspace(-50, 50, 100)
        XX, YY = np.meshgrid(xs, xs)

        x = (XX / self.variances.var_x) ** 2
        y = (YY / self.variances.var_y) ** 2

        g = np.exp(-((x + y) ** self.power))

        if flatten:
            return g.flatten()
        else:
            return g

In [None]:
# Create an instance of the SuperGaussian object
var_x, var_y = 10, 10
power = 1
useless_list = [-1, -2]

# Create the object
variances = Variances(var_x, var_y, useless_list)
pytree = SuperGaussian(variances, 1)

# Examine the object
print(pytree)

# Examine the output gaussian
gauss = pytree.model()

# Plot
plt.imshow(gauss)
plt.colorbar()
plt.show()

In [None]:
import equinox as eqx
from tqdm.notebook import tqdm

In [None]:
# Define paths to the variables we care about, and new values
paths = ["xvar", "yvar", "pow"]
new_values = [np.array(15.0), np.array(5.0), np.array(1.5)]

# Get a new pytree to optimise
model_pytree = pytree.set(paths, new_values, pmap=pmap)

# Generate a filter_spec to pass to equinox filter functions
filter_spec = model_pytree.get_args(paths, pmap=pmap)

In [None]:
# Define the loss function
@eqx.filter_jit()
@eqx.filter_value_and_grad(arg=filter_spec)
def loss_fn(model, data):
    return np.sum((model.model() - data) ** 2)

In [None]:
# Make some fake data and evaluate loss
fake_data = pytree.model()
loss, grads = loss_fn(model_pytree, fake_data)
print(loss, grads.variances.var_x, grads.variances.var_y, grads.power)
print(grads)

In [None]:
# Define a basic step function
get_step = lambda grads, lr: jax.tree_map(lambda leaf: -lr * leaf, grads)

# Optimise the model
for i in tqdm(range(500)):
    loss, grads = loss_fn(model_pytree, fake_data)
    model_pytree = eqx.apply_updates(model_pytree, get_step(grads, 1e-2))

# Print the final values to check that eveything works
(model_pytree.variances.var_x, model_pytree.variances.var_y, model_pytree.power)

Awesome! As we can see we were able to recover out true parameters!

---

### Optax param_spec interface!

So next we want to be able to actually optmise a model using optax, so we need to define a param_spec!

> .get_param_spec(path, groups, path_dict=dict, get_filter_spec=bool)

This function lets us generate a param_spec in order to group parameters and apply optimiser to them. We could pass in the filter_spec from before, or we could use the inbuilt functionality that returns the correct filter_spec for the given param_spec.

Lets group the two variances together, and the power to its own group, and see how we go

In [None]:
import optax

In [None]:
values = [None]
values == [None]

In [None]:
# Define paths to the variables we care about, and new values
paths = ["xvar", "yvar", "pow"]
new_values = [np.array(15.0), np.array(5.0), np.array(1.5)]

# Get a new pytree to update
model_pytree = pytree.set(paths, new_values, pmap=pmap)

In [None]:
# Define the parameter groups for the param spec
# Use the nested path functionality to group the variances together!
paths = [["xvar", "yvar"], "pow"]
groups = ["var", "pow"]
param_spec, filter_spec = model_pytree.get_param_spec(
    paths, groups, get_args=True, pmap=pmap
)

print(param_spec)
print(filter_spec)

In [None]:
#  Define Learning rates
var_lr = 1e0
pow_lr = 1e-1

# Use the generated param spec to map optimisers
# Be sure to match the values defined in 'groups'!
optim = optax.multi_transform(
    {"null": optax.adam(0.0), "var": optax.adam(var_lr), "pow": optax.adam(pow_lr)},
    param_spec,
)

# Initialise & optimise a for single epoch
opt_state = optim.init(model_pytree)
for i in tqdm(range(100)):
    loss, grads = loss_fn(model_pytree, fake_data)
    updates, opt_state = optim.update(grads, opt_state)
    model_pytree = eqx.apply_updates(model_pytree, updates)

# Print the final values to check that eveything works
(model_pytree.variances.var_x, model_pytree.variances.var_y, model_pytree.power)

Great it all works!

---

### Optax optimiser interface!

> .get_pytree_optimiser(paths, optimisers, get_filter_spec=bool)

So in most use-cases, we can avoid the need to interact with the optax.multi_transform function all together, allowing us to *only* define the optimisers we wish to apply to each parameter and group. Lets have a look how we can do that!

In [None]:
# Define paths to the variables we care about, and new values
paths = ["xvar", "yvar", "pow"]
new_values = [np.array(15.0), np.array(5.0), np.array(1.5)]

# Get a new pytree to update
model_pytree = pytree.set(paths, new_values, pmap=pmap)

In [None]:
# Define paths and corresponsing optimisers
paths = [["xvar", "yvar"], "pow"]
optimisers = [optax.adam(1e-1), optax.adam(1e-2)]

# Get optimiser and filter_spec
optim, fs = pytree.get_optimiser(
    paths, optimisers, get_args=True, pmap=pmap
)

# Initialise & Optimise
opt_state = optim.init(model_pytree)

for i in tqdm(range(100)):
    loss, grads = loss_fn(model_pytree, fake_data)
    updates, opt_state = optim.update(grads, opt_state)
    model_pytree = eqx.apply_updates(model_pytree, updates)

# Print the final values to check that eveything works
(model_pytree.variances.var_x, model_pytree.variances.var_y, model_pytree.power)

How easy was that!

---

### Numpyro interface!

The last package we want to be able to interact with easily is Numpyro, so we can run MCMCs!

> .update_and_model(model_fn, paths, values, path_dict=dict, *args, **kwargs)

So for this method the paths, values and path_dict should all be familiar by now. The difference here is that we also must specify with function is the one used th generate our model. This is done using a string to reference the method. Similarly if we need to pass extra agruments into the modelling function we can do that with the *args and **kwargs. I will show to to pass in key word arguments here!

For those who haven't used Numpyro, you need to define a modelling function with all of the parameters you wish to sample. This minimal example should give you a good idea of what to do, but for a more in-depth exploration of its functionality and behaviour, check out [this great tutorial](https://dfm.io/posts/intro-to-numpyro/)

In [None]:
import numpyro as npy
import numpyro.distributions as dist
import jax.random as jr
import chainconsumer as cc

In [None]:
def modelling_fn(data, model, path_dict=None):
    """ """
    # Define parameter sampling
    var_x = npy.sample("x variance", dist.Uniform(0, 100))
    var_y = npy.sample("y variance", dist.Uniform(0, 100))
    power = npy.sample("power", dist.Uniform(0, 10))

    # Define paths and values
    paths = ["xvar", "yvar", "pow"]
    values = [var_x, var_y, power]

    with npy.plate("data", len(data)):
        poisson_model = dist.Normal(
            model.update_and_model(
                "model", paths, values, pmap=pmap, flatten=True
            )
        )

        return npy.sample("super-gaussian", poisson_model, obs=data)


# This has not yet been correctly configured for the mkdocs framework, but will 
# be at some time in the future
# graph = npy.render_model(
#     modelling_fn, model_args=(fake_data.flatten(), model_pytree, path_dict)
# )

In [None]:
# Using the model above, we can now sample from the posterior distribution
# using the No U-Turn Sampler (NUTS).
sampler = npy.infer.MCMC(
    npy.infer.NUTS(modelling_fn),
    num_warmup=2000,
    num_samples=2000,
    progress_bar=True,
)
%time sampler.run(jr.PRNGKey(0), fake_data.flatten(), model_pytree, pmap)

In [None]:
sampler.print_summary()
values_out = sampler.get_samples()

In [None]:
chain = cc.ChainConsumer()
chain.add_chain(values_out)
chain.configure(
    serif=True, shade=True, bar_shade=True, shade_alpha=0.2, spacing=1.0, max_ticks=3
)
fig = chain.plotter.plot(truth={"power": 1, "x variance": 10, "y variance": 10})
fig.set_size_inches((15, 15));