# Advanced Guide to `caskade`

The beginners guide layed out the basics of constructing simulators in `caskade`, now we will present the powerful capabilities and techniques that let you easily and efficiently perform complex analyses. The order of these techniques has no particular meaning, so you may search for points of interest or scan through for relevant sections. 

In [None]:
import torch
import numpy as np
import caskade as ckd
from time import time, sleep
import matplotlib.pyplot as plt
import h5py
from IPython.display import display

In [None]:
class Gaussian(ckd.Module):
    def __init__(self, name, x0=None, y0=None, q=None, phi=None, sigma=None, I0=None):
        super().__init__(name)
        self.x0 = ckd.Param("x0", x0) # position
        self.y0 = ckd.Param("y0", y0)
        self.q = ckd.Param("q", q) # axis ratio
        self.phi = ckd.Param("phi", phi) # orientation
        self.sigma = ckd.Param("sigma", sigma) # width
        self.I0 = ckd.Param("I0", I0) # intensity

    @ckd.forward
    def _r(self, x, y, x0=None, y0=None, q=None, phi=None):
        x, y = x - x0, y - y0
        s, c = torch.sin(phi), torch.cos(phi)
        x, y = c * x - s * y, s * x + c * y
        return (x ** 2 + (y * q) ** 2).sqrt()
    
    @ckd.forward
    def brightness(self, x, y, sigma=None, I0=None):
        return I0 * (-self._r(x, y)**2 / sigma**2).exp()
    
class Combined(ckd.Module):
    def __init__(self, name, first, second, ratio=0.5):
        super().__init__(name)
        self.first = first # Modules are automatically registered
        self.ratio = ckd.Param("ratio", ratio, valid=(0,1))
        self.second = second

    @ckd.forward
    def brightness(self, x, y, ratio):
        return ratio * self.first.brightness(x, y) + (1 - ratio) * self.second.brightness(x, y)

## Ways of accessing Param values

When running a simulation there are several ways to access the value of a `Param` object, here is a mostly complete listing.

In [None]:
class TryParam(ckd.Module):
    def __init__(self, submod):
        super().__init__()
        self.x = ckd.Param("x", 1.0)
        self.y = ckd.Param("y", 2.0)
        self.submod = submod

    @ckd.forward
    def test_access(self, a, x, k=1, y=None):
        # Regular function attribute, is not a caskade object and so behaves normally
        total = a
        total += k

        # Getting values from Param objects
        total += x ** 2 # as arg of function (preferred)
        total += y ** 2 # as kwarg of function (preferred)
        total += self.x.value ** 2 # by attribute (allowed but discouraged)
        total += self.submod.I0.value ** 2 # by attribute of submod (allowed but may indicate inefficient code)

        # Modifying values of Param objects
        x = 3.0 # locally modify param value (allowed)
        total += x ** 2 # use modified value, will not change the param value globally
        total += self.submod.brightness(0,0, sigma=2.0) # call module with modified param value, only affects this call (allowed)
        self.x.value = 4.0 # modify param value globally (explicitly forbidden)
        return total
    
G = Gaussian("G", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
T = TryParam(G)

try:
    T.test_access(0.0)
except ckd.ActiveStateError as e:
    print("Caught ActiveStateError:", e)

# Outside a @forward function, we can still access param values like so:
print("x:", T.x.value)
# If a Param is a pointer, and you access the `value` it will try to evaluate the pointer
G.sigma = T.x
print("sigma:", G.sigma.value) # Basic pointer to another Param
G.sigma = lambda p: p.x.value * 2.0
G.sigma.link(T.x)
print("sigma:", G.sigma.value) # Function pointer

## Control dynamic vs static param

One of the most powerful features of `caskade` is its flexible system for switching which parameters are dynamic (involved in sampling/fitting) and which are static (fixed). This allows a single simulator object to perform many tasks with a uniform interface. Here we will see a few options for controlling this feature.

In [None]:
# All params initialized with a value
G1 = Gaussian("G1", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
G2 = Gaussian("G2", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
C = Combined("C", G1, G2)
print("All params are static automatically when given a value")
display(C.graphviz())

# Set individual param to dynamic
G1.x0.to_dynamic() # call function to set dynamic
G1.q = None # set to None to make dynamic
C.to_dynamic() # only sets immediate children to dynamic
print("Individual params can be set to dynamic")
display(C.graphviz())

# Set all simulator params to be dynamic
C.to_dynamic(local_only=False)
print("All params for the entire simulator may be set to dynamic")
display(C.graphviz())

# Even when set to dynamic, the params remember their original values
print("x0:", G1.x0.value)
G1.x0 = G1.x0.value # Setting value sets to static
G1.q.to_static() # Setting to static, uses the earlier value

# Setting any value will make it static
G1.I0 = 10.0 
print("Individual params can be set to static")
display(C.graphviz())

# Similarly a whole simulator can be set static
C.to_static(local_only=False)
print("All params for the entire simulator may be set to static")
display(C.graphviz())

# Use a param list to set multiple params to dynamic
paramset1 = ckd.NodeList([G1.x0, G1.q, G2.phi, G2.sigma])
paramset1.to_dynamic() # set all params in the list to dynamic
print("Use a NodeList to curate which params are set to dynamic/static")
display(C.graphviz())

# NOTE: trying to set a dynamic param to static when there is no stored value will throw an error
badparam = ckd.Param("badparam")
print("Blank param is dynamic: ", badparam.dynamic)
try:
    badparam.to_static()
except Exception as e:
    print(f"Caught error: {type(e)}: {e}")
    print("Param is still dynamic: ", badparam.dynamic)

## Call function with internally modified param value

A caskade simulator often is build of nested modules that call each others functions. Sometimes one may wish to call a function but with a different value for one of the Params than what has been given in the input (for example when computing a reference for comparison). Here we will show how to do this kind of local Param modification. This is also covered in [Ways of accessing Param values](#ways-of-accessing-param-values).

In [None]:
class TryModify(ckd.Module):
    def __init__(self, submod):
        super().__init__()
        self.submod = submod
        self.newval1 = torch.tensor(2.0)
        self.newval2 = torch.tensor(3.0)

    @ckd.forward
    def test_modify(self):
        init = self.submod.brightness(0,0) # call with original param values
        mod = self.submod.brightness(0,0, sigma=self.newval1) # call with modified param value
        with ckd.OverrideParam(self.submod.sigma, self.newval2):
            othermod = self.submod.brightness(0,0) # call with temporarily modified param value
        assert init != mod
        assert init != othermod
        assert mod != othermod
        print("See, they are all different!")
        return init, mod, othermod
    
G = Gaussian("G", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
T = TryModify(G)
print(T.test_modify())

## Reparametrize a Module

Sometimes it makes sense to write a module and its functions using a particular parametrization, but on some occasions or for user interpretation it should be given in another parametrization. For example, it may be easier to write some model in cartesian coordinates, but for users the polar coordinates are easier to interpret. 

In [None]:
G = Gaussian("G", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0) # default in cartesian coordinates
r = ckd.Param("r", 1.0) # radius
theta = ckd.Param("theta", 0.0) # angle
G.x0 = lambda p: p.r.value * torch.cos(p.theta.value)
G.x0.link(r)
G.x0.link(theta)
G.y0 = lambda p: p.r.value * torch.sin(p.theta.value)
G.y0.link(r)
G.y0.link(theta)

G.graphviz()

## Save, Append, and Load the Param values

It is possible to save the state of the params in a `caskade` simulator in an HDF5 file. Once saved, one can append to the file to create a "chain" such as in MCMC sampling.

Note: it is also possible to store meta data in the hdf5 file. Simply add the metadata in the `.meta` attribute of any of the `caskade` nodes and it will be stored at the appropriate place in the graph. See the [Add meta data](#add-meta-data-to-a-param-or-module) section for how to do this.

In [None]:
# Recreate the gaussian in polar coordinates example
G = Gaussian("G", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
r = ckd.Param("r", 1.0)
theta = ckd.Param("theta", 0.0)
G.x0 = lambda p: p.r.value * torch.cos(p.theta.value)
G.x0.link(r)
G.x0.link(theta)
G.y0 = lambda p: p.r.value * torch.sin(p.theta.value)
G.y0.link(r)
G.y0.link(theta)

# Run the "MCMC"
G.save_state("gauss_chain.h5", appendable=True) # save the initial state

# Pretend to run a sampling chain
for _ in range(100):
    G.x0.value += np.random.normal(0.01, 0.1)
    G.y0.value += np.random.normal(0.01, 0.1)
    G.q.value = np.clip(G.q.value + 0.1 * np.random.randn(), 0.1, 0.9)
    G.phi.value = (G.phi.value + 0.1 * np.random.randn()) % np.pi
    G.sigma.value += np.random.normal(0.1, 0.05)
    G.I0.value += np.random.normal(0.01, 0.5)

    G.append_state("gauss_chain.h5") # append the new state

In [None]:
# Now we can read the chain back in
fig, axarr = plt.subplots(6, 6, figsize=(12, 12))
with h5py.File("gauss_chain.h5", "r") as f: # Load the hdf5 file directly
    for i, ikey in enumerate(["x0", "y0", "q", "phi", "sigma", "I0"]):
        idata = f["G"][ikey]["value"] # access values for a given param
        for j, jkey in enumerate(["x0", "y0", "q", "phi", "sigma", "I0"]):
            jdata = f["G"][jkey]["value"] # access values for a given param
            if i < j:
                axarr[i,j].axis("off")
                continue
            elif i == j:
                axarr[i, j].hist(idata, bins=50, color="k")
                axarr[i, j].set_xlabel(ikey)
                axarr[i, j].set_ylabel("Counts")
            else:
                axarr[i, j].scatter(jdata, idata, s=2, color="k")
            axarr[i, j].set_xlabel(jkey)
            axarr[i, j].set_ylabel(ikey)
plt.show()

You can also simply load the state of a module from the hdf5 file.

In [None]:
G.load_state("gauss_chain.h5", 32) # Load the 32nd state from the chain

print("Loaded state 32:")
print(f"x0: {G.x0.value.item():.2f}") 
print(f"y0: {G.y0.value.item():.2f}")
print(f"q: {G.q.value.item():.2f}")
print(f"phi: {G.phi.value.item():.2f}")
print(f"sigma: {G.sigma.value.item():.2f}")
print(f"I0: {G.I0.value.item():.2f}")

## Add meta data to a Param or Module

Sometimes it is very useful to carry along some extra data right next to your params. For example, you may want to keep track of the uncertainty of a param value. The best way to do this is by tacking on attributes to the `meta` container in a `Param`. This is essentially an empty class which you may then build on however you like. Anything you do to this object is guaranteed not to interfere with `caskade` stuff. Similarly, making attributes with the `meta_` prefix is guaranteed not to interfere with `caskade` stuff.

In [None]:
p = ckd.Param("p", 1.0) 

p.meta.extra_info = 42 # add attribute to meta container (preferred)
p.meta_extra_info = 42 # add attribute with "meta_" prefix (allowed)
p.extra_info = 42 # add attribute directly to Param object (allowed but discouraged due to potential conflicts)

It is also possible to define new types of `Param` objects by subclassing `Param`, however one should be careful not to make differences too extreme if they wish to interact with other `caskade` based packages. A straightforward example would be when making a package where every parameter will store an uncertainty, rather than creating the attribute for each new `Param`, one can just make a class that starts with it from the outset.

In [None]:
class ParamU(ckd.Param):
    def __init__(self, *args, uncertainty = None, **kwargs):
        super().__init__(*args, **kwargs)
        if uncertainty is None:
            self.uncertainty = torch.zeros_like(self.value)
        else:
            self.uncertainty = uncertainty

p = ParamU("p", 1.0)
print(f"p: {p.value} +- {p.uncertainty}")
p2 = ParamU("p2", 2.0, uncertainty=0.1)
print(f"p2: {p2.value} +- {p2.uncertainty}")

## Break up a Param Tensor

Sometimes a `Param` value is naturally a multi-component tensor, but we only wish for part of it to be dynamic. This can be accomplished by creating new params and linking appropriately.

In [None]:
# This is the param we plan to use
x = ckd.Param("x", torch.arange(10)) # param has 10 elements
print("Original x tensor", x.value)

# These are sub params for the broken primary param
x_dynamic = ckd.Param("x_dynamic", torch.arange(3)) # want first three elements to be dynamic
x_dynamic.to_dynamic()
x_static = ckd.Param("x_static", torch.arange(3,10)) # want last seven elements to be static

# This rebuilds the full param from the broken params
x.value = lambda p: torch.cat((p.x_dynamic.value, p.x_static.value))
x.link(x_dynamic)
x.link(x_static)

# Here we see we get the same result, but now only the first three elements are dynamic!
print("Rebuilt x tensor", x.value)
x.graphviz()

## Batching with caskade

Adding batch dimensions allows for more efficient computation by requiring less communication between the CPU and GPU, or simply by letting the CPU spend more time doing computations and less time reading python code. In `caskade` it is possible to fully take advantage of batching capabilities of ones code. Here we demo the basic format for doing so.

### Case 1, vmap

`vmap` is a utility in PyTorch that lets you automatically add a batch dimension to your inputs and outputs. You can think of it like a faster version of a `for-loop` that just stacks all the outputs together.

In [None]:
G = Gaussian("G", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
G.sigma.to_dynamic()
G.phi.to_dynamic()
x, y = torch.meshgrid(torch.linspace(0,10,100), torch.linspace(0,10,100), indexing="ij")

# Batching using vmap                phi                            sigma
params = torch.stack((torch.linspace(0.0, 3.14/2, 5), torch.linspace(0.5, 4.0, 5)), dim=-1)
img = torch.vmap(G.brightness, in_dims=(None, None, 0), out_dims=0)(x, y, params)
fig, axarr = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axarr):
    ax.imshow(img[i].detach().numpy(), origin="lower")
    ax.axis("off")
plt.show()

# Multiple batching with vmap
# imagine the brightness function could only take a single value, rather than a grid
#                                            batch x y                        batch params
img = torch.vmap(torch.vmap(G.brightness, in_dims=(0,0,None)), in_dims=(None, None, 0))(x.flatten(), y.flatten(), params)
img = img.reshape(5, *x.shape)
fig, axarr = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axarr):
    ax.imshow(img[i].detach().numpy(), origin="lower")
    ax.axis("off")
plt.show()

### Case 2, Module with batch dimension

If you write a module assuming the user will pass parameters with a batch dimension, then you can handle direct batching without using wrappers like `vmap`. This requires a bit more care in managing the shapes of each object, but can pay off a lot in terms of speed and flexibility later on!

In [None]:
class GaussianBatched(ckd.Module):
    def __init__(self, name, x0=None, y0=None, q=None, phi=None, sigma=None, I0=None):
        super().__init__(name)
        self.x0 = ckd.Param("x0", x0) # position
        self.y0 = ckd.Param("y0", y0)
        self.q = ckd.Param("q", q) # axis ratio
        self.phi = ckd.Param("phi", phi) # orientation
        self.sigma = ckd.Param("sigma", sigma) # width
        self.I0 = ckd.Param("I0", I0) # intensity

    @ckd.forward
    def _r(self, x, y, x0=None, y0=None, q=None, phi=None):
        x0 = x0.unsqueeze(-1)
        y0 = y0.unsqueeze(-1)
        q = q.unsqueeze(-1)
        phi = phi.unsqueeze(-1)
        x, y = x - x0, y - y0
        s, c = torch.sin(phi), torch.cos(phi)
        x, y = c * x - s * y, s * x + c * y
        return (x ** 2 + (y * q) ** 2).sqrt()
    
    @ckd.forward
    def brightness(self, x, y, sigma=None, I0=None):
        init_shape = x.shape
        B, *_ = sigma.shape
        x = x.flatten()
        y = y.flatten()
        return (I0.unsqueeze(-1) * (-self._r(x, y)**2 / sigma.unsqueeze(-1)**2).exp()).reshape(B, *init_shape)
    
G = GaussianBatched("G", x0=[5], y0=[5], q=[0.5], phi=[0.0], sigma=[1.0], I0=[1.0])
G.to_dynamic() # all params are dynamic
x, y = torch.meshgrid(torch.linspace(0,10,100), torch.linspace(0,10,100), indexing="ij")

# Batching on all dims using batched tensor input
params = G.build_params_array()
params = params.repeat(5, 1) # 5 copies of the same params
params[:,3] = torch.linspace(0.0, 3.14/2, 5) # phi
params[:,4] = torch.linspace(0.5, 4.0, 5) # sigma
img = G.brightness(x, y, params=params)
fig, axarr = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axarr):
    ax.imshow(img[i].detach().numpy(), origin="lower")
    ax.axis("off")
plt.show()

# Batching by setting shapes of params, then flat tensor input
for param in G.dynamic_params:
    param.shape = (5,) + param.shape # add batch dimension to shape
params = params.T.flatten() # now params is a flat tensor again
img = G.brightness(x, y, params=params)
fig, axarr = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axarr):
    ax.imshow(img[i].detach().numpy(), origin="lower")
    ax.axis("off")
plt.show()

# Batching using list input, note that list allows for different shapes, (also true for dictionary params)
params = [
    torch.tensor(5), # x0
    torch.tensor(5), # y0
    torch.tensor(0.5), # q
    torch.linspace(0.0, 3.14/2, 5), # phi, batched
    torch.linspace(0.5, 4.0, 5), # sigma, batched
    torch.tensor(1.0) # I0
]
img = G.brightness(x, y, params=params)
fig, axarr = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axarr):
    ax.imshow(img[i].detach().numpy(), origin="lower")
    ax.axis("off")
plt.show()

## Remove Param from a Module

It is possible to remove a Param object from a module and later replace it. This may be helpful for getting a simulator exactly the way you want it. You may use this to have multiple modules share a Param rather than just pointing to the same object. Generally, this is not preferred practice since it is just as fast to use pointers and they are more flexible.

In [None]:
G1 = Gaussian("G1", x0=None, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
G2 = Gaussian("G2", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)
C = Combined("C", G1, G2)

del G2.x0 # remove a param from a module

C.graphviz()

In [None]:
G2.x0 = G1.x0 # assign a param from one module to another
C.graphviz()

## Pointer functions only called once

When you create a pointer function it may be arbitrarily complex, which may require a lot of compute. To maintain efficiency, the pointer is only called once for a given simulation then the value is stored. This shouldn't matter on the user side, but it is just good to know!

In [None]:
class TryCallPointer(ckd.Module):
    def __init__(self):
        super().__init__()
        self.x = ckd.Param("x", 1.0)
        self.y = ckd.Param("y", 2.0)

    @ckd.forward
    def test_call(self):
        total = 0.0
        start = time()
        total += self.x.value
        print(f"first call took {time()-start:.5f} sec")
        start = time()
        total += self.x.value
        print(f"second call took {time()-start:.5f} sec")
        return total
    
def long_function(p):
    sleep(2)
    return 1.0 + p.y.value

T = TryCallPointer()
T.x = long_function
T.x.link(T.y)
print(T.test_call())

print("\nOutside @forward the pointer is called every time:")
start = time()
T.x.value
print(f"first outside call took {time()-start:.5f} sec")
start = time()
T.x.value
print(f"second outside call took {time()-start:.5f} sec")