### Here we want to test whether its possible to find an approximate parameter transformation which enables the metric to be flattened

In particular, we want to find a $J$ such that

$$
I \sim J\, g\, J^{\mathrm{T}}
$$

where $g$ is the metric in the original parameter space. We want to do this because we found that for 6D parameter spaces the metric is poorly conditioned, making the calculation of the determinant very unstable. Hopefully we can fix that!

If we decompose the metric as 

$$
g = U^{\mathrm{T}} \, U
$$

then $J = U^{-1}$. Note that once we have found this tranformation we can also transform the parameters as 

$$
\theta_n = \theta_o J
$$

and similarly

$$
\theta_o = \theta_n (J)^{-1}
$$

This is enforced since we want $d\theta_n = d\theta_n g d\theta_n = d\theta_o J g J d\theta_o$

We will start by taking a single point in the parameter space and trying this

In [1]:
from functools import partial
from typing import Callable
import numpy as np
from diffbank.bank import Bank
from diffbank.utils import gen_templates_rejection
from jax import random
import jax
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, inv, det
from jax import jit
from diffbank.metric import get_g
from diffbank.waveforms import taylorF2_modified
from diffbank.waveforms import taylorF2



[[ 2.62024325e-03  3.60693965e-20  0.00000000e+00  0.00000000e+00]
 [-5.71936909e-03  7.09145187e-04 -0.00000000e+00 -0.00000000e+00]
 [-3.44132377e-03  1.15421576e-03  1.55187850e-01  0.00000000e+00]
 [ 1.22738751e-02 -1.74407032e-03 -3.06029146e-01  7.46873242e-01]]


In [2]:
f_u = 512.0  # Hz
f_l = 32.0  # Hz

Mt_range = (2, 9)
eta_range = (0.139, 0.25)
chi1_range = (-0.8, 0.8)
chi2_range = (-0.8, 0.8)
# k1_range = (0.0, 1.0)
# k2_range = (0.0, 1.0)

def get_Sn_aLIGO() -> Callable[[jnp.ndarray], jnp.ndarray]:
    """
    Get interpolator for noise curve.
    """
    xp, yp = np.loadtxt("../scripts/LIGO-P1200087-v18-aLIGO_MID_LOW.txt", unpack=True)
    return lambda f: jnp.interp(f, xp, yp, left=jnp.inf, right=jnp.inf)


def propose(key, n):
    """
    Proposal distribution for var rejection sampling.
    """
    return random.uniform(
        key,
        shape=(
            n,
            4,
        ),
        minval=jnp.stack(
            (Mt_range[0], eta_range[0], chi1_range[0], chi2_range[0])
        ),
        maxval=jnp.stack(
            (Mt_range[1], eta_range[1], chi1_range[1], chi2_range[1])
        ),
    )

In [3]:
# Lets just check to see if we can get bad metrics
fs = jnp.linspace(f_l, f_u, 3000)
Sn_aLIGO = get_Sn_aLIGO()
mm = 0.95
eta = 0.95

bank = Bank(
    taylorF2.Amp,
    taylorF2.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="6D",
)

bank_modified = Bank(
    taylorF2_modified.Amp,
    taylorF2_modified.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="6D",
)

seed = 10
key = random.PRNGKey(seed)
N = 10
thetas = propose(key, N)

In [4]:
#This is how we can calculate the appropriate coordinate transform
gs = jax.lax.map(bank.g_fun, thetas)
g_average = gs.sum(axis=0)/gs.shape[0]

# First lets check that the matrix is symmetric and positive definite
def check_symmetric(a, rtol=1e-15, atol=1e-15):
    return np.allclose(a, a.transpose(), rtol=rtol, atol=atol)

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

print(g_average)
print(check_symmetric(np.array(g_average)))
print(is_pos_def(np.array(g_average)))

# Now lets try to take the decomposition
L = cholesky(g_average, lower=True)
J = inv(L)
G = jnp.dot(jnp.dot(J,g_average),J.transpose())
print(G)
print(check_symmetric(np.array(G)))
print(is_pos_def(np.array(G)))

print(J)
np.save("../src/diffbank/waveforms/J_taylorF2.npy", np.array(J))

[[ 1.45652110e+05  1.17470751e+06 -5.50706676e+03 -1.90697184e+03]
 [ 1.17470751e+06  1.14627222e+07 -5.92050585e+04 -1.67965492e+04]
 [-5.50706676e+03 -5.92050585e+04  3.59742179e+02  9.96512551e+01]
 [-1.90697184e+03 -1.67965492e+04  9.96512551e+01  3.47403999e+01]]
True
True
[[ 1.00000000e+00  0.00000000e+00  0.00000000e+00 -1.33226763e-15]
 [ 5.08632043e-17  1.00000000e+00 -1.33226763e-15  1.66533454e-15]
 [-2.97887169e-16 -6.39710596e-16  1.00000000e+00  2.44249065e-15]
 [-1.19154867e-15  1.31094036e-15  6.08148261e-15  1.00000000e+00]]
False
True
[[ 2.62024325e-03  3.60693965e-20  0.00000000e+00  0.00000000e+00]
 [-5.71936909e-03  7.09145187e-04 -0.00000000e+00 -0.00000000e+00]
 [-3.44132377e-03  1.15421576e-03  1.55187850e-01  0.00000000e+00]
 [ 1.22738751e-02 -1.74407032e-03 -3.06029146e-01  7.46873242e-01]]


In [5]:
k = 2
print(gs[k])
print(jnp.dot(jnp.dot(J,gs[k]),J.T))
print(jnp.sqrt(det(jnp.dot(jnp.dot(J,gs[k]),J.T))))

[[ 3.69727030e+04  5.64594157e+05 -4.10430826e+03 -8.87737789e+02]
 [ 5.64594157e+05  8.62168919e+06 -6.26555093e+04 -1.35521659e+04]
 [-4.10430826e+03 -6.26555093e+04  4.79703630e+02  1.03578884e+02]
 [-8.87737789e+02 -1.35521659e+04  1.03578884e+02  2.23663881e+01]]
[[ 0.25384255  0.49501293 -0.29480446  0.16276106]
 [ 0.49501293  0.96532219 -0.57272103  0.31529033]
 [-0.29480446 -0.57272103  0.92953622 -0.75851998]
 [ 0.16276106  0.31529033 -0.75851998  0.65678329]]
1.8979303567345375e-07


In [6]:
# Here is a demonstration of the problem
# For some points in the parameter space, the density is nan
# This is due to a poorly conditioned metric
for i in range(0, 3):
    print(bank.g_fun(thetas[i]))

[[ 4.17103455e+03  7.24669526e+04 -8.03075279e+02 -3.00470073e+02]
 [ 7.24669526e+04  1.25904410e+06 -1.39634368e+04 -5.22429251e+03]
 [-8.03075279e+02 -1.39634368e+04  1.63315375e+02  6.10106652e+01]
 [-3.00470073e+02 -5.22429251e+03  6.10106652e+01  2.27931471e+01]]
[[ 5.08688687e+03  7.14689778e+04 -6.52827088e+02 -3.66682396e+02]
 [ 7.14689778e+04  1.00414215e+06 -9.16102559e+03 -5.14572159e+03]
 [-6.52827088e+02 -9.16102559e+03  8.80670490e+01  4.94180293e+01]
 [-3.66682396e+02 -5.14572159e+03  4.94180293e+01  2.77310298e+01]]
[[ 3.69727030e+04  5.64594157e+05 -4.10430826e+03 -8.87737789e+02]
 [ 5.64594157e+05  8.62168919e+06 -6.26555093e+04 -1.35521659e+04]
 [-4.10430826e+03 -6.26555093e+04  4.79703630e+02  1.03578884e+02]
 [-8.87737789e+02 -1.35521659e+04  1.03578884e+02  2.23663881e+01]]


In [7]:
# This should no longer get nans
for i in range(0, 10):
    theta_n = jnp.dot(inv(J), thetas[i])
    print("Original: %.15f" % jnp.sqrt(det(jnp.dot(jnp.dot(J,gs[i]),J.T))))
    print("Modified: %.15f" % jnp.sqrt(det(bank_modified.g_fun(theta_n))))

Original: 0.000000016151077
Modified: 0.000000016171458
Original: 0.000000007974249
Modified: 0.000000007972787
Original: 0.000000189793036
Modified: 0.000000189920159
Original: 0.000000073346759
Modified: 0.000000073429506
Original: 0.000000124682961
Modified: 0.000000124727025
Original: 0.000000011257714
Modified: 0.000000011265104
Original: 0.000000013080073
Modified: 0.000000013095189
Original: 0.000001054239783
Modified: 0.000001055337227
Original: 0.000000690776731
Modified: 0.000000691158452
Original: 0.000000195630018
Modified: 0.000000195575684
