### Here we want to test whether its possible to find an approximate parameter transformation which enables the metric to be flattened

In particular, we want to find a $J$ such that

$$
I \sim J\, g\, J^{\mathrm{T}}
$$

where $g$ is the metric in the original parameter space. We want to do this because we found that for 6D parameter spaces the metric is poorly conditioned, making the calculation of the determinant very unstable. Hopefully we can fix that!

If we decompose the metric as 

$$
g = U^{\mathrm{T}} \, U
$$

then $J = U^{-1}$. Note that once we have found this tranformation we can also transform the parameters as 

$$
\theta_n = \theta_o J
$$

and similarly

$$
\theta_o = \theta_n (J)^{-1}
$$

This is enforced since we want $d\theta_n = d\theta_n g d\theta_n = d\theta_o J g J d\theta_o$

We will start by taking a single point in the parameter space and trying this

In [10]:
from functools import partial
from typing import Callable
import numpy as np
from diffbank.bank import Bank
from diffbank.utils import gen_templates_rejection
from jax import random
import jax
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, inv, det, eigh
from jax import jit
from diffbank.waveforms import kappa6D_modified
from diffbank.waveforms import kappa6D
from diffbank.metric import get_g

In [2]:
f_u = 512.0  # Hz
f_l = 32.0  # Hz

Mt_range = (2, 9)
eta_range = (0.139, 0.25)
chi1_range = (-0.8, 0.8)
chi2_range = (-0.8, 0.8)
k1_range = (0.0, 1.0)
k2_range = (0.0, 1.0)

def get_Sn_aLIGO() -> Callable[[jnp.ndarray], jnp.ndarray]:
    """
    Get interpolator for noise curve.
    """
    xp, yp = np.loadtxt("../scripts/LIGO-P1200087-v18-aLIGO_MID_LOW.txt", unpack=True)
    return lambda f: jnp.interp(f, xp, yp, left=jnp.inf, right=jnp.inf)


def propose(key, n):
    """
    Proposal distribution for var rejection sampling.
    """
    return random.uniform(
        key,
        shape=(
            n,
            6,
        ),
        minval=jnp.stack(
            (Mt_range[0], eta_range[0], chi1_range[0], chi2_range[0], k1_range[0], k2_range[0])
        ),
        maxval=jnp.stack(
            (Mt_range[1], eta_range[1], chi1_range[1], chi2_range[1], k1_range[1], k2_range[1])
        ),
    )

In [3]:
# Lets just check to see if we can get bad metrics
fs = jnp.linspace(f_l, f_u, 3000)
Sn_aLIGO = get_Sn_aLIGO()
mm = 0.95
eta = 0.95

bank = Bank(
    kappa6D.Amp,
    kappa6D.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="6D",
)

bank_modified = Bank(
    kappa6D_modified.Amp,
    kappa6D_modified.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="6D",
)

seed = 10
key = random.PRNGKey(seed)
N = 100
thetas = propose(key, N)

In [12]:
#This is how we can calculate the appropriate coordinate transform
gs = jax.lax.map(bank.g_fun, thetas)
g_average = gs.sum(axis=0)/gs.shape[0]
print(eigh(g_average))
U = cholesky(g_average, lower=True)
J = inv(U)
np.save("../src/diffbank/waveforms/J_kappa6D.npy", np.array(J))
print(J)

(DeviceArray([2.70914640e-03, 1.13603599e-01, 4.11648826e+00,
             6.34754843e+01, 1.19431736e+04, 1.06374063e+07],            dtype=float64), DeviceArray([[ 4.53125182e-05, -1.35201010e-04, -1.50428240e-02,
               5.02528785e-02,  9.92985174e-01,  1.05965961e-01],
             [-3.03676102e-06, -2.57621277e-05,  2.36659901e-03,
              -1.05521382e-02, -1.05542378e-01,  9.94356001e-01],
             [-1.89364933e-03, -2.67610322e-02,  3.83778645e-01,
              -9.21499299e-01,  5.29864098e-02, -5.06903117e-03],
             [ 1.49128044e-02,  7.75197872e-03, -9.23021175e-01,
              -3.84338849e-01,  5.60531614e-03, -1.28659378e-03],
             [-1.34355414e-02, -9.99519298e-01, -1.76259399e-02,
               2.16266179e-02, -1.50611254e-03,  8.56547839e-05],
             [ 9.99796733e-01, -1.35981207e-02,  1.42583495e-02,
               4.27569886e-03, -4.88138297e-05,  8.95843038e-06]],            dtype=float64))
[[ 2.76056185e-03 -7.18429215e-20  

In [5]:
k = 0
print(gs[k])
print(jnp.dot(jnp.dot(J,gs[k]),J.T))
print(jnp.sqrt(det(jnp.dot(jnp.dot(J,gs[k]),J.T))))

[[ 1.42557956e+04  3.14955033e+05 -3.24663204e+03 -4.97371980e+02
   6.09051870e+01  2.00263795e-01]
 [ 3.14955033e+05  6.96120514e+06 -7.13727604e+04 -1.09358693e+04
   1.33675510e+03  4.39561489e+00]
 [-3.24663204e+03 -7.13727604e+04  7.83503795e+02  1.19804031e+02
  -1.49656904e+01 -4.91844246e-02]
 [-4.97371980e+02 -1.09358693e+04  1.19804031e+02  1.83202001e+01
  -2.28699358e+00 -7.51627567e-03]
 [ 6.09051870e+01  1.33675510e+03 -1.49656904e+01 -2.28699358e+00
   2.87462515e-01  9.44598523e-04]
 [ 2.00263795e-01  4.39561489e+00 -4.91844246e-02 -7.51627567e-03
   9.44598523e-04  3.10395251e-06]]
[[ 0.10863917  0.54316065 -0.23767474  0.32689749 -0.09772111 -0.11689447]
 [ 0.54316065  2.71859515 -1.13524838  1.59219122 -0.48770107 -0.58064566]
 [-0.23767474 -1.13524838  1.46832939 -1.46936072  0.22968633  0.32348672]
 [ 0.32689749  1.59219122 -1.46936072  1.58344371 -0.30657127 -0.40561096]
 [-0.09772111 -0.48770107  0.22968633 -0.30657127  0.08878457  0.10631948]
 [-0.11689447 -0.5

In [6]:
# Here is a demonstration of the problem
# For some points in the parameter space, the density is nan
# This is due to a poorly conditioned metric
for i in range(0, 10):
    print(bank.density_fun(thetas[i]))

1.1833529757553127e-13
1.3104750787224564e-14
2.8236843549070744e-14
2.584297165062539e-13
7.899986250320024e-14
2.51318622794933e-15
nan
2.526565693013264e-15
1.4286877778592906e-13
1.6849411646749166e-14


In [13]:
# This should no longer get nans
for i in range(0, 3):
    theta_n = jnp.dot(inv(J), thetas[i])
#     print("Original: %.25f" % jnp.sqrt(det(jnp.dot(jnp.dot(J,gs[i]),J.T))))
    print("Modified: %.25f" % jnp.sqrt(det(bank_modified.g_fun(theta_n))))
    print(eigh(bank_modified.g_fun(theta_n)))

Modified: nan
(DeviceArray([-4.12596168e-15,  2.93631775e-13,  1.84056168e-07,
              5.24662279e-04,  1.09225239e+00,  5.23865100e+02],            dtype=float64), DeviceArray([[-1.13993311e-01,  3.03399303e-02, -8.37075228e-04,
               1.75150685e-02,  1.62466205e-01, -9.79480611e-01],
             [-9.59582852e-01,  2.55467929e-01,  3.09282570e-03,
              -3.95197012e-03, -2.23732433e-02,  1.15806480e-01],
             [ 6.47933083e-03,  6.27931750e-03,  6.09947978e-01,
              -3.02546984e-01, -7.21410602e-01, -1.26151177e-01],
             [-2.55574773e-03, -9.38213816e-03, -7.73908525e-01,
              -4.25645377e-01, -4.61335752e-01, -8.34648194e-02],
             [-4.32643242e-03, -1.84258232e-02,  1.70109674e-01,
              -8.52436185e-01,  4.89625251e-01,  6.57581467e-02],
             [ 2.57175121e-01,  9.66099716e-01, -9.02731235e-03,
              -1.79301553e-02,  1.03610750e-02,  1.40072973e-03]],            dtype=float64))
Modified: nan
(