In [1]:
from einops import rearrange
import copy
import h5py
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from pdb import set_trace
import matplotlib.pyplot as plt
from torch import nn
from x_transformers import  Encoder, Decoder
from x_transformers.autoregressive_wrapper import exists
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from fastai.vision.all import BCEWithLogitsLossFlat
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from fastprogress.fastprogress import master_bar, progress_bar
import os
from timm import create_model
import random
from tqdm import tqdm
from scipy.stats import binned_statistic
from copy import copy
import pickle
from torch.utils.data import Dataset
from torchvision import transforms
OUT = 'init'
DATA = 'data/gwaves_train_v5.pickle'
with open(DATA, 'rb') as f: 
    clean_data = pickle.load(f)

In [2]:
class CFG:
    bs = 64
    nw = 8
    model_name = "convnext_large_in22k"
    lr = 1e-3
    wd = 1e-4
    epoch = 100
    warmup_pct = 0.1
    num_classes = 1
    dropout_rate = 0.3
    folder = "CAT_V0"
    mixup=False
    exp_name = f"{folder}_{model_name}"

In [3]:
# iofass based
class G2NetDatasetTrain(Dataset):
    def __init__(
        self,
        noise_fns, 
        signal_fns, 
        sz_f=360,
        sz_t=128,
        p=0.66,
        p_ns=0.5,
        depth0=15,
        depth1=50,
        tfms= True
    ):
        self.data = clean_data
        self.fns =list(clean_data.keys())
        self.tfrms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ]
        )
        self.p, self.p_ns, self.sz_f, self.sz_t = p, p_ns, sz_f, sz_t
        self.depth0, self.depth1 = depth0, depth1

        # A_noise after multiplying by 1e22: very important parameter for model without norm
        # 5e-2*np.sqrt(1800)/2 = 1.0606601717798214
        self.A_noise = 1.065
        Tsft = 1800
        self.tosqrtSX = 2 / np.sqrt(Tsft)
        self.noise_fns = noise_fns
        self.signal_fns = signal_fns
        self.tfms = tfms

        with open("data/real_noise_std.pickle", "rb") as handle:
            self.std_est = pickle.load(handle)

    def __len__(self):
        return len(self.data)

    def get_on_fly(self, idx):
        data = self.data[self.fns[idx]]

        n_stat = [data["H1_ts"], data["L1_ts"]]
        noise = [[] for source in range(len(n_stat))]
        if torch.rand(1).item() < 0.5:  # nonstationary noise
            std_est = random.choice(self.std_est)
            A_noise = [
                np.where(
                    std_est["H1_ts"] > 5,
                    std_est["H1_std"],
                    self.A_noise * np.ones_like(std_est["H1_std"]),
                ),
                np.where(
                    std_est["L1_ts"] > 5,
                    std_est["L1_std"],
                    self.A_noise * np.ones_like(std_est["L1_std"]),
                ),
            ]
        else:
            A_noise = self.A_noise * np.ones((2, self.sz_t))
        for source in range(len(n_stat)):
            for n in n_stat[source]:
                if n == 0:
                    noise[source].append(torch.zeros(self.sz_f))
                    continue
                dist = torch.distributions.chi2.Chi2(2 * n)  # 2 because real+img
                sample = dist.sample((self.sz_f,)) / n
                sample *= A_noise[source][n] * A_noise[source][n]
                noise[source].append(sample)
        noise = torch.stack([torch.stack(noise[0], -1), torch.stack(noise[1], -1)], 0)

        if torch.rand(1).item() < self.p:  # positive sample
            target = 1
            x0 = torch.from_numpy(
                np.stack([data["H1"], data["L1"]], 0).astype(np.float32)
            )
            c, h, w = x0.shape
            noise = noise.repeat(1, 2, 1)[:, :h, :]
            x0 = torch.cat([x0, noise], 0)#[:, 20:-20]
            x0 = self.tfrms(x0)  # transform noise and signal togather

            depth = self.depth0 + (self.depth1 - self.depth0) * torch.rand(1)
            # print(depth)
            scale = (
                1e2 * self.A_noise * self.tosqrtSX
            )  # noise-free data is generated at h=1e-2
            x = x0[c:] + x0[:c] * scale**2 / depth**2  # noise + signal

            if target == 1:
                raw = (
                    x0[:c]
                    / x0[:c].max()
                    * (self.depth1 - depth)
                    / (self.depth1 - self.depth0)
                )
            else:
                raw = torch.zeros_like(x)
        else:
            target = 0
            x = self.tfrms(noise)
            raw = torch.zeros_like(x)

        x[0] /= torch.max(x[0].mean(0, keepdim=True), 0.1 * torch.ones_like(x[0]))
        x[1] /= torch.max(x[1].mean(0, keepdim=True), 0.1 * torch.ones_like(x[1]))

        # x = torch.cat([x, 0.5 * (x[0] + x[1]).unsqueeze(0)], 0)
        # raw = torch.cat([raw, 0.5 * (raw[0] + raw[1]).unsqueeze(0)], 0)

        return x, target

    def generate_random_file(self):
        noise_fn = random.choice(self.noise_fns)
        signal_fn = random.choice(self.signal_fns)
        if np.random.rand() >= 0.66:
            img = normalize(torch.load(noise_fn))
            y = 0.0
        else:
            img = normalize(combine(torch.load(signal_fn), torch.load(noise_fn)))
            y = 1.0

        img = img.numpy()
        if np.random.rand() <= 0.5:  # horizontal flip
            img = np.flip(img, axis=1).copy()
        if np.random.rand() <= 0.5:  # vertical flip
            img = np.flip(img, axis=2).copy()
        if np.random.rand() <= 0.5:  # vertical shift
            img = np.roll(img, np.random.randint(low=0, high=img.shape[1]), axis=1)
        if np.random.rand() <= 0.5:  # channel shuffle
            img = img[np.random.permutation([0, 1]), ...]
        return torch.tensor(img), y
    
    
    def __getitem__(self, idx):
        x, y = self.get_on_fly(idx)
        if self.tfms:
            if np.random.rand() <= 0.3:
                x = freq_mask(x)
            if np.random.rand() <= 0.3:
                x = time_mask(x)
        return x, y


def read_pkl(filename):  
    data = dict()
    with open(filename, 'rb') as file1: 
        k = pickle.load(file1)
        data['L1_SFTs_amplitudes'] = np.array(k["L1"]['spectrogram'])
        data['L1_ts'] = np.array(k["L1"]['timestamps'])
        # Retrieve the Hanford decector data
        data['H1_SFTs_amplitudes'] =  np.array(k["H1"]['spectrogram'])
        data['H1_ts'] = np.array(k["H1"]['timestamps'])
        data['freq'] = np.array(k['frequency'])
    return data 
    
def normalize_pickle(data, sz_t=128):     
    time_ids = {"H1": data["H1_ts"], "L1": data["L1_ts"]}
    mean_statH = binned_statistic(
        time_ids["H1"],
        data["H1_SFTs_amplitudes"],
        statistic="mean",
        bins=sz_t,
        range=(
            max(time_ids["H1"].min(), time_ids["L1"].min()),
            min(time_ids["H1"].max(), time_ids["L1"].max()),
        ),
    )
    mean_statL = binned_statistic(
        time_ids["L1"],
        data["L1_SFTs_amplitudes"],
        statistic="mean",
        bins=sz_t,
        range=(
            max(time_ids["H1"].min(), time_ids["L1"].min()),
            min(time_ids["H1"].max(), time_ids["L1"].max()),
        ),
    )
    mean_statH = np.nan_to_num(np.transpose(mean_statH.statistic, (0, 1)))
    mean_statL = np.nan_to_num(np.transpose(mean_statL.statistic, (0, 1)))

    x = torch.from_numpy(np.stack([mean_statH, mean_statL], 0).astype(np.float32))
    c, h, w = x.shape
    x[0] /= torch.max(x[0].mean(0, keepdim=True), 0.1 * torch.ones_like(x[0]))
    x[1] /= torch.max(x[1].mean(0, keepdim=True), 0.1 * torch.ones_like(x[1]))
    #x = torch.cat([x, 0.5 * (x[0] + x[1]).unsqueeze(0)], 0)
    return x

def read_data(path):
    data = {}
    with h5py.File(path, "r") as f:
        ID_key =  path.stem
        # Retrieve the frequency data
        try:
            data['freq'] = np.array(f['frequency_Hz'])
        except:
            data['freq'] = np.array(f[ID_key]['frequency_Hz'])
        # Retrieve the Livingston decector data
        data['L1_SFTs_amplitudes'] = np.array(f[ID_key]['L1']['SFTs'])
        data['L1_ts'] = np.array(f[ID_key]['L1']['timestamps_GPS'])
        # Retrieve the Hanford decector data
        data['H1_SFTs_amplitudes'] = np.array(f[ID_key]['H1']['SFTs'])
        data['H1_ts'] = np.array(f[ID_key]['H1']['timestamps_GPS'])
    return data

class ValLoaderPickle(torch.utils.data.Dataset):
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """
    def __init__(self, df, freq_tfms=False):
        self.df = df
        self.tfms = freq_tfms
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """
        i (int): get ith data
        """
        r = self.df.iloc[i]
        y = np.float32(r.target)
        try:
            img = normalize_pickle(read_pkl(r.id))
        except:
            img = normalize(read_data(r.id))
        return img.float(), y

In [4]:
def get_snr(left, right, df):
    df_ = pd.concat([df.query(f"snr>{left} & snr<{right}"), df.query("snr==0")])
    return df_


def generate_report(df, p, fn):
    pred = torch.sigmoid(p).cpu().numpy().reshape(-1)
    val_df_eval = df.copy()
    val_df_eval["pred"] = pred
    val_df_eval = val_df_eval.dropna(subset="pred")
    val_df_eval.to_csv(f"{fn}_oof.csv")

    roc_100 = roc_auc_score(val_df_eval["target"], val_df_eval["pred"])
    tr_comp = val_df_eval.query('data_type=="comp_train"')
    roc_comp_train = roc_auc_score(tr_comp["target"], tr_comp["pred"])
    return {"roc_all": roc_100, "roc_comp_train": roc_comp_train}


class SaveModel:
    def __init__(self, folder, exp_name, best=np.inf):
        self.best = best
        self.folder = Path(folder) / f"{exp_name}.pth"

    def __call__(self, score, model, epoch):
        if score < self.best:
            self.best = score
            print(f"Better model found at epoch {epoch} with value: {self.best}.")
            torch.save(model.state_dict(), self.folder)


class SaveModelMetric:
    def __init__(self, folder, exp_name, best=-np.inf):
        self.best = best
        self.folder = Path(folder) / f"{exp_name}.pth"

    def __call__(self, score, model, epoch):
        if score > self.best:
            self.best = score
            print(f"Better model found at epoch {epoch} with value: {self.best}.")
            torch.save(model.state_dict(), self.folder)


class SaveModelEpoch:
    def __init__(self, folder, exp_name, best=-np.inf):
        self.best = best
        self.folder = Path(folder)
        self.exp_name = exp_name

    def __call__(self, score, model, epoch):
        self.best = score
        print(f"Better model found at epoch {epoch} with value: {self.best}.")
        torch.save(model.state_dict(), f"{self.folder/self.exp_name}_{epoch}.pth")


def custom_auc_score(p, gt):
    p = torch.nan_to_num(p)
    return roc_auc_score(gt.cpu().numpy(), torch.sigmoid(p).cpu().numpy().reshape(-1))


def fit_mixup(
    epochs,
    model,
    train_dl,
    valid_dl,
    loss_fn,
    opt,
    metric,
    val_df,
    folder="models",
    exp_name="exp_00",
    device=None,
    sched=None,
    mixup_=False,
    save_md=SaveModelEpoch,
):
    if device is None:
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )

    os.makedirs(folder, exist_ok=True)
    loss_fn_trn = loss_fn
    if mixup_:
        mixup = Mixup(num_classes=2, mixup_alpha=0.4, prob=0.8)
        loss_fn_trn = BinaryCrossEntropy()
    mb = master_bar(range(epochs))

    mb.write(
        [
            "epoch",
            "train_loss",
            "valid_loss",
            "val_metric",
            "roc_all",
            "roc_comp_train",
        ],
        table=True,
    )
    model.to(device)  # we have to put our model on gpu
    # scaler = torch.cuda.amp.GradScaler()  # this for half precision training
    save_md = save_md(folder, exp_name)

    for i in mb:  # iterating  epoch
        trn_loss, val_loss = 0.0, 0.0
        trn_n, val_n = len(train_dl.dataset), len(valid_dl.dataset)
        model.train()  # set model for training
        for (xb, yb) in progress_bar(train_dl, parent=mb):
            xb, yb = xb.to(device), yb.to(device)  # putting batches to device
            if mixup_:
                xb, yb = mixup(xb, yb)

            out = model(xb)  # forward pass
            loss = loss_fn_trn(out, yb)  # calulation loss

            trn_loss += loss.item()
            # print(loss.item())
            opt.zero_grad()  # zeroing optimizer
            loss.backward()  # backward
            opt.step()  # optimzers step
            if sched is not None:
                sched.step()  # scuedular step

        trn_loss /= mb.child.total

        # putting model in eval mode
        model.eval()
        gt = []
        pred = []
        # after epooch is done we can run a validation dataloder and see how are doing
        with torch.no_grad():
            for (xb, yb) in progress_bar(valid_dl, parent=mb):
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = loss_fn(torch.nan_to_num(out), yb)
                val_loss += loss.item()

                gt.append(yb.detach())
                pred.append(out.detach())
        # calculating metric
        metric_ = metric(torch.cat(pred), torch.cat(gt))
        # saving model if necessary
        save_md(metric_, model, i)
        val_loss /= mb.child.total
        dict_res = generate_report(val_df, torch.cat(pred), f"{folder}/{exp_name}_{i}")

        pd.DataFrame(
            {
                "trn_loss": [trn_loss],
                "val_loss": [val_loss],
                "metric": [metric_],
                "roc_all": [dict_res["roc_all"]],
                "roc_comp_train": [dict_res["roc_comp_train"]],
            }
        ).to_csv(f"{folder}/{exp_name}_{i}.csv", index=False)
        mb.write(
            [
                i,
                f"{trn_loss:.6f}",
                f"{val_loss:.6f}",
                f"{metric_:.6f}",
                f"{dict_res['roc_all']:.6f}",
                f"{dict_res['roc_comp_train']:.6f}",
            ],
            table=True,
        )
    print("Training done")
    # loading the best checkpoint

In [5]:
def time_mask(spec):
    cloned = spec.clone().detach()
    len_spectro = cloned.shape[2]
    num_masks = np.random.randint(3, 8)
    T = np.random.randint(7, 12)
    for i in range(0, num_masks):
        t = random.randrange(0, T)
        t_zero = random.randrange(0, len_spectro - t)

        # avoids randrange error if values are equal and range is empty
        if (t_zero == t_zero + t): return cloned

        mask_end = random.randrange(t_zero, t_zero + t)
        cloned[:, :,t_zero:mask_end] = 0
    return cloned




def freq_mask(spec):
    cloned = spec.clone().detach()
    num_mel_channels = cloned.shape[1]
    num_masks = np.random.randint(3, 8)
    F = np.random.randint(10, 30)
    for i in range(0, num_masks):        
        f = random.randrange(0, F)
        f_zero = random.randrange(0, num_mel_channels - f)

        # avoids randrange error if values are equal and range is empty
        if (f_zero == f_zero + f): return cloned

        mask_end = random.randrange(f_zero, f_zero + f) 
        cloned[:, f_zero:mask_end, :] = 0
    
    return cloned


def normalize(data, sz_t=128):     
    time_ids = {"H1": data["H1_ts"], "L1": data["L1_ts"]}
    mean_statH = binned_statistic(
        time_ids["H1"],
        np.abs(data["H1_SFTs_amplitudes"] * 1e22) ** 2,
        statistic="mean",
        bins=sz_t,
        range=(
            max(time_ids["H1"].min(), time_ids["L1"].min()),
            min(time_ids["H1"].max(), time_ids["L1"].max()),
        ),
    )
    mean_statL = binned_statistic(
        time_ids["L1"],
        np.abs(data["L1_SFTs_amplitudes"] * 1e22) ** 2,
        statistic="mean",
        bins=sz_t,
        range=(
            max(time_ids["H1"].min(), time_ids["L1"].min()),
            min(time_ids["H1"].max(), time_ids["L1"].max()),
        ),
    )
    mean_statH = np.nan_to_num(np.transpose(mean_statH.statistic, (0, 1)))
    mean_statL = np.nan_to_num(np.transpose(mean_statL.statistic, (0, 1)))

    x = torch.from_numpy(np.stack([mean_statH, mean_statL], 0).astype(np.float32))
    c, h, w = x.shape
    x[0] /= torch.max(x[0].mean(0, keepdim=True), 0.1 * torch.ones_like(x[0]))
    x[1] /= torch.max(x[1].mean(0, keepdim=True), 0.1 * torch.ones_like(x[1]))
    #x = torch.cat([x, 0.5 * (x[0] + x[1]).unsqueeze(0)], 0)
    return x
    
    
def combine(sig_, noise, w=1.):
    sig = copy(sig_)
    min_value_l= min(sig['L1_SFTs_amplitudes'].shape[1], noise['L1_SFTs_amplitudes'].shape[1])
    min_value_h= min(sig['H1_SFTs_amplitudes'].shape[1], noise['H1_SFTs_amplitudes'].shape[1])
    sig['L1_SFTs_amplitudes'] = w * sig['L1_SFTs_amplitudes'][:, :min_value_l] + noise['L1_SFTs_amplitudes'][:, :min_value_l]
    sig['H1_SFTs_amplitudes'] = w * sig['H1_SFTs_amplitudes'][:, :min_value_h] + noise['H1_SFTs_amplitudes'][:, :min_value_h]
    sig['H1_ts'] = sig['H1_ts'][:min_value_h]
    sig['L1_ts'] = sig['L1_ts'][:min_value_l]
    return sig


In [6]:
signal = list(Path("../data/custom_data/SIGNAL_V0/data").glob("*.pth")) + list(Path("../data/custom_data/SIGNAL_V1/data").glob("*.pth"))
len(signal)
noise = []

In [7]:
comp_train = pd.read_csv("../data/train_labels.csv")
comp_train.columns = ["fn", "target"]
comp_train = comp_train.query("target>=0")
comp_train["fn"] = comp_train["fn"].apply(lambda x: Path("../data/train") / f"{x}.hdf5")
comp_train.columns = ["id", "target"]
comp_train["data_type"] = "comp_train"

df_eval = pd.read_csv('../../val/v21v.csv')
df_eval.id = df_eval.id.apply(lambda x: Path(f"../../val/v21_val/{x}.pickle"))

df_eval = pd.concat([df_eval, comp_train], ignore_index=True)

In [8]:
# Train - val split
fold =0
trn_ds = G2NetDatasetTrain(noise, signal)
vld_ds = ValLoaderPickle(df_eval)

trn_dl = DataLoader(
    trn_ds,
    batch_size=CFG.bs,
    shuffle=True,
    num_workers=CFG.nw,
    pin_memory=True,
    drop_last=True,
    persistent_workers=True,
)
vld_dl = DataLoader(
    vld_ds,
    batch_size=CFG.bs,
    shuffle=False,
    num_workers=CFG.nw,
    pin_memory=True,
)

custom_model = create_model(
                    CFG.model_name,
                    pretrained=True,
                    num_classes=1,
                    in_chans=2,
                )



opt = torch.optim.AdamW(custom_model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
loss_func = BCEWithLogitsLossFlat()
warmup_steps = 1000
total_steps = int(len(trn_dl) * CFG.epoch)
sched = get_linear_schedule_with_warmup(
    opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
fit_mixup(
    epochs=CFG.epoch,
    model=custom_model,
    train_dl=trn_dl,
    valid_dl=vld_dl,
    loss_fn=loss_func,
    opt=opt,
    val_df=df_eval,
    metric=custom_auc_score,
    folder=CFG.folder,
    exp_name=f"{CFG.exp_name}_{fold}",
    device="cuda:0",
    sched=sched,
)


epoch,train_loss,valid_loss,val_metric,roc_all,roc_comp_train
0,0.648979,0.643454,0.432039,0.431804,0.331912
1,0.639805,0.643413,0.487041,0.486863,0.45455
2,0.639623,0.641968,0.57786,0.577776,0.643331
3,0.646103,0.645419,0.415347,0.415094,0.324937
4,0.638541,0.641891,0.570926,0.570835,0.672444
5,0.63967,0.645749,0.604045,0.603989,0.73835
6,0.640141,0.642222,0.477221,0.477032,0.406806
7,0.644841,0.642238,0.50982,0.509665,0.518169
8,0.639285,0.642076,0.557189,0.557084,0.631081
9,0.639697,0.641939,0.397522,0.397251,0.308925


Better model found at epoch 0 with value: 0.4320393994206779.
Better model found at epoch 1 with value: 0.4870412404302893.
Better model found at epoch 2 with value: 0.5778600958534994.
Better model found at epoch 3 with value: 0.41534698402937775.
Better model found at epoch 4 with value: 0.5709258481500424.
Better model found at epoch 5 with value: 0.604045417459063.
Better model found at epoch 6 with value: 0.47722090917439886.
Better model found at epoch 7 with value: 0.509819968256953.
Better model found at epoch 8 with value: 0.5571890911075305.
Better model found at epoch 9 with value: 0.39752228420227415.
Better model found at epoch 10 with value: 0.4219102165687961.
Better model found at epoch 11 with value: 0.40590265917291657.
Better model found at epoch 12 with value: 0.4303653995704755.
Better model found at epoch 13 with value: 0.417066751995617.
Better model found at epoch 14 with value: 0.44222145911778915.
Better model found at epoch 15 with value: 0.390639158849182.
B

KeyboardInterrupt: 

In [None]:
#trn_ds = G2NetDatasetTrain(noise, signal)
#for i in range(len(trn_ds)):
#    x, y = trn_ds[i]
#    print(y)
#    plt.imshow(x.mean(0))
#    plt.pause(0.1)
#    