### Here we want to test whether its possible to find an approximate parameter transformation which enables the metric to be flattened

In particular, we want to find a $J$ such that

$$
I \sim J\, g\, J^{\mathrm{T}}
$$

where $g$ is the metric in the original parameter space. We want to do this because we found that for 6D parameter spaces the metric is poorly conditioned, making the calculation of the determinant very unstable. Hopefully we can fix that!

If we decompose the metric as 

$$
g = U^{\mathrm{T}} \, U
$$

then $J = U^{-1}$. Note that once we have found this tranformation we can also transform the parameters as 

$$
\theta_n = \theta_o J
$$

and similarly

$$
\theta_o = \theta_n (J)^{-1}
$$

This is enforced since we want $d\theta_n = d\theta_n g d\theta_n = d\theta_o J g J d\theta_o$

We will start by taking a single point in the parameter space and trying this

In [6]:
from functools import partial
from typing import Callable
import numpy as np
from diffbank.bank import Bank
from diffbank.utils import gen_templates_rejection
from jax import random
import jax
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, inv, det
from jax import jit
from diffbank.metric import get_g
from diffbank.waveforms import taylorF2_modified
from diffbank.waveforms import taylorF2

In [7]:
f_u = 512.0  # Hz
f_l = 32.0  # Hz

Mt_range = (2, 9)
eta_range = (0.139, 0.25)
chi1_range = (-0.8, 0.8)
chi2_range = (-0.8, 0.8)
# k1_range = (0.0, 1.0)
# k2_range = (0.0, 1.0)

def get_Sn_aLIGO() -> Callable[[jnp.ndarray], jnp.ndarray]:
    """
    Get interpolator for noise curve.
    """
    xp, yp = np.loadtxt("../scripts/LIGO-P1200087-v18-aLIGO_MID_LOW.txt", unpack=True)
    return lambda f: jnp.interp(f, xp, yp, left=jnp.inf, right=jnp.inf)


def propose(key, n):
    """
    Proposal distribution for var rejection sampling.
    """
    return random.uniform(
        key,
        shape=(
            n,
            4,
        ),
        minval=jnp.stack(
            (Mt_range[0], eta_range[0], chi1_range[0], chi2_range[0])
        ),
        maxval=jnp.stack(
            (Mt_range[1], eta_range[1], chi1_range[1], chi2_range[1])
        ),
    )

In [8]:
# Lets just check to see if we can get bad metrics
fs = jnp.linspace(f_l, f_u, 3000)
Sn_aLIGO = get_Sn_aLIGO()
mm = 0.95
eta = 0.95

bank = Bank(
    taylorF2.Amp,
    taylorF2.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="6D",
)

bank_modified = Bank(
    taylorF2_modified.Amp,
    taylorF2_modified.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="6D",
)

seed = 10
key = random.PRNGKey(seed)
N = 100
thetas = propose(key, N)

In [9]:
#This is how we can calculate the appropriate coordinate transform
gs = jax.lax.map(bank.g_fun, thetas)
g_average = gs.sum(axis=0)/gs.shape[0]
print(jnp.sqrt(det(g_average)))
U = cholesky(g_average, lower=True)
J = inv(U)
np.save("../src/diffbank/waveforms/J_taylorF2.npy", np.array(J))
print(str(J))

4472605.888955843
[[ 2.66386507e-03 -6.61029137e-20  0.00000000e+00  0.00000000e+00]
 [-1.10381424e-02  1.23357098e-03 -0.00000000e+00 -0.00000000e+00]
 [-8.79594991e-03  1.55144904e-03  1.29331445e-01  0.00000000e+00]
 [ 6.05216804e-03 -9.44200422e-04 -1.91455425e-01  5.26088577e-01]]


In [12]:
k = 2
print(gs[k])
print(jnp.dot(jnp.dot(J,gs[k]),J.T))
print(jnp.sqrt(det(jnp.dot(jnp.dot(J,gs[k]),J.T))))

[[ 5.64962913e+04  5.67662064e+05 -3.51821663e+03 -1.43829192e+03]
 [ 5.67662064e+05  5.70376433e+06 -3.53343372e+04 -1.44452649e+04]
 [-3.51821663e+03 -3.53343372e+04  2.30005325e+02  9.39362523e+01]
 [-1.43829192e+03 -1.44452649e+04  9.39362523e+01  3.83652297e+01]]
[[ 0.40090769  0.20415154 -0.18981591 -0.73828754]
 [ 0.20415154  0.10399406 -0.09407525 -0.37560513]
 [-0.18981591 -0.09407525  0.27886776  0.37506446]
 [-0.73828754 -0.37560513  0.37506446  1.3630481 ]]
7.619052255881044e-08


In [7]:
# Here is a demonstration of the problem
# For some points in the parameter space, the density is nan
# This is due to a poorly conditioned metric
for i in range(0, 3):
    print(bank.g_fun(thetas[i]))

[[ 2.77904508e+04  3.30334042e+05 -2.84905117e+03 -9.60952951e+02]
 [ 3.30334042e+05  3.92746462e+06 -3.37444892e+04 -1.13824247e+04]
 [-2.84905117e+03 -3.37444892e+04  3.08095054e+02  1.03813144e+02]
 [-9.60952951e+02 -1.13824247e+04  1.03813144e+02  3.49807125e+01]]
[[ 1.37168910e+05  1.68264520e+06 -1.07122133e+04 -2.17576478e+03]
 [ 1.68264520e+06  2.06428547e+07 -1.31110203e+05 -2.66311238e+04]
 [-1.07122133e+04 -1.31110203e+05  8.82297979e+02  1.79013196e+02]
 [-2.17576478e+03 -2.66311238e+04  1.79013196e+02  3.63215879e+01]]
[[ 5.64962913e+04  5.67662064e+05 -3.51821663e+03 -1.43829192e+03]
 [ 5.67662064e+05  5.70376433e+06 -3.53343372e+04 -1.44452649e+04]
 [-3.51821663e+03 -3.53343372e+04  2.30005325e+02  9.39362523e+01]
 [-1.43829192e+03 -1.44452649e+04  9.39362523e+01  3.83652297e+01]]


In [17]:
# This should no longer get nans
for i in range(0, 10):
    theta_n = jnp.dot(inv(J_temp), thetas[i])
    print("Original: %.15f" % jnp.sqrt(det(jnp.dot(jnp.dot(J,gs[i]),J.T))))
    print("Modified: %.15f" % jnp.sqrt(det(bank_modified.g_fun(theta_n))))

Original: 0.000000060299249
Modified: 0.000000060280087
Original: 0.000000556326029
Modified: 0.000000556048538
Original: 0.000000076190523
Modified: 0.000000076184671
Original: 0.000000097839188
Modified: 0.000000097877756
Original: 0.000000005008866
Modified: 0.000000005011089
Original: 0.000000006787082
Modified: 0.000000006788318
Original: 0.000000056069658
Modified: 0.000000056024654
Original: 0.000000091258262
Modified: 0.000000091265790
Original: 0.000000008855287
Modified: 0.000000008855856
Original: 0.000000228396470
Modified: 0.000000228152456
