# Waveform based CNN

From https://github.com/gwastro/ml-training-strategies/blob/master/Pytorch/network.py

In [108]:
import torch
import math
import numpy as np
import torch.nn as nn
from torch.fft import fft, rfft, ifft
import matplotlib.pyplot as plt
from scipy import signal
import librosa
import librosa.display
from torchaudio.functional import bandpass_biquad, lfilter
from pathlib import Path

COMP_NAME = "g2net-gravitational-wave-detection"

INPUT_PATH = Path(f"/mnt/storage_dimm2/kaggle_data/{COMP_NAME}/")
OUTPUT_PATH = Path(f"/mnt/storage_dimm2/kaggle_output/{COMP_NAME}/")

import sys
sys.path.append("/home/anjum/kaggle/g2net-gravitational-wave-detection/")

from src.resnet1d import ResNet1D

In [2]:
def load_file(id_, folder="train"):
    path = INPUT_PATH / folder / id_[0] / id_[1] / id_[2] / f"{id_}.npy"
    waves = np.load(path)
#     return waves / np.max(np.abs(waves), axis=1).reshape(3, 1)
    return waves / np.max(np.abs(waves))


# https://www.kaggle.com/kevinmcisaac/g2net-spectral-whitening
def apply_whiten(signal, window=False):  # signal is a numpy array
    
    signal = torch.from_numpy(signal).float()
    
    if signal.ndim == 2:
        win_length = signal.shape[1]
    else:
        win_length = signal.shape[0]
    
    # Not needed if a window has already been applied. Tukey is probably better
    if window:    
        hann = torch.hann_window(win_length, periodic=True, dtype=float)
        signal *= hann
        
    spec = fft(signal)
    mag = torch.sqrt(torch.real(spec * torch.conj(spec)))
    return torch.real(ifft(spec / mag)).numpy() * np.sqrt(win_length / 2)


def apply_bandpass(x, lf=35, hf=350, order=4, sr=2048):
    sos = signal.butter(order, [lf, hf], btype="bandpass", output="sos", fs=sr)
    normalization = np.sqrt((hf - lf) / (sr / 2))
    return signal.sosfiltfilt(sos, x) / normalization


def pad_data(x, padding=0.25, sr=2048):
    pad_value = int(padding * sr)
    return np.pad(x, ((0, 0), (pad_value, pad_value)))

In [3]:
# wave_id = "098a464da9"  # Super clean signal
wave_id = "000a5b6e5c"

In [102]:
class WaveformCNN(nn.Module):
    def __init__(self, n_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm1d(n_channels),
            nn.Conv1d(n_channels, 8, 64),
            nn.ELU(),
            nn.Conv1d(8, 8, 32),
            nn.MaxPool1d(4),
            nn.ELU(),
            nn.Conv1d(8, 16, 32),
            nn.ELU(),
            nn.Conv1d(16, 32, 16),
            nn.MaxPool1d(3),
            nn.ELU(),
            nn.Conv1d(32, 64, 16),
            nn.ELU(),
            nn.Conv1d(64, 128, 16),
            nn.AdaptiveAvgPool1d(1),
            nn.ELU(),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.Dropout(p=0.5),
            nn.ELU(),
            nn.Linear(64, 64),
            nn.Dropout(p=0.5),
            nn.ELU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x)

In [103]:
wcnn = WaveformCNN(3)

In [104]:
data = torch.from_numpy(load_file(wave_id)).unsqueeze(0).float()
data.shape

torch.Size([1, 3, 4096])

In [105]:
out = wcnn(data)
out.shape

torch.Size([1, 1])

# Make a residual net

In [112]:
model = ResNet1D(
    in_channels=3,
    base_filters=128,
    kernel_size=16,
    stride=2,
    n_block=48,
    groups=32,
    n_classes=1,
    downsample_gap=6,
    increasefilter_gap=12,
    verbose=False,
)

model

ResNet1D(
  (first_block_conv): MyConv1dPadSame(
    (conv): Conv1d(3, 128, kernel_size=(16,), stride=(1,))
  )
  (first_block_bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (first_block_relu): ReLU()
  (basicblock_list): ModuleList(
    (0): BasicBlock(
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU()
      (do1): Dropout(p=0.5, inplace=False)
      (conv1): MyConv1dPadSame(
        (conv): Conv1d(128, 128, kernel_size=(16,), stride=(1,), groups=32)
      )
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU()
      (do2): Dropout(p=0.5, inplace=False)
      (conv2): MyConv1dPadSame(
        (conv): Conv1d(128, 128, kernel_size=(16,), stride=(1,), groups=32)
      )
      (max_pool): MyMaxPool1dPadSame(
        (max_pool): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1)

In [111]:
out = model(data)
out.shape

torch.Size([1, 1])