This assignment is the first part of two exercises, in which we will analyze LIGO data to find the gravitational wave transient caused by the coalescence of two neutron stars (GW170817).

You are given a true 2048-second segment of Hanford LIGO data, sampled at 4096 Hz (down-sampled from the original 16 kHz data). Along with this PDF, you should have:

1. `strain.npy`, readable by NumPy, containing the strain data.
2. `gw_search_functions`, containing helpful functions, constants.
3. The timestamps corresponding to the strain are not uploaded due to size, and are instead provided in `gw_search_functions`.

---

In this notebook, we will practice using FFT/RFFT to perform a matched-filter search in the data and compute a test statistic. Special attention will be given to understanding the **normalization** of inputs and the expected test statistic values. Finally, we will apply a glitch-removal procedure to the data.

It is advised to get this code from https://github.com/JonathanMushkin/GW_search_tutorial, and use the pyproject.toml to define an environment.

Please contact jonathan.mushkin[at]weizmann.ac.il for any help, question or comment.



# Introduction

Under the null and signal hypotheses, the data model is:

$$
H_0: \quad s(t) = n(t) \\
H_1: \quad s(t) = n(t) + h(t)
$$

The noise $n(t)$ is approximately stationary and Gaussian with a certain power spectral density $S_n(f)$. This is only approximately true for two reasons:  
1. The spectral shape changes smoothly over a few seconds.  
2. There are *glitches* — unexplained, time-localized loud transients.

Under the Gaussian noise approximation, the log-likelihood of waveform $h$ given strain data $s$ is:

$$
\ln \mathcal{L} = \Re \langle h, s \rangle - \frac{1}{2} \langle h, h \rangle
$$

with the inner product defined as:

$$
\langle a, b \rangle = \sum_f \frac{a(f) b^\ast(f)}{S_n(f)}\,\mathrm{d}f = \sum_f \tilde{a}(f) \tilde{b}^\ast(f)\,\mathrm{d}f
$$

where the tilde denotes the whitened series.

The strain signal at the detector is a linear combination of the two polarizations:

$$
h(f) = F_+ h_+(f) + F_\times h_\times(f)
$$

Under the non-precessing, dominant-mode approximation, the polarization components satisfy:

$$
h_\times(f) = i\, h_+(f)
$$

(i.e., a sine in one is a cosine in the other). The detector response can thus be treated as a complex amplitude and maximized over. 

We define the complex overlap time series using "inverted convolution" notation:

$$
z(t) = z_{\cos}(t) + i\, z_{\sin}(t) = (\tilde{s} \ast \overleftarrow{\tilde{h}_+})(t) + i\, (\tilde{s} \ast \overleftarrow{\tilde{h}_\times})(t)
$$

Using a normalization such that:

$$
\langle h_+, h_+ \rangle = 1,
$$

the Signal-to-Noise Ratio (SNR) time series is:

$$
\text{SNR}^2(t) = |z(t)|^2 = |z_{\cos}(t)|^2 + |z_{\sin}(t)|^2
$$

and the log-likelihood becomes:

$$
\log \mathcal{L} = \frac{1}{2} \text{SNR}^2(t)
$$

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal, stats
import gw_search_functions

plt.rcParams["axes.labelsize"] = 14
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["axes.titlesize"] = 16

## 1 
Load the time domain data and Fourier transform it.

In [None]:
filename = "strain.npy"
event_name = "GW170817"
detector_name = "H"
fs = 2**12  # Hz

strain = np.load(filename)
times = np.arange(len(strain)) / fs
dt = times[1] - times[0]
freqs = np.fft.rfftfreq(len(strain), d=dt)
df = freqs[1] - freqs[0]

tukey_window = signal.windows.tukey(M=len(strain), alpha=0.1)
strain_f = np.fft.rfft(strain * tukey_window)

# presenting the time domain signal after tueky widnow

fig, ax = plt.subplots()
_ = ax.plot(times, tukey_window * strain)

## 2
In the next few cells we walk through the whitening of the data. We use the Welch method, and remove by force frequencies below 20 Hz. The latter is due to FFT artifacts due to the GW signal lengths, the duration of the data, and the assumed preiodicity in FFT framework.

In [None]:
seg_duration = 64
overlap_duration = 32
nperseg = int(seg_duration * fs)
noverlap = int(overlap_duration * fs)
welch_dict = {
    "x": strain,
    "fs": fs,
    "nperseg": nperseg,
    "noverlap": noverlap,
    "average": "median",
    "scaling": "density",
}
psd_freqs, psd_estimation = signal.welch(**welch_dict)
asd_estimation = psd_estimation**(1/2)
fmin = 20
asd = np.interp(freqs, psd_freqs, asd_estimation)
min_idx = np.searchsorted(psd_freqs, fmin)
plt.loglog(psd_freqs[min_idx:], asd_estimation[min_idx:])
plt.title("ASD")
plt.xlabel("Freq. [Hz]")

In [None]:
# Create high-pass filter
# make it go like sin-squared from 0 to 1 over (fmin, fmin+1Hz) interval
highpass_filter = np.zeros(len(freqs))
i1, i2 = np.searchsorted(freqs, (fmin, fmin + 1))
highpass_filter[i1:i2] = np.sin(np.linspace(0, np.pi / 2, i2 - i1)) ** 2
highpass_filter[i2:] = 1.0

# whitening filter is 1/asd(f) * high-pass filter
whitening_filter_raw = highpass_filter / np.interp(
    x=freqs, xp=psd_freqs, fp=asd_estimation
)

# To avoid ripples in Fourier domain, we apply a windowing in time domain

padded_tukey_window = np.fft.fftshift(
    np.pad(
        signal.windows.tukey(M=nperseg, alpha=0.1),
        pad_width=(len(strain) - nperseg) // 2,
        constant_values=0,
    )
)
# tranform to time domain, apply the window, and return to frequency domain
whitening_filter = (
    highpass_filter
    * np.fft.rfft(padded_tukey_window * np.fft.irfft(whitening_filter_raw))
).real * np.sqrt(2 * dt)

fig, ax = plt.subplots()

_ = ax.loglog(freqs[i1:], whitening_filter[i1:])
_ = ax.set_ylim(1e18)
_ = ax.set_title("whitening filter")
_ = ax.set_xlabel("freq [Hz]")
_ = ax.set_ylabel("whitening filter")

**Use the (frequency domain) whitening filter on the data. Find a very obvious "glitch" in the time-domain whitened data. Plot a histogram of the real and imaginary parts of the whitened time-domain strain, with samples from a glitchless region.**


In [None]:
wht_strain_f = strain_f  # this is wrong
wht_strain_t = strain # this is wrong
plt.plot(times, wht_strain_t)
plt.title("Whitened strain")
plt.xlabel("time")

plt.figure()
counts, bins, _ = plt.hist(
    wht_strain_t[:], # this is wrong
    log=True,
    bins=100,
    density=True,
)

plt.plot(bins, 1/bins**2) # this is wrong 

## 3
Create a single template for a search, with arbitrarily selected masses of $m_1=1.5$ and $m_2=1.25$ (in solar masses). 

Remove frequencies below 20 Hz due to FFT artifacts due to the signal length and waveform length. 

Perform a linear-free transformation on the phase, to remove an arbitrary phase and time shift in a standardize way. The code to do it is attached. See https://arxiv.org/abs/1904.01683 for the derivation.

**Plot the time-domain templates, with and without the linear-free time shift, such that they iare localized in the middle of the time-axis. Set the plot limits such that the waveform features and differences are visible.**.

In [None]:
m1 = 1.5
m2 = 1.25
fmin = 20
i1 = np.searchsorted(freqs, fmin)
phase = np.zeros_like(freqs)
phase[i1:] = gw_search_functions.masses_to_phases(m1, m2, freqs[i1:])
amp = np.zeros_like(freqs)
amp[i1:] = freqs[i1:] ** (-7 / 6)
h = amp * np.exp(1j * phase)

In [None]:
weights = amp * whitening_filter
weights /= np.sum(weights**2) ** (1 / 2)
phase_linear_free = gw_search_functions.phases_to_linear_free_phases(phase, freqs, weights)
h_linear_free = amp * np.exp(1j * phase_linear_free)

In [None]:
fig, ax = plt.subplots()
ax.plot(times, np.fft.fftshift(np.fft.irfft(h)))
ax.plot(times, np.fft.fftshift(np.fft.irfft(h_linear_free)))
plt.xlim(850, 1030)
ax.set_xlabel("time [sec]")
ax.set_ylabel("h [arb.]")

# 4 

Prepare the template for use. 

**Make sure it is normalized such that it it appears with amplitude 1 in the data, the auto-correlation function will return 1 in the zero-lag.** 

Then, Generate the complex-overlap time-series. 

**Plot a histogram with the real and imaginary parts of the complex-overlap, in a segment of data without an obvious glitch. Overlay the theoretical predictions**.

The theoretical prediction for the complex-overlaps is that they'll follow a normal distribution. See that you understand why (single line of calculations).

In [None]:
normalization = 1  # this is wrong
h_linear_free = h_linear_free * normalization
wht_template = h_linear_free * whitening_filter

In [None]:
# Complex overlaps
z_cos = np.fft.irfft(wht_strain_f * wht_template.conj())
z_sin = np.fft.irfft(wht_strain_f * (1j * wht_template).conj())
z = z_cos + 1j * z_sin


In [None]:
# indices without a glitch
t_start = 0  # this is wrong
t_end = len(times)  # this is wrong
tslice = slice(*np.searchsorted(times, (t_start, t_end)))
# keywords for the histogram
hist_kwargs = {"bins": 200, "density": True, "log": True, "histtype": "step"}
# create 2 histograms
counts, edges, patches = plt.hist(z_cos[tslice], **hist_kwargs, label="z_cos")
counts, edges, patches = plt.hist(z_sin[tslice], **hist_kwargs, label="z_sin")
# overlay normal distribution with mu=0 and sigma=1
plt.plot(edges, stats.norm().pdf(edges), label="normal distribution")
plt.legend(loc="lower center")

# 5
**compute the $\text{SNR}^2$ times series**.

To verify your result SNR-timeseries results, use the estimated ASD to draw mock data without a GW
transient. Due to FFT-ology, the code below already does that. 

**Create the ${\rm SNR}^2$ time-series on the mock data.**

**On the same figure, plot
the histograms of the ${\rm SNR}^2$ of the real data and of the mock data. Overlay the
theoretical prediction**.

Do you understand why the test statistic follows the $\chi^2(2)$ distribution?

In [None]:
snr2 = np.abs(z) ** 2

In [None]:
N = len(strain)
freqs = np.fft.rfftfreq(N, d=1 / fs)
sigma = asd * np.sqrt(fs * N) / 2

re = np.random.normal(scale=sigma, size=len(freqs))
im = np.random.normal(scale=sigma, size=len(freqs))
mock_strain_f = re + 1j * im

# DC and Nyquist (real only)
mock_strain_f[0] = np.random.normal(scale=sigma[0] * np.sqrt(2))
if N % 2 == 0:
    mock_strain_f[-1] = np.random.normal(scale=sigma[-1] * np.sqrt(2))

mock_strain = np.fft.irfft(mock_strain_f, n=N)

# plot the strain and mock strain, to see they have similar amplitude
plt.semilogy(freqs, np.abs(strain_f), alpha=0.5, label="strain")
plt.semilogy(freqs, np.abs(mock_strain_f), alpha=0.5, label="mock_strain")
plt.ylabel(r"$|{\rm strain}(f)|$")
plt.xlabel("frequency [Hz]")
plt.legend()

**calculate the SNR^2 for the mock data**
**Plot the histograms of the mock SNR^2, real SNR^2, and the pdf of chi2 with two d.o.fs**


In [None]:
wht_mock_strain_f = ()
mock_z_cos = ()
mock_z_sin = ()
mock_snr2 = ()

In [None]:
hist_kwargs = {
    "histtype": "step",
    "density": True,
    "log": True,
    "bins": range(200),
}

counts, edges, patches = plt.hist(
    snr2, **hist_kwargs, label=r"real data SNR$^2$"
)
counts, edges, patches = plt.hist(
    mock_snr2, **hist_kwargs, label=r"mock data SNR$^2$"
)
plt.plot(edges, stats.chi2(df=2).pdf(edges), label=r"$\chi^2(2)$ pdf")
# focus on interesting portion of the histogram
plt.xlim(0, 100)
plt.ylim(1 / np.diff(edges).mean() / len(snr2) / 10)
plt.legend()

# 6
Glitches are short periods of time with strong power, not coming from the stationary noise nor an astrophysical GW transient. Since their shape is not related to the shape of GW transient, they will fail a signal-consistency test. This test is defined per-template. We will create a $h_{\rm low}$ and $h_{\rm high}$ :
\begin{align}
    h_{\rm low}(f) = 
    \begin{cases}
    h_{+}(f) & f< \bar{f}\\
    0 & f> \bar{f}
    \end{cases}
    \\
    h_{\rm high}(f) = \begin{cases}
        0 & f<\bar{f}\\
        h_{+}(f) & f>\bar{f}
    \end{cases}
\end{align}
where $\bar{f}$ is defined as the mid-point of the template accumulated SNR$^2$:
\begin{equation}
    \sum_{f=0}^{\bar{f}} \frac{|h_+|^2}{S_n(f)} {\rm d}f= \sum_{f=\bar{f}}^{f_{\rm max}}\frac{|h_+|^2}{S_n(f)}{\rm d}f
\end{equation}
$h_{\rm low}$ and $h_{\rm high}$ are normalized to have unity norm ( $\langle h_{\rm low} | h_{\rm low}\rangle=\langle h_{\rm high}|h_{\rm high}\rangle=1$). This means that their complex-overlaps $z_{\rm low}(t)$, $z_{\rm high}(t)$ should be complex-normal random variables with variance of 1. The glitch-test $g(t)$ is defined as:
\begin{equation}
    g(t) = \frac{1}{2}|z_{\rm low} - z_{\rm high}|^2(t)
\end{equation}
Under the noise hypothesis or under signal consistent with $h_+$, $g$ follows a $\chi^2(2)$ distribution. In the presence of a glitch, $z_{\rm low}$ and $z_{\rm high}$ will have large amplitudes and different phases, which will lead to a large $g(t)$.

To mark an element of the timeseries as a glitch, it has to both have ${\rm SNR}
^2$ larger than some value, which you will set by observing the ${\rm SNR}^2$ histogram, AND that $g(t)$ has false-positive (probability to reject a measurement under the no-glitch hypothesis) of 1\%.


Create a test-statistic to detect glitches. 

Use it to remove glitches from the ${\rm SNR}^2$ timeseries.

**Plot the cleaned ${\rm SNR}^{2}$ timeseries histogram. Overlay the theoretical prediction**.

In [None]:
# find f_bar = where the cumulative SNR2 is equal half the overall SNR2
frac_snr2 = np.cumsum(np.abs(wht_template) ** 2)
frac_snr2 /= frac_snr2[-1]
j = np.searchsorted(frac_snr2, 0.5)

In [None]:
# create the the low and high frequencies templates
wht_h_low, wht_h_high = np.zeros((2, len(freqs)), dtype=complex)
# normalize them so each has norm 1
wht_h_low[:j] = wht_template[:j] * 1 # this is wrong
wht_h_high[j:] = wht_template[j:] * 1 # this is wrong

In [None]:
# check normalization of split templates
print(
    "<h|h> = ",
    np.fft.irfft(wht_template * wht_template.conj())[0]
    ** (1 / 2),
)
print(
    "<h_low | h_low> = ",
    np.fft.irfft(
        wht_h_low * wht_h_low.conj()
    )[0]
    ** (1 / 2),
)
print(
    "<h_high | h_high> = ",
    np.fft.irfft(
        wht_h_high * wht_h_high.conj()
    )[0]
    ** (1 / 2),
)
print(
    "<h_low| h_high> = ",
    np.fft.irfft(
        wht_h_high * wht_h_low.conj()
    )[0]
    ** (1 / 2),
)

In [None]:
z_low = ()
z_high = ()
glitch_test_statistic = 0.5 * np.abs(z_low - z_high) ** 2

In [None]:
# create 2 scatter plots of z_cos-z_sin (real vs imaginary) around and not around a glitch.

fig, axs = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)
tslice = slice(*np.searchsorted(times, (100, 101)))
axs[0].scatter(
    (z_low - z_high)[tslice].real,
    (z_low - z_high)[tslice].imag,
    s=1,
    alpha=0.5,
)
axs[0].set_title("not around glitch")
tslice = slice(*np.searchsorted(times, (1258, 1259)))
axs[1].scatter(
    (z_low - z_high)[tslice].real,
    (z_low - z_high)[tslice].imag,
    s=1,
    alpha=0.5,
)
axs[1].set_title("around glitch")
for ax in axs:
    ax.set_xlabel(r"$z_\cos - z_\sin$ (real)")
    ax.set_ylabel(r"$z_\cos - z_\sin$ (imaginary)")
    ax.set_aspect("equal")
fig.tight_layout()

**Find a threshold on the chi2 glitch test statistic, such it will remove 1 in a 100 good signals.**

**Create a glitch removal mask, if the glitch test is too high AND the SNR is above 5 (SNR^2 > 25).**

In [None]:
glitch_test_threshold = 1  # this is wrong
glitch_mask = ()

In [None]:
fig, ax = plt.subplots()

hist_kwargs = {
    "histtype": "step",
    "density": True,
    "log": True,
    "bins": range(200),
}
counts, edges, patches = ax.hist(
    snr2, **hist_kwargs, label=r"SNR$^2$ before glitch-vetoing"
)
counts, edges, patches = ax.hist(
    snr2[~glitch_mask], **hist_kwargs, label=r"SNR$^2$ after glitch vetoing"
)


ax.plot(edges, stats.chi2(df=2).pdf(edges), label=r"$\chi^2(2)$")
y_lower_limit = 0.5 / (np.diff(edges).mean() * len(snr2))
ax.set_xlim(right=100)
ax.set_ylim(y_lower_limit)
leg = ax.legend()