Making a windowed prior

In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tinygp
from tinygp import GaussianProcess, kernels
from stingray import Lightcurve
jax.config.update("jax_enable_x64", True)
import functools



In [7]:
# Making the Data:
from gpmodelling import get_kernel, get_mean
import warnings
warnings.filterwarnings("ignore")

Times = np.linspace(0,1,256)
hqpoparams = {
    "arn" : jnp.exp(1.0),    "crn" : jnp.exp(1.0),
    "aqpo": jnp.exp(-0.4),    "cqpo": jnp.exp(1),    "freq": 20,}
mean_params = {"A" : 3,    "t0" : 0.5,    "sig" : 0.2,}
kernel = get_kernel(kernel_type = "QPO_plus_RN", kernel_params = hqpoparams)
mean = get_mean(mean_type = "gaussian",  mean_params = mean_params)

gp = GaussianProcess(kernel = kernel, X = Times, mean_value = mean(Times))

counts = gp.sample(key = jax.random.PRNGKey(6))

In [40]:
# Trying out tinygp:
print("log prob: ", gp.log_probability(counts)) 

window_minimum = 0.2; a = jnp.searchsorted(Times, window_minimum)
window_maximum = 0.65; b = jnp.searchsorted(Times, window_maximum)

gp2 = GaussianProcess(kernel = kernel, X = Times[a:b], mean_value = mean(Times[a:b]))
print("log prob windowed: ", gp2.log_probability(counts[a:b]))

wind_counts = jnp.where((Times < window_minimum) & (Times > window_maximum), counts, 0)
print(wind_counts.shape)
print("log prob wind : ", gp.log_probability(wind_counts))


log prob:  -58.273614459545655
log prob windowed:  -26.10427910528807
(256,)
log prob wind :  63.80927271414357


In [8]:
# Calculating evidences using jaxns
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp
from jaxns import ExactNestedSampler, TerminationCondition, analytic_log_evidence, Prior, Model
from jaxns.special_priors import ForcedIdentifiability

tfpd = tfp.distributions

In [28]:
T = Times[-1] - Times[0]    # Total time
f = 1/(Times[1] - Times[0]) # Sampling frequency
min = jnp.min(counts)
max = jnp.max(counts)
span = max - min

def RNprior_model():
    arn = yield Prior(tfpd.Uniform(0.1*span, 2*span), name='arn') 
    crn = yield Prior(tfpd.Uniform(jnp.log(1/T), jnp.log(f)), name='crn')
    A = yield Prior(tfpd.Uniform(0.1*span, 2*span), name='A') 
    t0 = yield Prior(tfpd.Uniform(Times[0]-0.1*T, Times[-1]+0.1*T), name='t0')
    sig = yield Prior(tfpd.Uniform(0.5*1/f, 2*T), name='sig')
    t_window = yield ForcedIdentifiability(n = 2, low = Times[0], high = Times[-1], name='t_window')
    return arn, crn, A, t0, sig, t_window

def RNlog_likelihood1(arn, crn, A, t0, sig):
    rnlikelihood_params = {"arn": arn, "crn": crn,
                        "aqpo": 0.0, "cqpo": 0.0, "freq": 0.0, }
                
    mean_params = { "A": A, "t0": t0, "sig": sig, }
    
    kernel = get_kernel(kernel_type = "RN", kernel_params = rnlikelihood_params)
    mean = get_mean(mean_type = "gaussian",  mean_params = mean_params)
    gp = GaussianProcess(kernel = kernel, X = Times, mean_value = mean(Times))
    return gp.log_probability(counts)

def RNlog_likelihood2(arn, crn, A, t0, sig, t_window):
    rnlikelihood_params = {"arn": arn, "crn": crn,
                        "aqpo": 0.0, "cqpo": 0.0, "freq": 0.0, }            
    mean_params = { "A": A, "t0": t0, "sig": sig, }
    window_minimum = t_window[0] # becomes a tracer value
    window_maximum = t_window[1]
    # Times remains a jnp array

    windowed_indices = jnp.where(jnp.logical_and(window_minimum < Times, Times < window_maximum))[0]
    # Abstract tracer value encountered where concrete value is expected
    # Conretization type error
    kernel = get_kernel(kernel_type = "RN", kernel_params = rnlikelihood_params)
    mean = get_mean(mean_type = "gaussian",  mean_params = mean_params)
    gp = GaussianProcess(kernel = kernel, X = Times[windowed_indices], mean_value = mean(Times[windowed_indices]))
    return gp.log_probability(counts[windowed_indices])

def RNlog_likelihood3(arn, crn, A, t0, sig, t_window):
    rnlikelihood_params = {"arn": arn, "crn": crn,
                        "aqpo": 0.0, "cqpo": 0.0, "freq": 0.0, }       
    mean_params = { "A": A, "t0": t0, "sig": sig, }
    window_minimum = t_window[0] # becomes a tracer value
    window_maximum = t_window[1]

    mask = jnp.logical_and(window_minimum < Times, Times < window_maximum)
    kernel = get_kernel(kernel_type = "RN", kernel_params = rnlikelihood_params)
    mean = get_mean(mean_type = "gaussian",  mean_params = mean_params)
    gp = GaussianProcess(kernel = kernel, X = Times[mask], mean_value = mean(Times[mask]))
    # numpy.ndarray conversion method __array__() was called on the JAX Tracer object Tracedwith with
    # Tracer Array Conversion Error
    return gp.log_probability(counts[mask])

def RNlog_likelihood4(arn, crn, A, t0, sig, t_window):
    rnlikelihood_params = {"arn": arn, "crn": crn,
                        "aqpo": 0.0, "cqpo": 0.0, "freq": 0.0, }       
    mean_params = { "A": A, "t0": t0, "sig": sig, }
    window_minimum = t_window[0]; a = jnp.searchsorted(Times, window_minimum)
    window_maximum = t_window[1]; b = jnp.searchsorted(Times, window_maximum)

    kernel = get_kernel(kernel_type = "RN", kernel_params = rnlikelihood_params)
    mean = get_mean(mean_type = "gaussian",  mean_params = mean_params)
    gp = GaussianProcess(kernel = kernel, X = Times[a:b], mean_value = mean(Times[a:b]))
    # The __index__() method was called on the JAX Tracer object Tracedwith
    # TracerIntegerConversionError
    return gp.log_probability(counts[a:b])

def RNlog_likelihood5(arn, crn, A, t0, sig, t_window):
    rnlikelihood_params = {"arn": arn, "crn": crn,
                        "aqpo": 0.0, "cqpo": 0.0, "freq": 0.0, }       
    mean_params = { "A": A, "t0": t0, "sig": sig, }
    window_minimum = t_window[0]; a = jnp.searchsorted(Times, window_minimum)
    window_maximum = t_window[1]; b = jnp.searchsorted(Times, window_maximum)
    print(a) # a tracer value

    times = jax.lax.dynamic_slice(Times, (a,), (b-a,)); # Requires concrete integer index, not tracer value
    count = jax.lax.dynamic_slice(counts, (a,), (b-a,))
    kernel = get_kernel(kernel_type = "RN", kernel_params = rnlikelihood_params)
    mean = get_mean(mean_type = "gaussian",  mean_params = mean_params)
    gp = GaussianProcess(kernel = kernel, X = times, mean_value = mean(Times))
    # The __index__() method was called on the JAX Tracer object Tracedwith
    # TracerIntegerConversionError
    return gp.log_probability(count)

model = Model(prior_model=RNprior_model, log_likelihood=RNlog_likelihood5)
model.sanity_check(random.PRNGKey(10), S=100)


Traced<ShapedArray(int32[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(int32[100])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0


TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(int32[100])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
This BatchTracer with object id 11101559040 was created on line:
  /var/folders/z9/d9jc5k554dl6jd5l6z3_h5f80000gn/T/ipykernel_93247/3202572062.py:80 (RNlog_likelihood5)

In [25]:
Times2 = jnp.linspace(0,1,25)
window_minimum = 0.2
window_maximum = 0.65

windowed_indices = jnp.where(jnp.logical_and(window_minimum < Times2, Times2 < window_maximum))[0]
print(f"windowed_indices = {windowed_indices}")
print(f"Times2[windowed_indices] = {Times2[windowed_indices]}")

mask = jnp.logical_and(Times2 > window_minimum, Times2 < window_maximum)
print("mask: ", mask)
print("masked times: ", Times2[mask])

a = jnp.searchsorted(Times2, window_minimum)
b = jnp.searchsorted(Times2, window_maximum)
print("a = ", a, "b = ", b)
print("windowd times: ", Times2[a:b])

print("dynamic slice:", jax.lax.dynamic_slice(Times2, (a,), (b-a,)) )


windowed_indices = [ 5  6  7  8  9 10 11 12 13 14 15]
Times2[windowed_indices] = [0.20833333 0.25       0.29166667 0.33333333 0.375      0.41666667
 0.45833333 0.5        0.54166667 0.58333333 0.625     ]
mask:  [False False False False False  True  True  True  True  True  True  True
  True  True  True  True False False False False False False False False
 False]
masked times:  [0.20833333 0.25       0.29166667 0.33333333 0.375      0.41666667
 0.45833333 0.5        0.54166667 0.58333333 0.625     ]
a =  5 b =  16
windowd times:  [0.20833333 0.25       0.29166667 0.33333333 0.375      0.41666667
 0.45833333 0.5        0.54166667 0.58333333 0.625     ]
dynamic slice: [0.20833333 0.25       0.29166667 0.33333333 0.375      0.41666667
 0.45833333 0.5        0.54166667 0.58333333 0.625     ]
