In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, jit

from functools import partial
import gwjax
import gwjax.imrphenom

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass, matched_filter, sigma, get_cutoff_indices
from pycbc.psd import interpolate, inverse_spectrum_truncation
from pycbc.waveform import get_fd_waveform

  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
def snr(signal, template):
    # Compute the matched filter output
    matched_output = matched_filter(signal, template)

    # Compute the signal power
    signal_power = jnp.sum(template**2)

    # Compute the noise power
    noise = signal - matched_output
    noise_power = jnp.sum(noise**2)

    # Compute the SNR
    snr_value = signal_power / noise_power

    return snr_value

# Define the matched filter function
def matched_filter(signal, template):
    signal_fft = jnp.fft.rfft(signal)
    template_fft = jnp.fft.rfft(template)
    result_fft = signal_fft * jnp.conj(template_fft)
    result = jnp.fft.irfft(result_fft)
    return result

# Define the objective function to minimize (negative SNR)
def objective(params):
    signal, template = params
    return -snr(signal, template)

# Perform gradient descent
def gradient_descent(signal, template, learning_rate, num_iterations):
    # Initialize the parameters
    signal_param = signal
    template_param = template

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_signal, grad_template = objective_grad((signal_param, template_param))

        # Update the parameters
        signal_param -= learning_rate * grad_signal
        template_param -= learning_rate * grad_template

    return signal_param, template_param

# Example usage
sampling_rate = 4096  # Sampling rate in Hz
signal_duration = 2.0  # Duration of the signal in seconds

# Generate a random gravitational wave signal and template
np.random.seed(0)
signal = np.random.randn(int(sampling_rate * signal_duration))
template = np.random.randn(int(sampling_rate * signal_duration))

# Normalize the template to have unit power
template /= np.sqrt(np.sum(template**2))

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_signal, optimized_template = gradient_descent(signal, template, learning_rate, num_iterations)

# Compute the SNR of the optimized signal and template
optimized_snr = snr(optimized_signal, optimized_template)

print("Optimized SNR:", optimized_snr)


Optimized SNR: 6.36941214005156e-05


In [3]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit

def snr(signal, template):
    # Compute the matched filter output
    matched_output = matched_filter(signal, template)

    # Compute the signal power
    signal_power = jnp.sum(template**2)

    # Compute the noise power
    noise = signal - matched_output
    noise_power = jnp.sum(noise**2)

    # Compute the SNR
    snr_value = signal_power / noise_power

    return snr_value

# Define the matched filter function
def matched_filter(signal, template):
    signal_fft = jnp.fft.rfft(signal)
    template_fft = jnp.fft.rfft(template)
    result_fft = signal_fft * jnp.conj(template_fft)
    result = jnp.fft.irfft(result_fft)
    return result

# Define the objective function to minimize (negative SNR)
def objective(params):
    signal, template = params
    return -snr(signal, template)

# Perform gradient descent
def gradient_descent(signal, template, learning_rate, num_iterations):
    # Initialize the parameters
    template_param = template

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_template = objective_grad((signal, template_param))[1]

        # Update the parameters
        template_param -= learning_rate * grad_template

    return template_param

# Example usage
sampling_rate = 4096  # Sampling rate in Hz
signal_duration = 2.0  # Duration of the signal in seconds

# Generate a random gravitational wave signal and template
np.random.seed(0)
signal = np.random.randn(int(sampling_rate * signal_duration))
template = np.random.randn(int(sampling_rate * signal_duration))

# Normalize the template to have unit power
template /= np.sqrt(np.sum(template**2))

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_template = gradient_descent(signal, template, learning_rate, num_iterations)

# Compute the SNR of the optimized signal and template
optimized_snr = snr(signal, optimized_template)

print("Optimized SNR:", optimized_snr)


Optimized SNR: 6.369411780012543e-05


In [4]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit

def snr(signal, template):
    # Compute the matched filter output
    matched_output = matched_filter(signal, template[0], template[1])

    # Compute the signal power
    signal_power = jnp.sum(template[0]**2)

    # Compute the noise power
    noise = signal - matched_output
    noise_power = jnp.sum(noise**2)

    # Compute the SNR
    snr_value = signal_power / noise_power

    return snr_value

# Define the matched filter function
def matched_filter(signal, m1, m2):
    template = waveform_template(signal, m1, m2)
    signal_fft = jnp.fft.rfft(signal)
    template_fft = jnp.fft.rfft(template)
    result_fft = signal_fft * jnp.conj(template_fft)
    result = jnp.fft.irfft(result_fft)
    return result

# Define the waveform template function
def waveform_template(signal, m1, m2):
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    template = m1 * signal + m2 * jnp.sin(signal)
    return template

# Define the objective function to minimize (negative SNR)
def objective(params):
    signal, template = params
    return -snr(signal, template)

# Perform gradient descent
def gradient_descent(signal, template, learning_rate, num_iterations):
    # Initialize the parameters
    m1, m2 = template

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2 = objective_grad((signal, (m1, m2)))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2[0]  # Access the first element of grad_m2

    return m1, m2

# Example usage
sampling_rate = 4096  # Sampling rate in Hz
signal_duration = 2.0  # Duration of the signal in seconds

# Generate a random gravitational wave signal
np.random.seed(0)
signal = np.random.randn(int(sampling_rate * signal_duration))

# Set the initial mass parameters for the template
initial_m1 = 30.0
initial_m2 = 20.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2 = gradient_descent(signal, (initial_m1, initial_m2), learning_rate, num_iterations)

# Compute the SNR of the optimized signal and template
optimized_snr = snr(signal, (optimized_m1, optimized_m2))

print("Optimized SNR:", optimized_snr)
print("Optimized Mass Parameters: m1 =", optimized_m1, "m2 =", optimized_m2)


Optimized SNR: 3.185368316425371e-05
Optimized Mass Parameters: m1 = [30.         29.99999999 29.99999996 ... 30.00000001 29.99999999
 30.00000004] m2 = [20. 20. 20. ... 20. 20. 20.]


In [5]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit

def snr(signal, template):
    # Compute the matched filter output
    matched_output = matched_filter(signal, template[0], template[1])

    # Compute the signal power
    signal_power = jnp.sum(template[0]**2)

    # Compute the noise power
    noise = signal - matched_output
    noise_power = jnp.sum(noise**2)

    # Compute the SNR
    snr_value = signal_power / noise_power

    return snr_value

# Define the matched filter function
def matched_filter(signal, m1, m2):
    template = waveform_template(signal, m1, m2)
    signal_fft = jnp.fft.rfft(signal)
    template_fft = jnp.fft.rfft(template)
    result_fft = signal_fft * jnp.conj(template_fft)
    result = jnp.fft.irfft(result_fft)
    return result

# Define the waveform template function
def waveform_template(signal, m1, m2):
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    template = m1 * signal + m2 * jnp.sin(signal)
    return template

# Define the objective function to minimize (negative SNR)
def objective(params):
    signal, template = params
    return -snr(signal, template)

# Perform gradient descent
def gradient_descent(signal, template, learning_rate, num_iterations):
    # Initialize the parameters
    m1, m2 = template

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2 = objective_grad((signal, (m1, m2)))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2[0]  # Access the first element of grad_m2

    return m1, m2

# Example usage
sampling_rate = 4096  # Sampling rate in Hz
signal_duration = 2.0  # Duration of the signal in seconds

# Generate a random gravitational wave signal
np.random.seed(0)
signal = np.random.randn(int(sampling_rate * signal_duration))

# Set the initial mass parameters for the template
initial_m1 = 30.0
initial_m2 = 20.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2 = gradient_descent(signal, (initial_m1, initial_m2), learning_rate, num_iterations)

# Compute the SNR of the optimized signal and template
optimized_snr = snr(signal, (optimized_m1, optimized_m2))

print("Optimized SNR:", optimized_snr)
print("Optimized Mass Parameters: m1 =", optimized_m1[0], "m2 =", optimized_m2[0])


Optimized SNR: 3.185368316425371e-05
Optimized Mass Parameters: m1 = 29.999999996361467 m2 = 20.00000000014849


### Ripple

In [21]:
from ripple.waveforms import IMRPhenomD
from ripple import ms_to_Mc_eta
from functools import partial

In [22]:
m1_msun = 20.0 # In solar masses
m2_msun = 19.0
chi1 = 0.5 # Dimensionless spin
chi2 = -0.5
tc = 0.0 # Time of coalescence in seconds
phic = 0.0 # Time of coalescence
dist_mpc = 440 # Distance to source in Mpc
inclination = 0.0 # Inclination Angle
polarization_angle = 0.2 # Polarization angle

Mc, eta = ms_to_Mc_eta(jnp.array([m1_msun, m2_msun]))
theta_ripple_h0 = jnp.array([Mc, eta, chi1, chi2, dist_mpc, tc, phic])

# Now we need to generate the frequency grid
f_l = 24
f_u = 512
del_f = 0.01
fs = jnp.arange(f_l, f_u, del_f)

# We also need to give a reference frequency
f_ref = f_l

# And finally lets generate the waveform!
h0_ripple = IMRPhenomD.gen_IMRPhenomD(fs, theta_ripple_h0, f_ref)

In [74]:
# Define the conditiion data function
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    dynfac = 1.0e23
    return (dynfac*tmp)

def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd


# Define the waveform template function
def waveform_template(freqs, m1, m2):
    chi1 = 0 # Dimensionless spin
    chi2 = 0
    tc = 0.0 # Time of coalescence in seconds
    phic = 0.0 # Time of coalescence
    dist_mpc = 0 # Distance to source in Mpc
    inclination = 0.0 # Inclination Angle
    polarization_angle = 0 # Polarization angle
    
    Mc, eta = ms_to_Mc_eta(jnp.array([m1, m2]))
    theta_ripple_h0 = jnp.array([Mc, eta, chi1, chi2, dist_mpc, tc, phic])
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    # Now we need to generate the frequency grid
    #f_l = 15
    #f_u = 512
    #del_f = 0.01
    #fs = jnp.arange(f_l, f_u, del_f)

    # We also need to give a reference frequency
    #f_ref = f_l
    #may want to do partial on this
    return IMRPhenomD.gen_IMRPhenomD(freqs, theta_ripple_h0, f_ref=(15*signal_duration))


# Define the matched filter function
def matched_filter(template, kmin, kmax):
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.rfft(template)
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)
    return result

# Define the objective function to minimize (negative SNR)
def objective(template, matched_output):
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)

    # Compute the SNR
    snr_value = matched_output / sigma_squared**0.5
    
    return -snr_value

def snr(template, invpsd, matched_output):
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)

    # Compute the SNR
    snr_value = matched_output / sigma_squared**0.5
    
    return snr_value

# Perform gradient descent
def gradient_descent(initial_m1, initial_m2, learning_rate, num_iterations):
    sampling_rate = 2048  # Sampling rate in Hz
    signal_duration = 8  # Duration of the signal in seconds 4 or 8

    # Get the data and estimate the PSDs
    merger = Merger("GW150914")
    data = condition(merger.strain('H1'), sampling_rate)
    print(data.delta_f)
    invpsd = estimate_psd(data, data.delta_f)**(-1) #WHY NOT 1/signal_duration which is the example given in GWjax
    print(invpsd.delta_f)
    #idx = int(sampling_rate * (merger.time - data.start_time))
    #data[idx-4*sampling_rate:idx+4*sampling_rate]
    fcore = data.to_frequencyseries()*invpsd #this is what goes to multiply with template^*
    print(len(data))
    print(len(invpsd))
    print(len(fcore))
    
    #Get the frequency range
    nyquist = sampling_rate//2
    freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration
    
    #Determine the low freq and high frequency cut off
    kmin, kmax = 15*signal_duration, 900*signal_duration
    #kmin, kmax = get_cutoff_indices(flow=15, fhigh=900, df=data.delta_f, N=signal_duration) # WHY DID you avoid using this?
    #And I don't understand what the difference is between freq and kmin

    #Make everyythin the same length
    fcore = jnp.asarray(fcore[kmin:kmax])
    freqs = freqs[kmin:kmax]
    invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 
    print(f'fcore {len(fcore)}')
    print(f'freqs {len(freqs)}')
    print(f'invpsd {len(invpsd)}')
    
    print(f'fcore {type(fcore)}')
    print(f'freqs {type(freqs)}')
    print(f'invpsd {type(invpsd)}')
    
    #Get the stage ready for the template
    #my_waveform_template = partial(waveform_template, freqs)

    # Initialize the parameters
    m1, m2 = initial_m1, initial_m2
    template = waveform_template(freqs, m1, m2)
    print(template)
    print(f'template {len(template)}')
    print(f'template {type(template)}')
    
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.fft(template) #WHY did I change this from a rfft to fft?
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)   
    print(f'result {len(result)}')
    #from pycbc.types.array.Array import weighted_inner
    #print(f'{type(weighted_inner(np.array(template)))}')
    print
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)
    
    weighted_inner(jnp.array(template))
    print(f'sigma_squared {sigma_squared}')
    # Compute the SNR
    snr_value = result / (sigma_squared**0.5)
    
    print(f'sigma_squared {snr_value}')
    #Backstage is ready for snr
    my_snr = partial(snr, invpsd)
    #Set the scene for Matched-filtering
    my_matched_filter = partial(matched_filter, kmin, kmax)
    
    print(len(my_matched_filter(my_waveform_template(m1, m2))))
    
    
    # Define the gradient of the objective function
    objective_grad = jit(grad(my_snr(my_matched_filter)))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2 = objective_grad((my_waveform_template(m1, m2)))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2  # Access the first element of grad_m2

    return m1, m2


# Set the initial mass parameters for the template
initial_m1 = 30.0
initial_m2 = 20.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2 = gradient_descent(initial_m1, initial_m2, learning_rate, num_iterations)

# Compute the SNR of the optimized signal and template
optimized_snr = snr(signal, (optimized_m1, optimized_m2))

print("Optimized SNR:", optimized_snr)
print("Optimized Mass Parameters: m1 =", optimized_m1, "m2 =", optimized_m2)

0.03571428571428571
0.03571428571428571
57344
28673
28673
fcore 7080
freqs 7080
invpsd 7080
fcore <class 'jaxlib.xla_extension.ArrayImpl'>
freqs <class 'jaxlib.xla_extension.ArrayImpl'>
invpsd <class 'jaxlib.xla_extension.ArrayImpl'>
[-inf-infj  inf-infj -inf+infj ...  nan+nanj  nan+nanj  nan+nanj]
template 7080
template <class 'jaxlib.xla_extension.ArrayImpl'>
result 16384


NameError: name 'weighted_inner' is not defined

### GWJax

In [None]:
from functools import partial

In [38]:
dynfac = 1.0e23

# Define the conditiion data function
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

# Define the matched filter function
def matched_filter(template, kmin, kmax):
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.rfft(template)
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)
    return result

# Define the waveform template function
def waveform_template(freqs, m1, m2):
    params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': m1, 'm2': m2, 'spin1': 0., 'spin2': 0.,
            'ra': 0., 'dec': 0., 'pol': 0.}
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

# Define the objective function to minimize (negative SNR)
def objective(params):
    m1, m2 = params
    return -snr(signal, my_template)

def snr(template, invpsd):
    # Compute the matched filter output
    matched_output = my_matched_filter(template)
    
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)

    # Compute the SNR
    snr_value = matched_output / sigma_squared**0.5
    
    return snr_value

# Perform gradient descent
def gradient_descent(initial_m1, initial_m2, learning_rate, num_iterations):
    sampling_rate = 2048  # Sampling rate in Hz
    signal_duration = 8  # Duration of the signal in seconds 4 or 8

    # Get the data and estimate the PSDs
    merger = Merger("GW150914")
    data = condition(merger.strain('H1'), sampling_rate)
    print(data.delta_f)
    invpsd = estimate_psd(data, data.delta_f)**(-1) #WHY NOT 1/signal_duration which is the example given in GWjax
    print(invpsd.delta_f)
    #idx = int(sampling_rate * (merger.time - data.start_time))
    #data[idx-4*sampling_rate:idx+4*sampling_rate]
    fcore = data.to_frequencyseries()*invpsd #this is what goes to multiply with template^*
    print(len(data))
    print(len(invpsd))
    print(len(fcore))
    
    #Get the frequency range
    nyquist = sampling_rate//2
    freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration
    
    #Determine the low freq and high frequency cut off
    kmin, kmax = 15*signal_duration, 900*signal_duration
    #kmin, kmax = get_cutoff_indices(flow=15, fhigh=900, df=data.delta_f, N=signal_duration) # WHY DID you avoid using this?
    #And I don't understand what the difference is between freq and kmin

    #Make everyythin the same length
    fcore = jnp.asarray(fcore[kmin:kmax])
    freqs = freqs[kmin:kmax]
    invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 
    print(f'fcore {len(fcore)}')
    print(f'freqs {len(freqs)}')
    print(f'invpsd {len(invpsd)}')
    
    print(f'fcore {type(fcore)}')
    print(f'freqs {type(freqs)}')
    print(f'invpsd {type(invpsd)}')
    
    #Get the stage ready for the template
    #my_waveform_template = partial(waveform_template, freqs)

    # Initialize the parameters
    m1, m2 = initial_m1, initial_m2
    template, _ = waveform_template(freqs, m1, m2)
    print(template)
    print(f'template {len(template)}')
    print(f'template {type(template)}')
    
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.fft(template) #WHY did I change this from a rfft to fft?
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)   
    print(f'result {len(result)}')
    
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)
    
    print(f'sigma_squared {sigma_squared}')
    # Compute the SNR
    snr = result / (sigma_squared**0.5)
    print(f'snr {snr}')
    print(f'normal {len(snr)}')
    snr_max = len(snr) - 2
    snr_min = 2+4
    snr = snr[snr_min:snr_max]
    print(f'cropped snr {len(snr)}')
    peak = jnp.argmin(jnp.absolute(snr))
    snrp = jnp.absolute(snr[peak])
    print(f'snrp {snrp}')
    #Backstage is ready for snr

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2 = objective_grad((signal, (m1, m2)))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2[0]  # Access the first element of grad_m2

    return m1, m2



# Set the initial mass parameters for the template
initial_m1 = 30.0
initial_m2 = 20.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2 = gradient_descent(initial_m1, initial_m2, learning_rate, num_iterations)

# Compute the SNR of the optimized signal and template
optimized_snr = snr(signal, (optimized_m1, optimized_m2))

print("Optimized SNR:", optimized_snr)
print("Optimized Mass Parameters: m1 =", optimized_m1, "m2 =", optimized_m2)


0.03571428571428571
0.03571428571428571
57344
28673
28673
fcore 7080
freqs 7080
invpsd 7080
fcore <class 'jaxlib.xla_extension.ArrayImpl'>
freqs <class 'jaxlib.xla_extension.ArrayImpl'>
invpsd <class 'jaxlib.xla_extension.ArrayImpl'>
[-6.43331615e-01+3.94752521e+00j -2.94035527e+00-2.65175413e+00j
  3.72878602e+00-1.20959438e+00j ... -6.53569448e-07+8.03943096e-07j
 -6.41389508e-07+8.10994578e-07j -6.29140959e-07+8.17867783e-07j]
template 7080
template <class 'jaxlib.xla_extension.ArrayImpl'>
result 16384
sigma_squared (39.0028892000548+3.869538315663618e-18j)
snr [-0.50327633+0.00564291j -0.11171879-0.23912594j  0.01101527-0.06242673j
 ...  0.53272488+0.12803541j  0.27351817+0.48117682j
 -0.24611037+0.54276529j]
normal 16384
cropped snr 16376
snrp 0.001147618042581102


NameError: name 'objective_grad' is not defined

In [40]:
dynfac = 1.0e23

# Define the conditiion data function
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

# Define the matched filter function
def matched_filter(signal_duration, sampling_rate, kmin, kmax, fcore, template):
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.fft(template) #WHY did I change this from a rfft to fft?
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)   
    return result

# Define the waveform template function
def waveform_template(freqs, m1, m2):
    params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': m1, 'm2': m2, 'spin1': 0., 'spin2': 0.,
            'ra': 0., 'dec': 0., 'pol': 0.}
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

def snr(invpsd, matched_output, template):
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)

    # Compute the SNR
    snr = matched_output / sigma_squared**0.5
    
    snr_max = len(snr) - 2
    snr_min = 2+4
    snr = snr[snr_min:snr_max]
    print(f'cropped snr {len(snr)}')
    peak = jnp.argmin(jnp.absolute(snr))
    snrp = jnp.absolute(snr[peak])
    return snrp

# Define the objective function to minimize (negative SNR)
def objective(params):
    m1, m2 = params
    template, _ = my_waveform_template(m1, m2)
    matched_output = my_matched_filter(template)
    snr_peak = my_snr(matched_output, template)
    return -snr_peak

# Perform gradient descent
def gradient_descent(initial_m1, initial_m2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr):
 
    # Initialize the parameters
    m1, m2 = initial_m1, initial_m2

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2 = objective_grad((m1, m2))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2

    return m1, m2


sampling_rate = 2048  # Sampling rate in Hz
signal_duration = 8  # Duration of the signal in seconds 4 or 8

# Get the data and estimate the PSDs
merger = Merger("GW150914")
data = condition(merger.strain('H1'), sampling_rate)

print(data.delta_f)
invpsd = estimate_psd(data, data.delta_f)**(-1) #WHY NOT 1/signal_duration which is the example given in GWjax
print(invpsd.delta_f)
#idx = int(sampling_rate * (merger.time - data.start_time))
#data[idx-4*sampling_rate:idx+4*sampling_rate]
fcore = data.to_frequencyseries()*invpsd #this is what goes to multiply with template^*
print(len(data))
print(len(invpsd))
print(len(fcore))

#Get the frequency range
nyquist = sampling_rate//2
freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration

#Determine the low freq and high frequency cut off
kmin, kmax = 15*signal_duration, 900*signal_duration
#kmin, kmax = get_cutoff_indices(flow=15, fhigh=900, df=data.delta_f, N=signal_duration) # WHY DID you avoid using this?
#And I don't understand what the difference is between freq and kmin

#Make everyythin the same length
fcore = jnp.asarray(fcore[kmin:kmax])
freqs = freqs[kmin:kmax]
invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 
print(f'fcore {len(fcore)}')
print(f'freqs {len(freqs)}')
print(f'invpsd {len(invpsd)}')

print(f'fcore {type(fcore)}')
print(f'freqs {type(freqs)}')
print(f'invpsd {type(invpsd)}')

#Get the stage ready for the template
my_waveform_template = partial(waveform_template, freqs)
my_matched_filter = partial(matched_filter, signal_duration, sampling_rate, kmin, kmax, fcore)
my_snr = partial(snr, invpsd)

# Set the initial mass parameters for the template
initial_m1 = 39.0
initial_m2 = 34.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.01
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2 = gradient_descent(initial_m1, initial_m2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr)

print("Optimized Mass Parameters: m1 =", optimized_m1, "m2 =", optimized_m2)

0.03571428571428571
0.03571428571428571
57344
28673
28673
fcore 7080
freqs 7080
invpsd 7080
fcore <class 'jaxlib.xla_extension.ArrayImpl'>
freqs <class 'jaxlib.xla_extension.ArrayImpl'>
invpsd <class 'jaxlib.xla_extension.ArrayImpl'>
cropped snr 16376
Optimized Mass Parameters: m1 = 38.99337056455571 m2 = 33.99472396633342


In [49]:
dynfac = 1.0e23

# Define the conditiion data function
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

# Define the matched filter function
def matched_filter(signal_duration, sampling_rate, kmin, kmax, fcore, template):
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.fft(template) #WHY did I change this from a rfft to fft?
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    result = jnp.fft.ifft(workspace)   
    return result

# Define the waveform template function
def waveform_template(freqs, m1, m2, s1, s2):
    params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': m1, 'm2': m2, 'spin1': s1, 'spin2': s2,
            'ra': 0., 'dec': 0., 'pol': 0.}
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

def snr(invpsd, matched_output, template):
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd)

    # Compute the SNR
    snr = matched_output / sigma_squared**0.5
    
    snr_max = len(snr) - 2
    snr_min = 2+4
    snr = snr[snr_min:snr_max]
    print(f'cropped snr {len(snr)}')
    peak = jnp.argmin(jnp.absolute(snr))
    snrp = jnp.absolute(snr[peak])
    return snrp

# Define the objective function to minimize (negative SNR)
def objective(params):
    m1, m2, s1, s2 = params
    template, _ = my_waveform_template(m1, m2, s1, s2)
    matched_output = my_matched_filter(template)
    snr_peak = my_snr(matched_output, template)
    return -snr_peak

# Perform gradient descent
def gradient_descent(initial_m1, initial_m2, initial_s1, initial_s2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr):
 
    # Initialize the parameters
    m1, m2, s1, s2, = initial_m1, initial_m2, initial_s1, initial_s2

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2, grad_s1, grad_s2 = objective_grad((m1, m2, s1, s2))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2
        s1 -= learning_rate * grad_s1
        s2 -= learning_rate * grad_s2
        
    print(m1, m2, s1, s2)
    return m1, m2, s1, s2


sampling_rate = 2048  # Sampling rate in Hz
signal_duration = 8  # Duration of the signal in seconds 4 or 8

# Get the data and estimate the PSDs
merger = Merger("GW150914")
data = condition(merger.strain('H1'), sampling_rate)

print(data.delta_f)
invpsd = estimate_psd(data, data.delta_f)**(-1) #WHY NOT 1/signal_duration which is the example given in GWjax
print(invpsd.delta_f)
#idx = int(sampling_rate * (merger.time - data.start_time))
#data[idx-4*sampling_rate:idx+4*sampling_rate]
fcore = data.to_frequencyseries()*invpsd #this is what goes to multiply with template^*
print(len(data))
print(len(invpsd))
print(len(fcore))

#Get the frequency range
nyquist = sampling_rate//2
freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration

#Determine the low freq and high frequency cut off
kmin, kmax = 15*signal_duration, 900*signal_duration
#kmin, kmax = get_cutoff_indices(flow=15, fhigh=900, df=data.delta_f, N=signal_duration) # WHY DID you avoid using this?
#And I don't understand what the difference is between freq and kmin

#Make everyythin the same length
fcore = jnp.asarray(fcore[kmin:kmax])
freqs = freqs[kmin:kmax]
invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 
print(f'fcore {len(fcore)}')
print(f'freqs {len(freqs)}')
print(f'invpsd {len(invpsd)}')

print(f'fcore {type(fcore)}')
print(f'freqs {type(freqs)}')
print(f'invpsd {type(invpsd)}')

#Get the stage ready for the template
my_waveform_template = partial(waveform_template, freqs)
my_matched_filter = partial(matched_filter, signal_duration, sampling_rate, kmin, kmax, fcore)
my_snr = partial(snr, invpsd)

# Set the initial mass parameters for the template
initial_m1 = 39.0
initial_m2 = 34.0
initial_s1 = 0.0
initial_s2 = 0.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.001
num_iterations = 100

# Perform gradient descent to find the lowest SNR
optimized_m1, optimized_m2, optimized_s1, optimized_s2 = gradient_descent(initial_m1, initial_m2, initial_s1, initial_s2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr)

print("Optimized Mass and Spin Parameters: m1 =", optimized_m1, "m2 =", optimized_m2, "s1 =", optimized_s1, "s2 =", optimized_s2)

0.03571428571428571
0.03571428571428571
57344
28673
28673
fcore 7080
freqs 7080
invpsd 7080
fcore <class 'jaxlib.xla_extension.ArrayImpl'>
freqs <class 'jaxlib.xla_extension.ArrayImpl'>
invpsd <class 'jaxlib.xla_extension.ArrayImpl'>
cropped snr 16376
38.99954887079842 33.99935362808911 0.04664229273664835 0.03875220123103149
Optimized Mass and Spin Parameters: m1 = 38.99954887079842 m2 = 33.99935362808911 s1 = 0.04664229273664835 s2 = 0.03875220123103149
