# G2Net

Quick Exploratory Data Analysis for [G2Net Gravitational Wave Detection](https://www.kaggle.com/c/g2net-gravitational-wave-detection) challenge

*Thanks*

* [🕳️G2Net🕳️ - EDA and Modeling](https://www.kaggle.com/ihelon/g2net-eda-and-modeling)
* [MatchedFiltering](https://github.com/moble/MatchedFiltering)

## Useful sources

- Gravitational waves (GW)
    - [First observation](https://journals.aps.org/prl/pdf/10.1103/PhysRevLett.116.061102)
    - [Detecting methods](https://iopscience.iop.org/article/10.1088/0264-9381/21/20/024/pdf) или [🏴‍☠️](https://sci-hub.ru/10.1088/0264-9381/21/20/024)
- Scientific python modules
    - [`PyCBC`](http://pycbc.org/pycbc/latest/html/) + [tutorials](https://github.com/gwastro/PyCBC-Tutorials)
    - [`GWpy`](https://gwpy.github.io/docs/stable/index.html#)
- Matching filtering (classical way to detect GW)
    - [Matching Matched Filtering with Deep Networks for Gravitational-Wave Astronomy](https://journals.aps.org/prl/pdf/10.1103/PhysRevLett.120.141103)
    - [Matching filtering notebook](https://nbviewer.jupyter.org/github/moble/MatchedFiltering/blob/binder/MatchedFiltering.ipynb), [code](https://github.com/moble/MatchedFiltering/blob/binder/utilities.py)
- Ideas
    - [Convolutional neural networks: A magic bullet for gravitational-wave detection?](https://arxiv.org/pdf/1904.08693.pdf)
    - [Deep Learning for real-time gravitational wave detection and parameter estimation: Results with Advanced LIGO data](https://reader.elsevier.com/reader/sd/pii/S0370269317310390?token=CE27A8104F32CF9E7B6C1C42755C4BD09041FDCD4C97466386736235D570DE1DA77622F7C26D54BFD43F0D1E4E732409&originRegion=eu-west-1&originCreation=20210908151446)
    - [Gravitational-wave parameter estimation with autoregressive neural network flows](https://arxiv.org/pdf/2002.07656.pdf)
    - [Improving Deep Speech Denoising by Noisy2Noisy Signal Mapping](https://arxiv.org/ftp/arxiv/papers/1904/1904.12069.pdf)
- Kaggle notebooks
    - [Data Transformations + EfficientNet](https://www.kaggle.com/ihelon/g2net-eda-and-modeling)
- Kaggle Discussion
    - [signal fixed time](https://www.kaggle.com/c/g2net-gravitational-wave-detection/discussion/268553)

## Setup environment

### Files
**train/** - the training set files, one npy file per observation; labels are provided in a files shown below   
**test/** - the test set files; you must predict the probability that the observation contains a gravitational wave   
**training_labels.csv** - target values of whether the associated signal contains a gravitational wave   
**sample_submission.csv** - a sample submission file in the correct format

In [None]:
!pip install -q nnAudio -qq # to draw q transform
!pip install efficientnet_pytorch -q # to train the model 

In [None]:
import os

import numpy as np
import pandas as pd

import scipy
import scipy.signal
from scipy.interpolate import InterpolatedUnivariateSpline

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
import torch.nn.functional as F
import torchaudio

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
train_df = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
train_df.head()

In [None]:
sns.countplot(data=train_df, x="target"); # equal class fractions

## Data exploration

In [None]:
!wget https://github.com/moble/MatchedFiltering/raw/binder/utilities.py # useful functions to preprocess data
!wget https://github.com/moble/MatchedFiltering/raw/binder/Data/NR_GW150914.txt # gravity wave signal
!wget -N https://gist.github.com/nikita-p/8c67b7228f7b3930025bab90ef4ae8ff/raw/1f439dde4b2459bbc524ad3021e011670c1c8552/g2net_models.py # 1Dcnns (import g2net_models)

In [None]:
from utilities import bandpass, bump_function
from scipy.signal import welch
from nnAudio.Spectrogram import CQT1992v2

def fade(signal, fade_length=0.075):
    """Customize `utilities.fade` to work with 2-dim arrays
    """
    n = signal.shape[1]
    t = np.arange(n, dtype=float)
    return signal * bump_function(t, t[0], t[int(fade_length*n)], t[int(-1-fade_length*n)], t[-1])


class Dataset(torch.utils.data.Dataset):
    def __convert_image_id_2_path(self, image_id: str, is_train: bool = True) -> str:
        folder = "train" if is_train else "test"
        return "../input/g2net-gravitational-wave-detection/{}/{}/{}/{}/{}.npy".format(
            folder, image_id[0], image_id[1], image_id[2], image_id 
        )
    
    def estimate_noise_spectrum(self, signal):
        """Estimates noise spectrum via scipy.signal.welch
        
        Parameters
        ----------
        signal : np.array
            signal time representation
            
        Returns
        -------
        noise_interpolator
            interpolator for whitening
        """
        
        number_of_chunks = 8
        length = signal.shape[1]
        points_per_chunk = 2**int(np.log2(length/number_of_chunks))
        f_noise, noise_welch = welch(signal, self.sampling_rate, nperseg=points_per_chunk)
        noise_spectral_density = np.sqrt(2 * length * noise_welch / self.sampling_rate)
        noise_interpolator = [InterpolatedUnivariateSpline(f_noise, noise_spectral_density[i]) for i in range(signal.shape[0])]
        return noise_interpolator
    
    def raw_signal(self, idx: int):
        """Prepares raw signal of item (times, freqs)
        
        Parameters
        ----------
        idx : int
            event number
            
        Returns
        -------
        raw_signal : np.array
            raw signal event from 3 detectors (LIGO_H, LIGO_L, Virgo), time representation
        raw_freq : np.array
            raw signal event, freq representation
        """
        
        image_id = self.img[idx]
        path = self.__convert_image_id_2_path(image_id)
        raw_signal = np.load(path)
        raw_signal = fade(raw_signal)
        raw_freq = np.fft.rfft(raw_signal) / self.sampling_rate
        return raw_signal, raw_freq
    
    def filtered_signal(self, idx: int, return_interpolators: bool = False):
        """Prepares filtered signal of item (whitening + bandpass filter from 35 to upper bound)
        
        Parameters
        ----------
        idx : int
            event number
        return_interpolators : bool
            return interpolators for whitening or not
            
        Returns
        -------
        filtered_signal : np.array
            filtered signal (time representation)
        filtered_freq : np.array
            filtered signal (freq representation)
        noise_interpolator
            interpolators for whitening if return_interpolators is True  
        """
        
        raw_signal, raw_freq = self.raw_signal(idx)
        
        # Estimate the noise spectrum
        noise_interpolator = self.estimate_noise_spectrum(raw_signal)
        raw_frequencies = np.fft.rfftfreq(raw_signal.shape[1], d=1/self.sampling_rate)
        raw_bkg_freq = np.array([noise_interpolator[i](raw_frequencies) for i in range(3)])

        # Equalize the data using this noise estimate
        filtered_freq = raw_freq / raw_bkg_freq
        filtered_signal = fade(self.sampling_rate * np.fft.irfft(filtered_freq))
        
        # Finally, bandpass the equalized data
        filtered_signal = bandpass(filtered_signal, self.sampling_rate, lower_end=35.0,
                                   upper_end=self.upper_bandpass_frequency)
        filtered_freq = np.fft.rfft(filtered_signal) / self.sampling_rate
        
        if return_interpolators:
            return filtered_signal, filtered_freq, noise_interpolator
        
        return filtered_signal, filtered_freq
    
    def matching_filtering(self, idx: int, bh_signal_phase: complex = 0.1j) -> np.array:
        """Matching filtering for filtered signal with simulated signal from file. 
        ATTENTION: PRELIMINARY!
        
        Parameters
        ----------
        idx : int
            event number
        bh_signal_phase : complex
            phase of the simulated signal
        
        Returns
        -------
        resulted_signal : np.array
            filtered signal
        """
        
        filtered_signal, filtered_freq, noise_interpolator = self.filtered_signal(idx, return_interpolators = True)
        
        bh_signal = np.exp(bh_signal_phase*np.pi)*self.bh_signal.copy().view('complex')
        bh_signal = fade(bh_signal.real.reshape(1, -1))
        bh_freq = np.fft.rfft(bh_signal) / self.sampling_rate
        
        bh_frequencies = np.fft.rfftfreq(bh_signal.shape[1], d=1/self.sampling_rate)
        bh_bkg_freq = np.array([noise_interpolator[i](bh_frequencies) for i in range(3)])

        filtered_bh_freq = bh_freq.copy() / bh_bkg_freq
        filtered_bh_signal = fade(self.sampling_rate * np.fft.irfft(filtered_bh_freq)
                                  [:, bh_signal.shape[1]//2 - self.sampling_rate:bh_signal.shape[1]//2 + self.sampling_rate])

        filtered_bh_signal = bandpass(filtered_bh_signal, self.sampling_rate, lower_end=35.0,
                                   upper_end=self.upper_bandpass_frequency)
        filtered_bh_freq = np.fft.rfft(filtered_bh_signal) / self.sampling_rate
        
        # Offset
        match = np.fft.irfft(filtered_freq * filtered_bh_freq.conjugate())
        offsets = np.argmax(abs(match), axis=1)
        resulted_signal = np.array([np.roll(filtered_signal[i], -offsets[i])*filtered_bh_signal[i] for i in range(3)])
        resulted_freq = np.fft.rfft(resulted_signal) / self.sampling_rate
        
        return resulted_signal, resulted_freq
    
    def q_transformed(self, idx: int, filtered: bool = True):
        """Q-transform signal
        
        Parameters
        ----------
        idx : int
            event number
        filtered : bool
            use filtered signal from `self.filtered_signal` or raw from `self.raw_signal`
            
        Returns
        -------
        signal_transformed : torch.Tesnor
            q-transformed signal
        """
        
        if filtered:
            signal, freq = self.filtered_signal(idx)
        else:
            signal, freq = self.raw_signal(idx)
        
        if self.q_transform is None:
            self.q_transform = CQT1992v2(sr=self.sampling_rate, fmin=35, fmax=self.upper_bandpass_frequency, hop_length=32)
        
        t_signal_torch = torch.from_numpy(signal.astype(np.float32))
        signal_transformed = self.q_transform(t_signal_torch)
        
        return signal_transformed
    
    def event_spectrogram(self, idx: int, filtered: bool = True):
        """Returns spectrogram
        
        Parameters
        ----------
        idx : int
            event number
        filtered : bool
            use filtered signal from `self.filtered_signal` or raw from `self.raw_signal`
            
        Returns
        -------
        spectrogram : torch.Tesnor
            signal spectrogram
        """
        if self.spectrogram is None:
            self.spectrogram = torchaudio.transforms.Spectrogram(n_fft = 1024, hop_length=2, power=2, win_length = 150)
        
        if filtered:
            signal, freq = self.filtered_signal(idx)
        else:
            signal, freq = self.raw_signal(idx)
            signal *= 1e21
        spectrogram = self.spectrogram(torch.from_numpy(signal.astype(np.float32)))
        return spectrogram
    
    def mel_transformed(self, idx: int, filtered: bool = True):
        """Returns mel spectrogram
        
        Parameters
        ----------
        idx : int
            event number
        filtered : bool
            use filtered signal from `self.filtered_signal` or raw from `self.raw_signal`
            
        Returns
        -------
        mel : torch.Tesnor
            mel signal
        """
        if self.mel_transform is None:
            self.mel_transform = torchaudio.transforms.MelScale(256, sample_rate=2048, f_min=35, f_max=350)
        
        spectra = self.event_spectrogram(idx, filtered)
        mel = self.mel_transform(spectra)
        return mel
        
    def load_bh_signal(self, bh_signal_filename: str):
        """Loads simulated signal (black holes waves)
        """
        
        bh_signal_sampling_rate = 4096
        assert bh_signal_sampling_rate%self.sampling_rate == 0
        bh_signal = np.loadtxt(bh_signal_filename).view(dtype=complex)[::(bh_signal_sampling_rate//self.sampling_rate), 0]
        return bh_signal
    
    def __init__(self, train_df_filename: str = "../input/g2net-gravitational-wave-detection/training_labels.csv", 
                 bh_signal_filename: str = 'NR_GW150914.txt'):
        super().__init__()
        self.sampling_rate = 2048
        self.upper_bandpass_frequency = 350
        
        train_df = pd.read_csv(train_df_filename)
        self.img, self.y = train_df.id.values, train_df.target.values
        self.len = len(self.img)
        
        self.bh_signal = self.load_bh_signal(bh_signal_filename)
        
        self.q_transform = None
        self.spectrogram = None
        self.mel_transform = None
        
    def __getitem__(self, idx: int):
        """Get event using index (for model training)
        
        Parameters
        ----------
        idx : int
            event number
        
        Returns
        -------
        tuple
            filtered signal (time repr), true label
        """
        x = self.filtered_signal(idx)[0].astype(np.float32)
#         x = self.mel_transformed(idx)
        true = self.y[idx]
        return (x, true)
    
    def __len__(self):
        """Returns length of the dataset
        """
        
        return self.len
    
def draw_1D_signal(t_signal: np.array, freq: np.array, title: str, event: int = None, target: int = None, t_xlim: tuple = (0, 4096)):
    plt.figure(dpi=120, tight_layout=True, figsize=(9, 3))
    event = '' if event is None else f', event {event}'
    target = '' if target is None else f', target {target}'
    plt.suptitle(f'{title}{event}{target}')
    plt.subplot(121)
    plt.plot(t_signal.T, alpha=0.8, label=['LIGO_H', 'LIGO_L', 'Virgo'], lw=1)
    plt.xlabel('$\sim t$, s')
    plt.xlim(t_xlim)
    plt.legend()
    plt.grid(ls=':')
    plt.subplot(122)
    plt.loglog(np.abs(freq.T), alpha=0.8, lw=1)
    plt.xlabel('$\sim f$, Hz')
    plt.grid(ls=':')

def draw_2D_signal(img: np.array, title: str, event: int = None, target: int = None):
    event = '' if event is None else f', event {event}'
    target = '' if target is None else f', target {target}'
    fig, ax = plt.subplots(1, 3, dpi=120, figsize=(4*3, 4))
    for i, (a, detector) in enumerate(zip(ax, ('LIGO_H', 'LIGO_L', 'Virgo'))):
        a.imshow(img[i], interpolation='bicubic', aspect='auto', cmap='jet')
        a.set(title=f'{title} {detector}{event}{target}')
    
full_dataset = Dataset()

## Look at the `14` event 
(the best that i've seen)

In [None]:
idx = 14

In [None]:
t_signal, freq = full_dataset.raw_signal(idx)
draw_1D_signal(t_signal, freq, 'Raw signal', event=idx, target=full_dataset.y[idx])

In [None]:
t_signal, freq = full_dataset.filtered_signal(idx)
draw_1D_signal(t_signal, freq, 'Filtered signal', event=idx, target=full_dataset.y[idx])

In [None]:
signal, freq = full_dataset.matching_filtering(idx, bh_signal_phase=0.2j)
draw_1D_signal(signal, freq, 'Matching filtering', t_xlim=(2.7e3, 3e3), event=idx, target=full_dataset.y[idx])
import scipy.integrate
I = [scipy.integrate.simps(signal[i], np.arange(len(signal[i]))/full_dataset.sampling_rate) for i in range(3)]
print('Correlations:')
for title, i in zip(('LIGO_H', 'LIGO_L', 'Virgo'), I):
    print(f'c ({title}) = {i:.3f}')

In [None]:
q_transformed = full_dataset.q_transformed(idx, filtered=True)
draw_2D_signal(q_transformed, 'CQT', event=idx, target=full_dataset.y[idx])

In [None]:
spectra = full_dataset.event_spectrogram(idx, filtered=True)
draw_2D_signal(spectra, 'Spectrgm', event=idx, target=full_dataset.y[idx])

In [None]:
mel_spectra = full_dataset.mel_transformed(idx, filtered=True)
draw_2D_signal(mel_spectra, 'mel', event=idx, target=full_dataset.y[idx])

## Models

1D conv models

2D conv models (more effiective i think)

In [None]:
full_dataset = Dataset()
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 128, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 128)

print('Train dataset len:', len(train_dataset))
print('Test dataset len:', len(test_dataset))

In [None]:
%%time
batch0 = next(iter(train_loader)) # time to create batch on-the-fly (long time)

In [None]:
import g2net_models # 1D convolutional models
model = g2net_models.Model1DCNN().to(device) #AUC: 0.829

### Sanity test

In [None]:
%%time
with torch.no_grad():
    res = model(batch0[0].to(device))
    print('Predictions slice')
    print(res[:3])
    print('\nTrue labels shape:', batch0[1].shape)
    print('Model predicts shape:', res.shape, '\n')

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)#, weight_decay = 1e-5)
loss_fn = nn.CrossEntropyLoss()

In [None]:
import warnings
def train(model, optimizer, loss_fn, dataloader, device = 'cpu', demo: bool = False):
    if demo:
        warnings.warn('Demo mode. Using only first 30 batches')
    model.train()
    losses, accs = list(), list()
    for i, batch in enumerate(dataloader):
        x, y = batch
        x, y = x.float().to(device), y.to(device)
        
        pred = model(x)
        loss = loss_fn(pred, y)
        losses.append(loss)
        accs.append((pred.argmax(dim=1)==y).sum()/len(y))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if len(losses)==10:
            mean_loss = sum(losses)/len(losses)
            mean_acc = sum(accs)/len(accs)
            print(f'Batch: {i+1}, Loss: {mean_loss:.4f}, acc: {mean_acc:.2%}')
            losses, accs = list(), list()
            
        if demo and (i>30):
             break

from tqdm import tqdm
def test(model, loss_fn, dataloader, device = 'cpu'):
    model.eval()
    losses, accs = list(), list()
    preds, trues = torch.Tensor().to(device), torch.Tensor().to(device)
    with torch.no_grad():
        for i, batch in tqdm(zip(range(50), dataloader)):
            x, y = batch
            x, y = x.to(device), y.to(device)

            pred = model(x)
            pred = nn.functional.softmax(pred, dim=1)
            loss = loss_fn(pred, y)
            losses.append(loss)
            accs.append((pred.argmax(dim=1)==y).sum()/len(y))
            preds = torch.cat([preds, pred[:, 1]])
            trues = torch.cat([trues, y])
    mean_loss = sum(losses)/len(losses)
    mean_acc = sum(accs)/len(accs)
    print(f'Test loss: {mean_loss:.4f}, acc: {mean_acc:.2%}')
    return preds, trues

In [None]:
train(model, optimizer, loss_fn, train_loader, device, demo=True)