In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, jit

from functools import partial
import gwjax
import gwjax.imrphenom

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass, matched_filter, sigma, get_cutoff_indices
from pycbc.psd import interpolate, inverse_spectrum_truncation
from pycbc.waveform import get_fd_waveform

dynfac = 1.0e23

  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
# Define the conditiion data function
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

# Define the matched filter function
def matched_filter(signal_duration, sampling_rate, kmin, kmax, fcore, template):
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.fft(template) #WHY did I change this from a rfft to fft?
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)   
    return result

# Define the waveform template function
def waveform_template(freqs, m1, m2, s1, s2):
    params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': m1, 'm2': m2, 'spin1': s1, 'spin2': s2,
            'ra': 0., 'dec': 0., 'pol': 0.}
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

def snr(invpsd, matched_output, template):
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)

    # Compute the SNR
    snr = matched_output / sigma_squared**0.5
    
    snr_max = len(snr) - 2
    snr_min = 2+4
    snr = snr[snr_min:snr_max]
    peak = jnp.argmin(jnp.absolute(snr))
    snrp = jnp.absolute(snr[peak])
    return snrp

# Define the objective function to minimize (negative SNR)
def objective(params):
    m1, m2, s1, s2 = params
    template, _ = my_waveform_template(m1, m2, s1, s2)
    matched_output = my_matched_filter(template)
    snr_peak = my_snr(matched_output, template)
    return -snr_peak

# Perform gradient descent
def gradient_descent(initial_m1, initial_m2, initial_s1, initial_s2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr):
 
    # Initialize the parameters
    m1, m2, s1, s2, = initial_m1, initial_m2, initial_s1, initial_s2

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2, grad_s1, grad_s2 = objective_grad((m1, m2, s1, s2))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2
        s1 -= learning_rate * grad_s1
        s2 -= learning_rate * grad_s2
        
    return m1, m2, s1, s2

In [3]:
sampling_rate = 2048  # Sampling rate in Hz
signal_duration = 8  # Duration of the signal in seconds 4 or 8

# Get the data and estimate the PSDs
merger = Merger("GW150914")
data = condition(merger.strain('H1'), sampling_rate)

invpsd = estimate_psd(data, data.delta_f)**(-1) #WHY NOT 1/signal_duration which is the example given in GWjax
#idx = int(sampling_rate * (merger.time - data.start_time))
#data[idx-4*sampling_rate:idx+4*sampling_rate]
fcore = data.to_frequencyseries()*invpsd #this is what goes to multiply with template^*

#Get the frequency range
nyquist = sampling_rate//2
freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration

#Determine the low freq and high frequency cut off
kmin, kmax = 15*signal_duration, 900*signal_duration
#kmin, kmax = get_cutoff_indices(flow=15, fhigh=900, df=data.delta_f, N=signal_duration) # WHY DID you avoid using this?
#And I don't understand what the difference is between freq and kmin

#Make everything the same length
fcore = jnp.asarray(fcore[kmin:kmax])
freqs = freqs[kmin:kmax]
invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 

#Get the stage ready for the template
my_waveform_template = partial(waveform_template, freqs)
my_matched_filter = partial(matched_filter, signal_duration, sampling_rate, kmin, kmax, fcore)
my_snr = partial(snr, invpsd)

# Set the initial mass parameters for the template
initial_m1 = 39.0 #from the Optuna section of GWtuna
initial_m2 = 34.0 #from t he Optuna section of GWtuna
initial_s1 = 0.0
initial_s2 = 0.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.001
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2, optimized_s1, optimized_s2 = gradient_descent(initial_m1, initial_m2, initial_s1, initial_s2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr)

print("Optimized Mass and Spin Parameters: m1 =", optimized_m1, "m2 =", optimized_m2, "s1 =", optimized_s1, "s2 =", optimized_s2)

Optimized Mass and Spin Parameters: m1 = 38.99980894159089 m2 = 33.99980487101437 s1 = -0.0012025655172046607 s2 = -0.0010066341932189866
