# The `Base` Class

The `Base` class is the main module of `Zodiax` designed to make working with pytrees simpler. If you haven't already please read the [PyTree overview](https://louisdesdoigts.github.io/zodiax/pytree/) & `Zodiax` overview [here](https://louisdesdoigts.github.io/zodiax/usage/). There are 10 methods that are implemented in the `Base` class:

- `.get(path)` - get the value of a leaf
- `.set(path, value)` - set the value of a leaf
- `.add(path, value)` - add a value to a leaf
- `.multiply(path, value)` - multiply a leaf by a value
- `.divide(path, value)` - divide a leaf by a value
- `.power(path, value)` - raise a leaf to a power
- `.min(path, value)` - take the minimum of a leaf and value
- `.max(path, value)` - take the maximum of a leaf and value
- `.apply(path, fn)` - applies the function to the leaf
- `.apply_args(path, fn, args)` - - `.apply(path, fn)` - applies the function to the leaf while also passing in the extra arguments

Lets look at some examples of how to use these methods:

In [1]:
from zodiax import Base
import jax.numpy as np

In [2]:
# Example class
class Variances(Base):
    var_x: float
    var_y: float
    some_list: list
    some_dict: dict

    def __init__(self, var_x, var_y, some_list, some_dict):
        self.var_x = var_x
        self.var_y = var_y
        self.some_list = some_list
        self.some_dict = some_dict

# Example class
class SuperGaussian(Base):
    variances: object
    power: float

    def __init__(self, variances, power):
        self.variances = variances
        self.power = power
        
# Create an instance of the SuperGaussian object
var_x, var_y = 10, 10
power = 1
some_list = [-1, -2]
some_dict = {'a': 'foo', 'b': 'bar'}

# Create the object
variances = Variances(var_x, var_y, some_list, some_dict)
pytree = SuperGaussian(variances, power)

# Examine the object
print(pytree)

SuperGaussian(
  variances=Variances(
    var_x=10,
    var_y=10,
    some_list=[-1, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)


Nice! Here we have a nested structure, so to look at some of these class methods, we first need to understand the 'path' object.

## The `path` object

A `path` is simply a string that refers to some place in a pytree, with nested structures connected with dots '.', similar to accessing class attributes. Some example paths for our example pytree would look like this:

 - 'variances.var_x'
 - 'power'
 - 'variances.some_list.0'
 - 'variances.some_dict.a'
 - 'variances.some_dict'

Each of these path objects refer to some place in the pytree, not neccesarily a leaf. Now lets define some paths and look at the `.get(path)` method:

In [3]:
path1 = 'variances.var_x'
path2 = 'variances.some_list.0'
path3 = 'variances.some_dict'

# Get individual items
print(pytree.get(path1))
print(pytree.get(path2))
print(pytree.get(path3))

# Get list of items
print(pytree.get([path1, path2, path3]))

10
-1
{'a': 'foo', 'b': 'bar'}
[10, -1, {'a': 'foo', 'b': 'bar'}]


Now we can use these paths and the `.set(path, value)` to create an updated version of the pytree

This method takes in a path and a value and returns a pytree with the leaf or subtree sepcified by the path replaced with value. Similarly we can pass in a list of multiple paths and values and all of the parameters will be updated!

In [4]:
value1 = 100
value2 = -100
value3 = {'a': 'FOO', 'b': 'BAR'}

print(pytree.set([path1, path2, path3], [value1, value2, value3]))

SuperGaussian(
  variances=Variances(
    var_x=100,
    var_y=10,
    some_list=[-100, -2],
    some_dict={'a': 'FOO', 'b': 'BAR'}
  ),
  power=1
)


We can also use the 'nesting' concept to update mutiple parameters at once, and group those updates!

In [5]:
# Assign all paths to zero
print(pytree.set([path1, path2, path3], 0))

# Assign nested paths
print(pytree.set([path1, [path2, path3]], [0, 100]))

SuperGaussian(
  variances=Variances(var_x=0, var_y=10, some_list=[0, -2], some_dict=0),
  power=1
)
SuperGaussian(
  variances=Variances(var_x=0, var_y=10, some_list=[100, -2], some_dict=100),
  power=1
)


So you might notice that 'some_dict', which should presumably be a dictionary, is no-longer a dictionary! This is becuase these methods do *not* do any type checking at all! If you pass in a wrong data-type, then it will be assigned at to the leaf. This is an important caveat to be aware of, as if you accidentally put a list, or a jax array with incorrect dimensionality then later behaviour of the objects can not be guaranteed!

We can also use the `.add(path, value)` method to add a value to a leaf. As descibed, this works in the same manner to the `.set()` method, except it adds to the leaf specified by the path!

In [6]:
# Add different values
print(pytree.add([path1, path2], [1e3, -1e3]))

# Add the same value
print(pytree.add([path1, path2], 1e3))

SuperGaussian(
  variances=Variances(
    var_x=1010.0,
    var_y=10,
    some_list=[-1001.0, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)
SuperGaussian(
  variances=Variances(
    var_x=1010.0,
    var_y=10,
    some_list=[999.0, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)


The rest of the methods work in a similar manner, so I will not go into detail about them here. The exceptions are the `.apply(path, fn)` and `.apply_args(path, fn, args)` methods. These methods do not take in a value, but rather a function, so lets define some and have a look

In [7]:
# Sqaure function
def square(x):
    return x ** 2

# Apply function
print(pytree.apply([path1, path2], square))

# Cube function
cube = lambda x: x ** 3

# Apply functions
print(pytree.apply([path1, path2], cube))

# Log function
log = np.log10

# Apply functions
print(pytree.apply([path1, path2], log))

SuperGaussian(
  variances=Variances(
    var_x=100,
    var_y=10,
    some_list=[1, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)
SuperGaussian(
  variances=Variances(
    var_x=1000,
    var_y=10,
    some_list=[-1, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)
SuperGaussian(
  variances=Variances(
    var_x=f32[],
    var_y=10,
    some_list=[f32[], -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)


The `.apply(path, fn, args)` can be used to pass in extra arguments if we have a more complex function!

In [8]:
# Complex function
def mult_and_power(x, mult, power):
    return (x * mult) ** power

# Apply functions
print(pytree.apply_args([path1, path2], mult_and_power, (2, 3)))

SuperGaussian(
  variances=Variances(
    var_x=8000,
    var_y=10,
    some_list=[-8, -2],
    some_dict={'a': 'foo', 'b': 'bar'}
  ),
  power=1
)
