# TO DO 
- [ ] reshape output of Ncopies bums
- [ ] add to MF Model
- [ ] make the same for GM models

In [1]:
from functools import reduce, partial

import numpy as np
import jax
from jax import random
import jax.numpy as jnp

import nifty8 as ift
import nifty8.re as jft
from nifty8.re.tree_math import ShapeWithDtype

import jubik0 as ju

seed = 42
key = random.PRNGKey(seed)

Ciao is not sourced or installed. Therefore some operations can't be performed


# Definitions

In [2]:
class MappedModel(jft.Model):
    def __init__(self, correlated_field, cf_prefix, ndof, first_axis=True):
        """
        Parameters:
        ----------
        correlated_field : CF Model
        cf_prefix: probably string 
        ndof: int, number of copies
        first_axis: if True prepends the ndof copies els appends 
        """
        self._cf = correlated_field
        keys = correlated_field.domain.keys()
        xi_key = cf_prefix+'xi'
        if xi_key not in keys:
            raise ValueError
            
        xi_dom = correlated_field.domain[xi_key]
        if first_axis:
            new_primals = jft.ShapeWithDtype((ndof,) + xi_dom.shape, xi_dom.dtype)
            axs = 0
            self._out_axs = 0
        else: 
            new_primals = jft.ShapeWithDtype(xi_dom.shape + (ndof,), xi_dom.dtype)
            axs = -1
            self._out_axs = 1
        new_domain = correlated_field.domain.copy()
        new_domain[xi_key] = new_primals
        
        xiinit = partial(jft.random_like, primals= new_primals)
        
        init = correlated_field.init
        init = {k: init[k] if k != xi_key else xiinit for k in keys}
              
        self._axs=({k: axs if k==xi_key else None for k in keys},)
        super().__init__(domain=new_domain, init=jft.Initializer(init))

    def __call__(self, x):
        x = x.tree if isinstance(x, jft.Vector) else x
        return jax.vmap(self._cf, in_axes = self._axs, out_axes=self._out_axs)(x)    

In [3]:
class NCopiesCorrField():
    def __init__(self, cf, N_copies, prefix):
        """
        Parameters:
        ----------
        N_copies: Tuple of int,
                Shape of the new field, implicitely defining the Number of copies. 
                Each of the correlated fields has the same power spectrum,
                but different excitations.
        cf: Correlated Field Model,
            from nifty.re

        returns:
        -------
        Model for multiple CorrelatedFields with the same pspec 
        but different xi_s.     
        """
        self.cf = cf
        self.N_copies = N_copies
        self.number = reduce(lambda x, y: x * y, self.N_copies)
        self.xi_key = prefix+ "xi"
        self.n_xis = cf.domain[self.xi_key].size
        self.target_flatten = ShapeWithDtype(shape=(self.n_xis, self.number))
        self.target = ShapeWithDtype(shape=((self.n_xis,) + self.N_copies))
        self.new_domain = self._extend_xi_domain()
        
    def _partly_apply_cf(self, pos_init):
        def partly(xi, xi_key):
            pos_init.tree.pop(xi_key)
            pos_init.tree.update({xi_key: xi})
            return self.cf(pos_init)
        return partly
        
    def _extend_xi_domain(self):
        dict = self.cf.domain.copy()
        dict.update({self.xi_key: self.target_flatten})
        return dict

    def partial_init_ncopies_model(self, pos_new):
        func = self._partly_apply_cf(pos_new)
        ncopies_model_func = jax.vmap(func, in_axes=(1, None), out_axes=1)
        res = ncopies_model_func(pos_new.tree[self.xi_key], self.xi_key)
        return res.reshape(self.cf.target.shape + self.N_copies)

    def build_model(self):       
        return jft.Model(self.partial_init_ncopies_model,
                         domain=self.new_domain,
                         target=self.target)

In [4]:
def mf_model(freqs, alph, spatial, dev):
    print(type(alph))
    if isinstance(alph, jft.Model):
        plaw = lambda x: jnp.outer(freqs, alph(x)).reshape(freqs.shape + alph.target.shape)
        plaw_offset = lambda x: plaw(x) + spatial(x)
        res = lambda x: plaw_offset(x) + dev(x)
        domain = alph.domain | spatial.domain | dev.domain
        res = jft.Model(res, domain=domain)
    elif isinstance(alph, float):
        res = jnp.outer(freqs, alph).reshape(freqs.shape)
    return res

# Fields

In [14]:
e_dims = (10)
s_dims = (128,128)

RG_Energies = True
if RG_Energies:
    freqs = jnp.arange(-2,10)
else:
    freqs = jnp.array([1,3,4,7,12,17,19.3])

## Spatial field

In [11]:
cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
cf_fl = {
    "fluctuations": (1e-1, 5e-3),
    "loglogavgslope": (-1., 1e-2),
    "flexibility": (1e+0, 5e-1),
    "asperity": (5e-1, 5e-2),
    "harmonic_type": "Fourier"
}
cfm = jft.CorrelatedFieldMaker("space_cf")
cfm.set_amplitude_total_offset(**cf_zm)
cfm.add_fluctuations(
    s_dims,
    distances=1. / s_dims[0],
    **cf_fl,
    prefix="ax1",
    non_parametric_kind="power"
)
correlated_field = cfm.finalize()

## Deviations from Powerlaw

In [12]:
dev_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
dev_fl = {
    "fluctuations": (1e-1, 5e-3),
    "loglogavgslope": (-1., 1e-2),
    "flexibility": (1e+0, 5e-1),
    "asperity": (5e-1, 5e-2),
    "harmonic_type": "Fourier"}
dev_cfm = jft.CorrelatedFieldMaker("dev_cf")
dev_cfm.set_amplitude_total_offset(**cf_zm)
dev_cfm.add_fluctuations(
    e_dims,
    distances=1. / e_dims,
    **cf_fl,
    prefix="ax1",
    non_parametric_kind="power"
)
dev_correlated_field = dev_cfm.finalize()

## Spectral Index $\alpha$

In [13]:
alpha_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
alpha_fl = {
    "fluctuations": (1e-1, 5e-3),
    "loglogavgslope": (-1., 1e-2),
    "flexibility": (1e+0, 5e-1),
    "asperity": (5e-1, 5e-2),
    "harmonic_type": "Fourier"
}
alpha = jft.CorrelatedFieldMaker("alpha")
alpha.set_amplitude_total_offset(**alpha_zm)
alpha.add_fluctuations(
    s_dims,
    distances=1. / s_dims[0],
    **alpha_fl,
    prefix="ax1",
    non_parametric_kind="power"
)
alpha_field = alpha.finalize()

## Margret und Vincent - Code

In [None]:
prefix = "dev_cf"
n_copies_model = NCopiesCorrField(dev_correlated_field, s_dims, prefix)
n_copies_model.number
cf_split = n_copies_model.build_model()

## Derived from Philipps Code

In [None]:
my_cf = ju.MappedModel(correlated_field, "cf", 10, True)

In [None]:
def reshape(model):
    return model.reshape(128, 2, 5)

In [None]:
newmodel = lambda x: reshape(my_cf(x))

In [None]:
key, subkey = random.split(key)
pos_init = jft.Vector(jft.random_like(subkey, my_cf.domain))

In [16]:
key, subkey = random.split(key)
pos_init = jft.Vector(jft.random_like(subkey, dev_field.domain))

In [17]:
dev = jft.WienerProcess((0, 1), (1, 3), 2, name="margret", N_steps= 10)

In [18]:
key, subkey = random.split(key)
pos_init = jft.Vector(jft.random_like(subkey, dev.domain))

# Ncopies with Nifty

we could remove this?

In [8]:
import nifty8 as ift

sp1 = ift.RGSpace([10])
cfmaker = ift.CorrelatedFieldMaker('', total_N=100)
cfmaker.add_fluctuations(sp1, (0.1, 1e-2), (2, .2), (.01, .5), (-4, 2.),
                             'amp1')
cfmaker.set_amplitude_total_offset(0., (1e-2, 1e-6))
correlated_field = cfmaker.finalize()

Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset amplitude: 1.00E-02 ± 7.77E-07
Offset ampli