In [1]:
import numpy as np
import torch
from functorch import vmap, grad
from caustic.sources import Sersic
from caustic.lenses import SIE
from caustic.utils import get_meshgrid

# Sketch

- User creates a class Pytorch style, where he creates the modules he/she needs (Source, Lens, etc.) in the ```__init__``` and the computational graph in the ```forward``` method.

- All **static parameters** are created and defined in ```__init__``` method, and the rest are left as **dynamic parameters** to be inferred.

- The Simulator class knows automatically the distinction by reading off the parameters from its modules.

### Static parameter
A parameter fixed during inference/simulation. For example, we might want to fix the cosmology unless we are doing time-delay cosmography. Then, $H_0$ will be considered a **static parameter**, or hyperparameter. 

In our context, a static parameter will simply be whatever needs to be defined in the ```__init___``` method.

### Dynamic parameter
A parameter that we wish to infer. The value of this parameter is not known in advance by our simulator (remains undefined in the ```__init__``` method). In other words, it will be part of the $\mathbf{x}$ tensor fed into the forward method of the simulator. 

### Sketch
```
class Sketch1(Simulator):
    def __init__(self, **static_parameters, **hyperparameters, device=DEVICE):
        self.source = Sersic(**static_parameters, name="source", device=device)
        self.lens = SIE(**static_parameters, name="lens", device=device)
        self.instruments = Instruments(**static_parameters, name="instruments", device=DEVICE)
        
        self.cosmology = FlatLambdaCDM(**static_parameters)
        self.thx, self.thy = get_meshgrid(**hyperparameters)
        self.zl, self.zs = **hyperparameters
    
    @vmap
    def forward(x):
        x = self.transform_to_physical(x)
        ax, ay = self.lens.alpha(self.thx, self.thy, self.zl, self.zs, self.cosmology, x["lens"])
        y = self.source.brightness(self.thx - ax, self.thy - ay, x["source"])
        y = self.instruments.observe(y, x["instruments"])
        return y
        
```

### What do we need to accomplish such a control flow?
- Every module need to know what their parameters are, as well as their numerical constraints. Thus we need 2 things
    1. Create a Parameter class
    2. Redefine the ```__init__``` of every module in a way to can catch static parameters, and leave undefined the dynamic parameters. 
- The Simulator module will need to know about its dynamic parameters so that the user can easily get access to the structure of $\mathbf{x}$.

In [122]:
import torch
from torch.distributions import transform_to, biject_to
from torch.distributions.constraints import Constraint

class Parameter:
    def __init__(self, value: float, name: str, index: int, dimension: list, constraints: Constraint, bijective_transform=False, device=torch.device("cpu")):
        self.constraints = constraints
        self.dimension = dimension
        self.name = name
        self.index = index
        if value is None:
            self._dynamic = True
            self._value = None
        else:
            self._dynamic = False
            self._value = torch.tensor([value]).to(device)
        if bijective_transform:
            self.transform = biject_to(constraints)
        else:
            self.transform = transform_to(constraints)
    
    @property
    def dynamic(self):
        return self._dynamic
    
    # TODO redefine the setter of this so that changing the setting the value also change the state of the class
    @property
    def value(self):
        return self._value
        

    def constrained_value(self, x):
        return self.transform(x)

    def __repr__(self):
        return self.name


In [149]:
import torch
from torch.distributions import constraints as C

from caustic.utils import derotate, translate_rotate
from caustic.lenses.base import ThinLens


class SIE(ThinLens):
    """
    References:
        Keeton 2001, https://arxiv.org/abs/astro-ph/0102341
    """

    def __init__(self, name="sie", thx0=None, thy0=None, q=None, phi=None, b=None, s=0.0, device=torch.device("cpu")):
        super().__init__(device)
        self.module = None # just create a hook to let Simulator know there are potentially dynamical parameters here
        self.name = name
        
        # Here we store the parameters. Then we can recover the named parameters (kwargs) from these two list at inference time
        self.static_parameters = []  # these will be properties of the class, since they are required for all "modules" of caustic
        self.dynamic_parameters = []
        constraints = {"x0": C.real, "y0": C.real, "q": C.interval(0.05, 1), "phi": C.real, "b": C.positive, "s": C.positive}
        for i, (k, p) in enumerate({"x0": thx0, "y0": thy0, "q": q, "phi": phi, "b": b, "s": s}.items()):
            if p is None:
                self.dynamic_parameters.append(
                    Parameter(p, name=k, index=i, dimension=[1], constraints=constraints[k], device=device)
                )
            else:
                self.static_parameters.append(
                    Parameter(p, name=k, index=i, dimension=[1], constraints=constraints[k], device=device)
                )

    def untangle_x(self, x): # This should be a base class method, since we will always be working from static and dynamic lists
        assert x.shape[0] == len(self.dynamic_parameters) # for now, assume there is no batch dimensions. 
        D = len(self.dynamic_parameters)
        p = list(torch.tensor_split(x, D, dim=0)) + [p.value for p in self.static_parameters] # get all parameters in a list
        i_map = np.argsort([p.index for p in self.dynamic_parameters] + [p.index for p in self.static_parameters]) # get their order so that all methods of this class can work
        return [p[i] for i in i_map]
        
    def _get_psi(self, x, y, q, s):
        return (q**2 * (x**2 + s**2) + y**2).sqrt()

    def alpha(self, thx, thy, z_l, z_s, cosmology, x):
        thx0, thy0, q, phi, b, s = self.untangle_x(x)  # We infer the value of each parameter at inference time from static and dynamic lists.   
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        psi = self._get_psi(thx, thy, q, s)
        f = (1 - q**2).sqrt()
        ax = b * q.sqrt() / f * (f * thx / (psi + s)).atan()
        ay = b * q.sqrt() / f * (f * thy / (psi + q**2 * s)).atanh()

        return derotate(ax, ay, phi)

    def Psi(self, thx, thy, z_l, z_s, cosmology, x):
        thx0, thy0, q, phi, b, s = self.untangle_x(x)  # We infer the value of each parameter at inference time from static and dynamic lists.   
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        # Only transform coordinates once: pass thx0=0, thy=0, phi=None to alpha
        ax, ay = self.alpha(thx, thy, z_l, z_s, cosmology, 0.0, 0.0, q, None, b, s)
        return thx * ax + thy * ay

    def kappa(self, thx, thy, z_l, z_s, cosmology, x):
        thx0, thy0, q, phi, b, s = self.untangle_x(x)  # We infer the value of each parameter at inference time from static and dynamic lists.
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        psi = self._get_psi(thx, thy, q, s)
        return 0.5 * q.sqrt() * b / psi


In [150]:
# skecthing out how the hook will work
class Sketch1:
    def __init__(self):
        self.source = SIE()
sim = Sketch1()
sie = SIE(thx0=10.0)

# This loops accross attributes in search of the hook. This is the idea of how we will catch all parameters from the Simulator point of view, at inference time!
for module in [k for k, v in vars(sim).items() if not k.startswith('_')]:
    for m2 in [k for k,v in vars(getattr(sim, module)).items() if not k.startswith('_')]:
        if m2 == "module":
            print("I found the hook")

I found the hook


In [151]:
sie.static_parameters[0].index

0

In [152]:
sie.dynamic_parameters

[y0, q, phi, b]

In [153]:
sie.untangle_x(torch.arange(4))

[tensor([10.]),
 tensor([0]),
 tensor([1]),
 tensor([2]),
 tensor([3]),
 tensor([0.])]

In [66]:
from caustic.utils import to_elliptical, translate_rotate
from caustic.sources.base import Source


class Sersic(Source):
    
    def __init__(self, device=torch.device("cpu"), use_lenstronomy_k=False):
        """
        Args:
            lenstronomy_k_mode: set to `True` to calculate k in the Sersic exponential
                using the same formula as lenstronomy. Intended primarily for testing.
        """
        super().__init__(device)
        self.lenstronomy_k_mode = use_lenstronomy_k

    def brightness(self, thx, thy, thx0, thy0, q, phi, index, th_e, I_e, s=None):
        s = torch.tensor(0.0, device=self.device, dtype=thx0.dtype) if s is None else s
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        ex, ey = to_elliptical(thx, thy, q)
        e = (ex**2 + ey**2).sqrt() + s

        if self.lenstronomy_k_mode:
            k = 1.9992 * index - 0.3271
        else:
            k = 2 * index - 1 / 3 + 4 / 405 / index + 46 / 25515 / index**2

        exponent = -k * ((e / th_e) ** (1 / index) - 1)
        return I_e * exponent.exp()

In [None]:
class Simulator:
    def __init__(self, device=torch.device("cpu")):
        

In [19]:
class Sketch(Simulator):
    def __init__(self, **hyperparameters):
        self.source = Sersic()
        self.lens = SIE()
        
        self.thx, self.thy = get_meshgrid(0.04, 128, 128)
        
    def forward(self, x):
        alpha_x, alpha_y = self.lens.alpha(self.thx, self.thy, x["lens"])

In [20]:
sim = Simulation()
sim.forward(None)

TypeError: alpha() missing 10 required positional arguments: 'thx', 'thy', 'z_l', 'z_s', 'cosmology', 'thx0', 'thy0', 'q', 'phi', and 'b'