In [1]:
import numpy as np
import operator
import logging
import time

import optuna

import jax
import jax.numpy as jnp
from jax import grad, jit

from functools import partial
import gwjax
import gwjax.imrphenom

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation


  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
#Set-up the logging 
logger = logging.getLogger(__name__)  
logger.setLevel(logging.INFO) # set log level 

file_handler = logging.FileHandler('GWtuna.log') # define file handler and set formatter
formatter    = logging.Formatter('%(asctime)s : %(levelname)s : %(name)s : %(message)s')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler) # add file handler to logger

In [3]:
# Define the conditiion data function
dynfac = 1.0e23
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

# Define the matched filter function
def matched_filter(signal_duration, sampling_rate, kmin, kmax, fcore, template):
    workspace = jnp.zeros(sampling_rate*signal_duration, dtype=complex)
    template_fft = jnp.fft.fft(template) #WHY did I change this from a rfft to fft?
    result_fft = fcore * jnp.conj(template_fft)
    workspace = workspace.at[kmin:kmax].set(result_fft)#can't change array in Jax
    #print(f'template is{workspace}')
    result = jnp.fft.ifft(workspace) #There is a bug here 
    #print(f'The match is {result}')
    return result

# Define the waveform template function
def waveform_template(freqs, m1, m2, s1, s2):
    params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': m1, 'm2': m2, 'spin1': s1, 'spin2': s2,
            'ra': 0., 'dec': 0., 'pol': 0.}
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

# Define the waveform template function
def optuna_waveform_template(freqs, m1, m2):
    params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': m1, 'm2': m2, 'spin1': 0, 'spin2': 0,
            'ra': 0., 'dec': 0., 'pol': 0.}
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

def snr(invpsd, matched_output, template):
    #Compute sigma
    sigma_squared = jnp.sum(template*jnp.conj(template)*invpsd) #There is a bug here 

    # Compute the SNR
    snr = matched_output / sigma_squared**0.5 #There is a bug here 
    #print(f'The snr is {snr}')
    snr_max = len(snr) - 2
    snr_min = 2+4
    snr = snr[snr_min:snr_max]
    peak = jnp.argmin(jnp.absolute(snr))
    #print(peak)
    snrp = jnp.absolute(snr[peak])
    #print(snrp)
    return snrp

# Define the objective function to minimize (negative SNR)
def objective(params):
    m1, m2, s1, s2 = params
    template, _ = my_waveform_template(m1, m2, s1, s2)
    matched_output = my_matched_filter(template)
    snr_peak = my_snr(matched_output, template)
    return -snr_peak

# Perform gradient descent
def gradient_descent(initial_m1, initial_m2, initial_s1, initial_s2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr):
 
    # Initialize the parameters
    m1, m2, s1, s2, = initial_m1, initial_m2, initial_s1, initial_s2

    # Define the gradient of the objective function
    objective_grad = jit(grad(objective))

    for i in range(num_iterations):
        # Compute the gradient
        grad_m1, grad_m2, grad_s1, grad_s2 = objective_grad((m1, m2, s1, s2))

        # Update the parameters
        m1 -= learning_rate * grad_m1
        m2 -= learning_rate * grad_m2
        s1 -= learning_rate * grad_s1
        s2 -= learning_rate * grad_s2
        
    return m1, m2, s1, s2

In [4]:
sampling_rate = 2048  # Sampling rate in Hz
signal_duration = 8  # Duration of the signal in seconds 4 or 8

# Set the initial mass parameters for the template
initial_m1 = 39.0 #from the Optuna section of GWtuna
initial_m2 = 34.0 #from t he Optuna section of GWtuna
initial_s1 = 0.0
initial_s2 = 0.0

# Set the learning rate and number of iterations for gradient descent
learning_rate = 0.001
num_iterations = 100

# Get the data and estimate the PSDs
merger = Merger("GW150914")
data = condition(merger.strain('H1'), sampling_rate)

invpsd = estimate_psd(data, data.delta_f)**(-1) #WHY NOT 1/signal_duration which is the example given in GWjax
#idx = int(sampling_rate * (merger.time - data.start_time))
#data[idx-4*sampling_rate:idx+4*sampling_rate]
fcore = data.to_frequencyseries()*invpsd #this is what goes to multiply with template^*

#Get the frequency range
nyquist = sampling_rate//2
freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration

#Determine the low freq and high frequency cut off
kmin, kmax = 15*signal_duration, 900*signal_duration
#kmin, kmax = get_cutoff_indices(flow=15, fhigh=900, df=data.delta_f, N=signal_duration) # WHY DID you avoid using this?
#And I don't understand what the difference is between freq and kmin

#Make everything the same length
fcore = jnp.asarray(fcore[kmin:kmax])
freqs = freqs[kmin:kmax]
invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 

#Get the stage ready for the template
my_waveform_template = partial(waveform_template, freqs)
my_optuna_waveform_template = partial(optuna_waveform_template, freqs)
my_matched_filter = partial(matched_filter, signal_duration, sampling_rate, kmin, kmax, fcore)
my_snr = partial(snr, invpsd)

In [5]:
class NeedsInvestigatingCallback(object):
    """A callback for Optuna which identifies potential events."""

    def __init__(self, early_stopping_rounds: int, snr_threshold: int, direction: str = "minimize") -> None:
        self.snr_threshold = snr_threshold
        self.early_stopping_rounds = early_stopping_rounds
        
        self._iter = 0

        if direction == "minimize":
            self._operator = operator.lt
            self._score = np.inf
        elif direction == "maximize":
            self._operator = operator.gt
            self._score = -np.inf
        else:
            ValueError(f"invalid direction: {direction}")

    def __call__(self, study: optuna.Study, trial: optuna.Trial) -> None:
        """Goes onto Stocastic Gradient Descent."""
        if self._operator(study.best_value, self._score):
            self._iter = 0
            self._score = study.best_value
        else:
            self._iter += 1

        if self._score >= self.snr_threshold:
            if self._iter >= self.early_stopping_rounds:
                study.stop()
                logger.info("Optuna determined that", study.best_params)
                # Perform gradient descent to find the lowest SNR
                optimized_m1, optimized_m2, optimized_s1, optimized_s2 = gradient_descent(initial_m1, initial_m2, initial_s1, initial_s2, learning_rate, num_iterations, my_waveform_template, my_matched_filter, my_snr)
                logger.info("Optimized Mass and Spin Parameters: m1 =", optimized_m1, "m2 =", optimized_m2, "s1 =", optimized_s1, "s2 =", optimized_s2)

In [6]:
def objective(trial):
    m1 = trial.suggest_float('m1', 2, 100, step=0.000001)
    m2 = trial.suggest_float('m2', 2, 100, step=0.000001)
    template, _ = my_optuna_waveform_template(m1, m2)
    matched_output = my_matched_filter(template)
    snr_peak = my_snr(matched_output, template)
    return snr_peak

In [7]:
#Events = ["GW150914", "GW151012", "GW151226", "GW170104", "GW170608", "GW170729", "GW170809", "GW170814", "GW170817", "GW170818", "GW170823"]
Events = ["GW150914"]

In [11]:
for event in Events: 
    start_time = time.time()
    #optuna.logging.disable_default_handler()
    direction="minimize"
    study = optuna.create_study(sampler=optuna.samplers.TPESampler(), direction=direction)
    needs_to_be_investigated = NeedsInvestigatingCallback(300, snr_threshold=6, direction=direction)
    study.optimize(objective, callbacks=[needs_to_be_investigated], n_trials=10)
    logger.info(("Time taken", time.time() - start_time))
    logger.info(f'The event is {event} and has the best {study.best_params} with a snr {study.best_value}')
    print(f'The event is {event} and has the best {study.best_params} with a snr {study.best_value}')

[32m[I 2023-05-19 13:47:42,463][0m A new study created in memory with name: no-name-bb61b770-85b5-48d7-b0f8-adc137534c0a[0m
[32m[I 2023-05-19 13:47:42,487][0m Trial 0 finished with value: 0.00027490978731524327 and parameters: {'m1': 74.161192, 'm2': 40.345521}. Best is trial 0 with value: 0.00027490978731524327.[0m
[32m[I 2023-05-19 13:47:42,509][0m Trial 1 finished with value: 0.0005431384451607949 and parameters: {'m1': 42.234165999999995, 'm2': 29.441498}. Best is trial 0 with value: 0.00027490978731524327.[0m
[32m[I 2023-05-19 13:47:42,531][0m Trial 2 finished with value: 0.0003805096314326402 and parameters: {'m1': 88.231613, 'm2': 30.843317}. Best is trial 0 with value: 0.00027490978731524327.[0m
[32m[I 2023-05-19 13:47:42,553][0m Trial 3 finished with value: 0.0002187475531722316 and parameters: {'m1': 95.675871, 'm2': 34.40295}. Best is trial 3 with value: 0.0002187475531722316.[0m
[32m[I 2023-05-19 13:47:42,576][0m Trial 4 finished with value: 0.0005247213421

The event is GW150914 and has the best {'m1': 95.675871, 'm2': 34.40295} with a snr 0.0002187475531722316
