In [1]:
from asterion.glitch import HeGlitch, BCZGlitch, Glitch
from asterion.asy import FirstOrderAsy
import numpyro.distributions as dist
import numpyro
from numpyro.handlers import scope
import astropy.units as u
from tinygp import GaussianProcess, kernels

In [5]:
from typing import Union, Optional

In [6]:
s = u.microhertz.to_string("latex_inline")

In [74]:
from collections.abc import MutableMapping, KeysView, ValuesView, ItemsView, Set

class Parameter:
    
    def __init__(self, name: str, symbol: str="", unit: Optional[Union[str, u.Unit]]=None):
        self.name = name
        
        if unit is None:
            unit = u.dimensionless_unscaled
        
        self.unit = u.Unit(unit)
        
        if not symbol[0] == symbol[-1] == "$":
            raise ValueError("Symbol must start and finish with '$' to constitute valid latex math")
        
        self.symbol = symbol
        
    def __repr__(self):
#         if self.unit.is_unity():
#             return self.name
#         return f"{self.name} [{self.unit}]"
        unit = "'" + self.unit.to_string() + "'"
        return f"{self.__class__.__name__}({repr(self.name)}, symbol={repr(self.symbol)}, unit={unit})"
    
    def _repr_latex_(self):
        if self.unit.is_unity():
            return self.symbol
        name = self.symbol.strip("$")
        unit = self.unit.to_string("latex").strip("$")
        return  f"${name} \\left[ {unit} \\right]$"


class LogParameter(Parameter):
    def __init__(self, parameter, base="e"):
        if base not in ["e", 10]:
            raise NotImplementedError(f"Base {base} not implemented. Choose from ['e', 10]")
        
        if base == "e":
            name = "_".join(["log", parameter.name])
            symbol = rf"$\log\left({parameter.symbol.strip('$')}\right)$"
        else:
            name = "_".join(["log{base}", parameter.name])
            symbol = rf"$\log_{base}\left({parameter.symbol.strip('$')}\right)$"
        
        unit = u.dimensionless_unscaled
        
        super().__init__(name, symbol=symbol, unit=unit)
        
        self.base_parameter = parameter
        self.base = base


class ParameterCollection(MutableMapping):

    def __init__(self, iterable):
#         self._parameters = {}
        self.store = {}
        for value in iterable:
            self[value.name] = self._validate(value)

    def __iter__(self):
        return iter(self.store)
#         return iter(self._parameters)
        
    def __contains__(self, value):
        return value in self.store
#         return value in self._parameters

    def __len__(self):
        return len(self.store)
#         return len(self._parameters)
    
    def __getitem__(self, key):
        return self.store[key]
#         return self._parameters[key]
    
    def __setitem__(self, key, value):
        self.store[key] = self._validate(value)
#         self._parameters[key] = self._validate(value)

    def __delitem__(self, key):
        del self.store[key]
#         del self._parameters[key]

    def __repr__(self):
        return f"{self.__class__.__name__}({repr(list(self.values()))})"

    def __str__(self):
        return str(list(self.values()))
        
    def _repr_latex_(self):
        return "$\\left[" \
            + ",".join([value._repr_latex_().strip("$") for value in self.values()]) \
            + "\\right]$"
        
    def _validate(self, value):
        if not isinstance(value, Parameter):
            raise TypeError(f"Element is of type: '{type(value)}' but must be an instance of 'Parameter'.")
        if value.name in self:
            raise KeyError(f"Parameter with the name '{value.name}' already exists in collection.")
        return value
    
    def keys(self):
        return self.store.keys()
    
    def values(self):
        return self.store.values
    
    def items(self):
        return self.store.items()

In [80]:
params = ParameterCollection([
    Parameter("delta_nu", unit="microHertz", symbol=r"$\Delta\nu$"),
    Parameter("epsilon", symbol=r"$\epsilon$")
])

params.update(ParameterCollection([Parameter("epslon", symbol=r"$\epsilon$")]))

In [81]:
params.items()

dict_items([('delta_nu', Parameter('delta_nu', symbol='$\\Delta\\nu$', unit='uHz')), ('epsilon', Parameter('epsilon', symbol='$\\epsilon$', unit='')), ('epslon', Parameter('epslon', symbol='$\\epsilon$', unit=''))])

In [77]:
other_params = params = ParameterCollection([
    Parameter("delta_", unit="microHertz", symbol=r"$\Delta\nu$"),
])

In [85]:
import warnings
import jax.numpy as jnp

In [86]:
class CircularTransform(dist.transforms.Transform):
    codomain = dist.constraints.circular
    
    def __call__(self, x):
        return jnp.remainder(x + jnp.pi, 2 * jnp.pi) - jnp.pi

    def _inverse(self, y):
        warnings.warn(
            "CircularTransform is not a bijective transform."
            " The inverse of `y` will be `y`.",
            stacklevel=find_stack_level(),
        )
        return y

In [87]:
from numpyro.infer.reparam import TransformReparam
from numpyro.handlers import reparam

In [93]:
class ModelParameterWarning(Warning):
    def __init__(self, name):
        super().__init__(f"Model parameter '{name}' is not registered in model.params")

In [10]:
divider = "/"

class Model:
    params: ParameterCollection

        
# class GlitchMean:
#     # a function of functions
#     def __init__(self, asy_func, he_glitch, bcz_glitch):
#         self.asy_func = asy_func
#         self.he_glitch = he_glitch
#         self.bcz_glitch = bcz_glitch
    
#     def __call__(self, n)
#         nu_asy = numpyro.deterministic("nu_asy", self.asy_func(n))
#         dnu_he = numpyro.deterministic("dnu_he", self.he_func(nu_asy))
#         dnu_bcz = numpyro.deterministic("dnu_bcz", self.bcz_func(nu_asy))
#         return nu_asy + dnu_he + dnu_bcz

    
class GlitchModelFactory:
    # takes all model parameters and builds a gp and mean function
    def __init__(self, delta_nu, epsilon, a_he, b_he, tau_he, phi_he
                 a_cz, tau_cz, phi_cz, kerenl_var, kernel_scale, diag):
        self.asy_func = FirstOrderAsy(delta_nu, epsilon)
        self.he_glitch = HeGlitch(a_he, b_he, tau_he, phi_he)
        self.cz_glitch = BCZGlitch(a_cz, tau_cz, phi_cz)
        
        self.kernel = kernel_var * kernels.ExpSquared(kernel_scale)   
        self.diag = diag
    
    def mean(self, n):
        nu_asy = numpyro.deterministic("nu_asy", self.asy_func(n))
        dnu_he = numpyro.deterministic("dnu_he", self.he_func(nu_asy))
        dnu_bcz = numpyro.deterministic("dnu_bcz", self.bcz_func(nu_asy))
        return nu_asy + dnu_he + dnu_bcz

    def gp(self, n):
        return GaussianProcess(self.kernel, n, diag=self.diag)

class GlitchModel(Model):
    _he = r"\mathrm{He}"
    _bcz = r"\mathrm{BCZ}"
    
    # metadata about model parameters
    params = ParameterCollection([
        Parameter("delta_nu", symbol=r"$\Delta\nu$", unit=u.microhertz),
        Parameter("epsilon", symbol=r"$\epsilon$"),
        Parameter("a_he", symbol=rf"$a_{_he}$"),
        Parameter("b_he", symbol=rf"$b_{_he}$", unit=u.megasecond**2),
        Parameter("tau_he", symbol=rf"$\tau_{_he}$", unit=u.megasecond),
        Parameter("phi_he", symbol=rf"$\phi_{_he}$", unit=u.rad),
        Parameter("a_bcz", symbol=rf"$a_{_bcz}$", unit=u.microhertz**3),
        Parameter("tau_bcz", symbol=rf"$\tau_{_bcz}$", unit=u.megasecond),
        Parameter("phi_bcz", symbol=rf"$\phi_{_bcz}$", unit=u.rad),
        Parameter("<A_he>", symbol=rf"$\langle A_{_he} \rangle$", unit=u.microhertz),
        Parameter("<A_bcz>", symbol=rf"$\langle A_{_bcz} \rangle$", unit=u.microhertz),
        Parameter("nu_asy", symbol=r"$\nu_\mathrm{asy}$", unit=u.microhertz),
        Parameter("dnu_he", symbol=rf"$\delta\nu_{_he}$", unit=u.microhertz),
        Parameter("dnu_bcz", symbol=rf"$\delta\nu_{_bcz}$", unit=u.microhertz),
        Parameter("kernel_var", symbol=rf"$\sigma_k^2$", unit=u.microhertz),
        Parameter("kernel_scale", symbol=rf"$\lambda_k$"),
        Parameter("nu", symbol=r"$\nu$", unit=u.microhertz),
        Parameter("f", symbol=r"$f$", unit=u.microhertz),  # GP noiseless
        Parameter("y", symbol=r"$y$", unit=u.microhertz),  # GP obs
    ])

    # log params
    params.update(
        ParameterCollection([
            LogParameter(params["a_he"]),
            LogParameter(params["b_he"]),
            LogParameter(params["tau_he"]),
            LogParameter(params["a_bcz"]),
            LogParameter(params["tau_bcz"]),
        ])
    )

    def __init__(
        self,
        *,
        delta_nu,
        log_tau_he,
        log_tau_bcz,
        low=None,
        high=None,
        **prior,
    ):
        
        
        self.prior = prior
        
        # Required prior
        self.prior["delta_nu"] = delta_nu
        self.prior["log_tau_he"] = log_tau_he
        self.prior["log_tau_bcz"] = log_tau_bcz
        
        # Optional prior
        self.prior.setdefault("epsilon", dist.Normal(1.2, 0.3))
        self.prior.setdefault("log_a_he", dist.Normal())
        
        self.low = low
        self.high = high

    
    def sample_asy_func(self):
        """Sample the asy function."""
        delta_nu = numpyro.sample("delta_nu", self.prior["delta_nu"])
        epsilon = numpyro.sample("epsilon", self.prior["epsilon"])
        return FirstOrderAsy(delta_nu, epsilon)

    def sample_he_func(self) -> HeGlitch:
        """Sample the helium glitch function."""
        a = numpyro.deterministic(
            "a_he",
            jnp.exp(numpyro.sample("log_a_he", self.prior["log_a_he"])),
        )

        b = numpyro.deterministic(
            jnp.exp(numpyro.sample("log_b_he", self.prior["log_b_he"])),
        )

        tau = numpyro.deterministic(
            "tau_he",
            jnp.exp(numpyro.sample("log_tau_he", self.prior["log_tau_he"])),
        )

        phi = numpyro.sample("phi_he", self.prior["phi_he"])

        return HeGlitch(a, b, tau, phi)
    
    def sample_bcz_func(self) -> BCZGlitch:
        """Sample the BCZ glitch function."""
        a = numpyro.deterministic(
            "a_bcz",
            jnp.exp(numpyro.sample("log_a_bcz", self.prior["log_a_bcz"])),
        )

        tau = numpyro.deterministic(
            "tau_bcz",
            jnp.exp(numpyro.sample("log_tau_bcz", self.prior["log_tau_bcz"])),
        )

        phi = numpyro.sample("phi_bcz", self.prior["phi_bcz"])

        return BCZGlitch(a, tau, phi)
    
    def sample_mean(self):
        # Sample from function priors
        asy_func = self.sample_asy_func()
        he_func = self.sample_he_func()
        bcz_func = self.sample_bcz_func()
        
        return GlitchMean(asy_func, he_func, bcz_glitch)
        
    def sample_kernel(self) -> kernels.Kernel:
        amp = numpyro.sample("kernel_var", self.prior["kernel_var"])
        scale = numpyro.sample("kernel_scale", self.prior["kernel_scale"])
        return amp * kernels.ExpSquared(scale)        
    
    def predict(self, gp, y, n, mean=None):
        """Predict nu given conditioned gp and mean."""
        cond = gp.condition(y, n) # conditioned result 

        if mean is None:
            # E.g. if mean is zero or inclided in GP
            return numpyro.sample("nu", cond.gp.numpyro_dist())
        
        # Manually add on mean
        nu_mean = mean
        if callable(mean):
            nu_mean = mean(n)

        f = numpyro.sample("f", cond.gp.numpyro_dist())
        return numpyro.deterministic("nu", nu_mean + f)
    
    def average_amplitude(self, nu):
        low = nu.min() if self.low is None else self.low
        high = nu.max() if self.high is None else self.high

        A_he = numpyro.deterministic(
            "<A_he>", 
            mean.he_func.average_amplitude(low, high)
        )
        A_bcz = numpyro.deterministic(
            "<A_bcz>", 
            mean.bcz_func.average_amplitude(low, high)
        )
        return A_he, A_bcz

    def __call__(self, n, nu=None, diag=None, n_pred=None):
        
        mean = self.sample_mean()
        kernel = self.sample_kernel()
        gp = GaussianProcess(kernel, n, diag=diag)
        # replace this with something that returns the gp and mean
        # given the parameters
        nu_mean = numpyro.deterministic("nu_mean", mean(n))
        
        obs = None
        if nu is not None:
            obs = nu - nu_mean

        y = numpyro.sample("y", gp.numpyro_dist(), obs=obs)
        
        # Predict 'true' nu and compute average amplitudes
        self.average_amplitude(
            self.predict(gp, y, n, mean=nu_mean)
        )
        
        # Predict nu for given radial orders 'n_pred'
        if n_pred is not None:
            # Predict under the pred scope
            scope(
                self.predict, prefix="pred", divider=divider
            )(
                gp,
                y,
                n_pred, 
                mean=numpyro.deterministic("nu_mean", mean(n_pred))
            )

In [None]:
def build_gp():
    # returns the gp and the mean function
    return GaussianProcess(kernel, n, diag=diag)

def log_like

In [102]:
class A:
    def __init__(self):
        self.a = None

In [18]:
from collections.abc import ValuesView
isinstance(GlitchModel.params.values(), ValuesView)

True

In [19]:
GlitchModel.params._parameters.items()

dict_items([('delta_nu', Parameter('delta_nu', symbol='$\\Delta\\nu$', unit='uHz')), ('epsilon', Parameter('epsilon', symbol='$\\epsilon$', unit='')), ('a_he', Parameter('a_he', symbol='$a_\\mathrm{He}$', unit='')), ('b_he', Parameter('b_he', symbol='$b_\\mathrm{He}$', unit='Ms2')), ('tau_he', Parameter('tau_he', symbol='$\\tau_\\mathrm{He}$', unit='Ms')), ('phi_he', Parameter('phi_he', symbol='$\\phi_\\mathrm{He}$', unit='rad')), ('a_bcz', Parameter('a_bcz', symbol='$a_\\mathrm{BCZ}$', unit='uHz3')), ('tau_bcz', Parameter('tau_bcz', symbol='$\\tau_\\mathrm{BCZ}$', unit='Ms')), ('phi_bcz', Parameter('phi_bcz', symbol='$\\phi_\\mathrm{BCZ}$', unit='rad')), ('<A_he>', Parameter('<A_he>', symbol='$\\langle A_\\mathrm{He} \\rangle$', unit='uHz')), ('<A_bcz>', Parameter('<A_bcz>', symbol='$\\langle A_\\mathrm{BCZ} \\rangle$', unit='uHz')), ('nu_asy', Parameter('nu_asy', symbol='$\\nu_\\mathrm{asy}$', unit='uHz')), ('dnu_he', Parameter('dnu_he', symbol='$\\delta\\nu_\\mathrm{He}$', unit='uHz

In [20]:
GlitchModel.params.items()

ItemsView(ParameterCollection([Parameter('delta_nu', symbol='$\\Delta\\nu$', unit='uHz'), Parameter('epsilon', symbol='$\\epsilon$', unit=''), Parameter('a_he', symbol='$a_\\mathrm{He}$', unit=''), Parameter('b_he', symbol='$b_\\mathrm{He}$', unit='Ms2'), Parameter('tau_he', symbol='$\\tau_\\mathrm{He}$', unit='Ms'), Parameter('phi_he', symbol='$\\phi_\\mathrm{He}$', unit='rad'), Parameter('a_bcz', symbol='$a_\\mathrm{BCZ}$', unit='uHz3'), Parameter('tau_bcz', symbol='$\\tau_\\mathrm{BCZ}$', unit='Ms'), Parameter('phi_bcz', symbol='$\\phi_\\mathrm{BCZ}$', unit='rad'), Parameter('<A_he>', symbol='$\\langle A_\\mathrm{He} \\rangle$', unit='uHz'), Parameter('<A_bcz>', symbol='$\\langle A_\\mathrm{BCZ} \\rangle$', unit='uHz'), Parameter('nu_asy', symbol='$\\nu_\\mathrm{asy}$', unit='uHz'), Parameter('dnu_he', symbol='$\\delta\\nu_\\mathrm{He}$', unit='uHz'), Parameter('dnu_bcz', symbol='$\\delta\\nu_\\mathrm{BCZ}$', unit='uHz'), Parameter('kernel_var', symbol='$\\sigma_k^2$', unit='uHz'), 

In [21]:
ValuesView(GlitchModel.params._parameters)

ValuesView({'delta_nu': Parameter('delta_nu', symbol='$\\Delta\\nu$', unit='uHz'), 'epsilon': Parameter('epsilon', symbol='$\\epsilon$', unit=''), 'a_he': Parameter('a_he', symbol='$a_\\mathrm{He}$', unit=''), 'b_he': Parameter('b_he', symbol='$b_\\mathrm{He}$', unit='Ms2'), 'tau_he': Parameter('tau_he', symbol='$\\tau_\\mathrm{He}$', unit='Ms'), 'phi_he': Parameter('phi_he', symbol='$\\phi_\\mathrm{He}$', unit='rad'), 'a_bcz': Parameter('a_bcz', symbol='$a_\\mathrm{BCZ}$', unit='uHz3'), 'tau_bcz': Parameter('tau_bcz', symbol='$\\tau_\\mathrm{BCZ}$', unit='Ms'), 'phi_bcz': Parameter('phi_bcz', symbol='$\\phi_\\mathrm{BCZ}$', unit='rad'), '<A_he>': Parameter('<A_he>', symbol='$\\langle A_\\mathrm{He} \\rangle$', unit='uHz'), '<A_bcz>': Parameter('<A_bcz>', symbol='$\\langle A_\\mathrm{BCZ} \\rangle$', unit='uHz'), 'nu_asy': Parameter('nu_asy', symbol='$\\nu_\\mathrm{asy}$', unit='uHz'), 'dnu_he': Parameter('dnu_he', symbol='$\\delta\\nu_\\mathrm{He}$', unit='uHz'), 'dnu_bcz': Parameter(

In [97]:
d = {}
d.setdefault("a", 1)

1

In [98]:
d

{'a': 1}