In [5]:
import numpy as np
import torch
from functorch import vmap, grad
from caustic.sources import Sersic
from caustic.lenses import SIE
from caustic.utils import get_meshgrid

In [None]:
import torch
from torch.distributions import transform_to, biject_to


class Parameter:
    def __init__(self, constraints, prior=None, ):
        ...

    def keys(self):
        ...

    def values(self):
        ...

    def items(self):
        ...

    def prior(self):
        ...

    def constraints(self):
        ...

    def unconstrained_value(self):
        ...

    def constrained_value(self):
        ...

    def sample(self):
        # initialize the value from the prior, if no prior is given, assume a Uniform prior over the constraints
        ...

    def __repr__(self):
        ...






In [22]:
import torch

from caustic.utils import derotate, translate_rotate
from caustic.lenses.base import ThinLens


class SIE(ThinLens):
    """
    References:
        Keeton 2001, https://arxiv.org/abs/astro-ph/0102341
    """

    def __init__(self, device=torch.device("cpu")):
        super().__init__(device)

    def _get_psi(self, x, y, q, s):
        return (q**2 * (x**2 + s**2) + y**2).sqrt()

    def alpha(self, thx, thy, z_l, z_s, cosmology, thx0, thy0, q, phi, b, s=None):
        s = torch.tensor(0.0, device=self.device, dtype=thx0.dtype) if s is None else s
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        psi = self._get_psi(thx, thy, q, s)
        f = (1 - q**2).sqrt()
        ax = b * q.sqrt() / f * (f * thx / (psi + s)).atan()
        ay = b * q.sqrt() / f * (f * thy / (psi + q**2 * s)).atanh()

        return derotate(ax, ay, phi)

    def Psi(self, thx, thy, z_l, z_s, cosmology, thx0, thy0, q, phi, b, s=None):
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        # Only transform coordinates once: pass thx0=0, thy=0, phi=None to alpha
        ax, ay = self.alpha(thx, thy, z_l, z_s, cosmology, 0.0, 0.0, q, None, b, s)
        return thx * ax + thy * ay

    def kappa(self, thx, thy, z_l, z_s, cosmology, thx0, thy0, q, phi, b, s=None):
        s = torch.tensor(0.0, device=self.device, dtype=thx0.dtype) if s is None else s
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        psi = self._get_psi(thx, thy, q, s)
        return 0.5 * q.sqrt() * b / psi


In [24]:
from caustic.utils import to_elliptical, translate_rotate
from caustic.sources.base import Source


class Sersic(Source):
    def __init__(self, device=torch.device("cpu"), use_lenstronomy_k=False):
        """
        Args:
            lenstronomy_k_mode: set to `True` to calculate k in the Sersic exponential
                using the same formula as lenstronomy. Intended primarily for testing.
        """
        super().__init__(device)
        self.lenstronomy_k_mode = use_lenstronomy_k

    def brightness(self, thx, thy, thx0, thy0, q, phi, index, th_e, I_e, s=None):
        s = torch.tensor(0.0, device=self.device, dtype=thx0.dtype) if s is None else s
        thx, thy = translate_rotate(thx, thy, thx0, thy0, phi)
        ex, ey = to_elliptical(thx, thy, q)
        e = (ex**2 + ey**2).sqrt() + s

        if self.lenstronomy_k_mode:
            k = 1.9992 * index - 0.3271
        else:
            k = 2 * index - 1 / 3 + 4 / 405 / index + 46 / 25515 / index**2

        exponent = -k * ((e / th_e) ** (1 / index) - 1)
        return I_e * exponent.exp()

In [19]:
class Simulation:
    def __init__(self, **hyperparameters):
        self.source = Sersic()
        self.lens = SIE()
        
        self.thx, self.thy = get_meshgrid(0.04, 128, 128)
        
    def forward(self, x):
        alpha_x, alpha_y = self.lens.alpha(self.thx, self.thy, x["lens"])

In [20]:
sim = Simulation()
sim.forward(None)

TypeError: alpha() missing 10 required positional arguments: 'thx', 'thy', 'z_l', 'z_s', 'cosmology', 'thx0', 'thy0', 'q', 'phi', and 'b'