### Here we want to test whether its possible to find an approximate parameter transformation which enables the metric to be flattened

In particular, we want to find a $J$ such that

$$
I \sim J\, g\, J^{\mathrm{T}}
$$

where $g$ is the metric in the original parameter space. We want to do this because we found that for 6D parameter spaces the metric is poorly conditioned, making the calculation of the determinant very unstable. Hopefully we can fix that!

If we decompose the metric as 

$$
g = U^{\mathrm{T}} \, U
$$

then $J = U^{-1}$. Note that once we have found this tranformation we can also transform the parameters as 

$$
\theta_n = \theta_o J
$$

and similarly

$$
\theta_o = \theta_n (J)^{-1}
$$

This is enforced since we want $d\theta_n = d\theta_n g d\theta_n = d\theta_o J g J d\theta_o$

We will start by taking a single point in the parameter space and trying this

In [1]:
from functools import partial
from typing import Callable
import numpy as np
from diffbank.bank import Bank
from diffbank.utils import gen_templates_rejection
from jax import random
import jax
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, inv, det
from jax import jit
from diffbank.metric import get_g, get_metric_ellipse
# from diffbank.waveforms import taylorF2
from diffbank.waveforms import twoPN_simple

In [2]:
f_u = 512.0  # Hz
f_l = 32.0  # Hz

Mt_range = (2, 9)
eta_range = (0.139, 0.25)
# chi1_range = (-0.8, 0.8)
# chi2_range = (-0.8, 0.8)
# k1_range = (0.0, 1.0)
# k2_range = (0.0, 1.0)

def get_Sn_aLIGO() -> Callable[[jnp.ndarray], jnp.ndarray]:
    """
    Get interpolator for noise curve.
    """
    xp, yp = np.loadtxt("../scripts/LIGO-P1200087-v18-aLIGO_MID_LOW.txt", unpack=True)
    return lambda f: jnp.interp(f, xp, yp, left=jnp.inf, right=jnp.inf)


def propose(key, n):
    """
    Proposal distribution for var rejection sampling.
    """
    return random.uniform(
        key,
        shape=(
            n,
            2,
        ),
        minval=jnp.stack(
            (Mt_range[0], eta_range[0])
        ),
        maxval=jnp.stack(
            (Mt_range[1], eta_range[1])
        ),
    )

In [3]:
# # Lets just check to see if we can get bad metrics
# fs = jnp.linspace(f_l, f_u, 50000)
# Sn_aLIGO = get_Sn_aLIGO()
# mm = 0.95
# eta = 0.95

# def Psi_mod(fs, theta):
#     theta_o = jnp.dot(J, theta)
#     return taylorF2.Psi(fs, theta_o)

# def Amp_mod(fs, theta):
#     theta_o = jnp.dot(J, theta)
#     return taylorF2.Amp(fs, theta_o)

# bank = Bank(
#     taylorF2.Amp,
#     taylorF2.Psi,
#     fs,
#     Sn_aLIGO,
#     sample_base = propose,
#     m_star=1 - mm,
#     eta=eta,
#     name="6D",
# )

# bank_modified = Bank(
#     Amp_mod,
#     Psi_mod,
#     fs,
#     Sn_aLIGO,
#     sample_base = propose,
#     m_star=1 - mm,
#     eta=eta,
#     name="6D",
# )

# seed = 10
# key = random.PRNGKey(seed)
# N = 10
# thetas = propose(key, N)

In [4]:
# Lets just check to see if we can get bad metrics
fs = jnp.linspace(f_l, f_u, 50000)
Sn_aLIGO = get_Sn_aLIGO()
mm = 0.95
eta = 0.95

def Psi_mod(fs, theta):
    theta_o = jnp.dot(J, theta)
    return twoPN_simple.Psi(fs, theta_o)

def Amp_mod(fs, theta):
    theta_o = jnp.dot(J, theta)
    return twoPN_simple.Amp(fs, theta_o)

bank = Bank(
    twoPN_simple.Amp,
    twoPN_simple.Psi,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="2D",
)

bank_modified = Bank(
    Amp_mod,
    Psi_mod,
    fs,
    Sn_aLIGO,
    sample_base = propose,
    m_star=1 - mm,
    eta=eta,
    name="2D",
)

seed = 10
key = random.PRNGKey(seed)
N = 10
thetas = propose(key, N)



In [6]:
#This is how we can calculate the appropriate coordinate transform
gs = jax.lax.map(bank.g_fun, thetas)
g_average = gs[0]
# g_average = gs.mean(axis=0)

# First lets check that the matrix is symmetric and positive definite
def check_symmetric(a, rtol=1e-10, atol=1e-10):
    return np.allclose(a, a.transpose(), rtol=rtol, atol=atol)

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

def condition(x):
    return np.linalg.eigvals(x).max()/np.linalg.eigvals(x).min()

# print(g_average)
print(check_symmetric(np.array(g_average)))
print(is_pos_def(np.array(g_average)))
print(condition(np.array(g_average)))

# Now lets try to take the decomposition
L = cholesky(g_average, lower=True)
J = inv(L)
G = J @ g_average @ J.T
print(check_symmetric(np.array(G)))
print(is_pos_def(np.array(G)))
print(condition(np.array(G)))

print(gs)
print(G)
# print(J)
# np.save("../src/diffbank/waveforms/J_taylorF2.npy", np.array(J))

True
True
2883358.7662019897
True
True
1.00000000002536
[[[1.42715304e+03 2.37818279e+04]
  [2.37818279e+04 3.96334657e+05]]

 [[2.47820025e+05 1.86973821e+06]
  [1.86973821e+06 1.41068564e+07]]

 [[1.54396526e+03 2.60914375e+04]
  [2.60914375e+04 4.40960330e+05]]

 [[1.15488287e+04 1.70688588e+05]
  [1.70688588e+05 2.52284003e+06]]

 [[2.31563916e+04 2.55741528e+05]
  [2.55741528e+05 2.82452301e+06]]

 [[6.49671370e+04 6.59108304e+05]
  [6.59108304e+05 6.68696062e+06]]

 [[7.53240622e+02 1.92903758e+04]
  [1.92903758e+04 4.94091774e+05]]

 [[1.85288146e+05 2.74602564e+06]
  [2.74602564e+06 4.06975374e+07]]

 [[5.10070562e+03 1.23956581e+05]
  [1.23956581e+05 3.01257747e+06]]

 [[1.00171705e+03 1.89542547e+04]
  [1.89542547e+04 3.58688878e+05]]]
[[1.0000000e+00 2.4158453e-13]
 [0.0000000e+00 1.0000000e+00]]


In [7]:
k = 0
print(gs[k])
print(J @ gs[k] @ J.T)
print(jnp.sqrt(det(J @ gs[k] @ J.T)))

[[  1427.15304149  23781.82793658]
 [ 23781.82793658 396334.65725511]]
[[1.0000000e+00 2.4158453e-13]
 [0.0000000e+00 1.0000000e+00]]
0.99999999998732


In [8]:
# Here is a demonstration of the problem
# For some points in the parameter space, the density is nan
# This is due to a poorly conditioned metric
for i in range(0, 3):
    print(bank.g_fun(thetas[i]))

[[  1427.15304149  23781.82793658]
 [ 23781.82793658 396334.65725511]]
[[  247820.0249643   1869738.21332848]
 [ 1869738.21332848 14106856.42578878]]
[[  1543.96525581  26091.43753873]
 [ 26091.43753873 440960.32974307]]


In [9]:
# This should no longer get nans
for i in range(0, 3):
    theta_n = inv(J) @ thetas[i]
#     print("Original the first: %.15f" % jnp.sqrt(det(gs[i])))
#     print("Original: %.15f" % jnp.sqrt(det(J @ gs[i] @ J.T)))
#     print("Modified: %.15f" % bank_modified.density_fun(theta_n))
    print(J @ gs[i] @ J.T)
    print(bank_modified.g_fun(theta_n))

[[1.0000000e+00 2.4158453e-13]
 [0.0000000e+00 1.0000000e+00]]
[[2859046.83433869 -171673.53206241]
 [-171673.53206241   10308.26122112]]
[[ 1.73646426e+02 -9.64748122e+03]
 [-9.64748122e+03  5.36000853e+05]]
[[ 1.01617450e+08 -6.10605837e+06]
 [-6.10605837e+06  3.66904984e+05]]
[[1.08184982 1.54995251]
 [1.54995251 3.30402851]]
[[3181016.23929434 -191004.8459692 ]
 [-191004.8459692    11468.92956227]]


In [10]:
def get_metric_ellipse_simple(g):
    eigval, norm_eigvec = jnp.linalg.eig(g)
    r_major, r_minor = 1 / jnp.sqrt(eigval)
    U = jnp.linalg.inv(norm_eigvec)
    ang = jnp.arccos(U[0, 0] / jnp.linalg.norm(U[:, 0]))

    return jnp.array([r_major, r_minor, ang])

for i in range(0, 1):
    theta_n = inv(J) @ thetas[i]
    print(get_metric_ellipse_simple(J @ gs[i] @ J.T))
    print(get_metric_ellipse(theta_n, Amp_mod, Psi_mod, fs, Sn_aLIGO))

[1.+0.j 1.+0.j 0.-0.j]
[5.90347667e-04-0.j 1.69391974e+03-0.j 5.99737132e-02-0.j]
