In [1]:
from einops import rearrange
import copy
import h5py
from pathlib import Path
import numpy as np
import pandas as pd

import torch
torch.cuda.set_device(1)

from pdb import set_trace
import matplotlib.pyplot as plt
from torch import nn
from x_transformers import  Encoder, Decoder
from x_transformers.autoregressive_wrapper import exists
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from fastai.vision.all import BCEWithLogitsLossFlat
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from fastprogress.fastprogress import master_bar, progress_bar

from timm import create_model
import random
from tqdm import tqdm

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
from vit_pytorch.vit_with_patch_merger import ViT
from vit_pytorch.crossformer import CrossFormer

In [4]:
class CFG:
    bs = 128
    nw = 4
    model_name = "VIT"
    lr = 1e-3
    wd = 1e-4
    epoch = 100
    warmup_pct = 0.1
    num_classes = 1
    dropout_rate = 0.3
    folder = "EXP_200_BASELINE_CASHE_VIT"
    mixup=False
    exp_name = f"{folder}_{model_name}"

In [5]:
def get_snr(left, right, df):
    df_ = pd.concat([df.query(f"snr>{left} & snr<{right}"), df.query("snr==0")])
    return df_


def generate_report(df, p, fn):
    pred = torch.sigmoid(p).cpu().numpy().reshape(-1)
    val_df_eval = df.copy()
    val_df_eval["pred"] = pred
    val_df_eval.to_csv(f"{fn}_oof.csv")

    roc_100 = roc_auc_score(val_df_eval["target"], val_df_eval["pred"])
    
    tr_comp = val_df_eval.query('data_type=="comp_train"')
    roc_comp_train = roc_auc_score(tr_comp['target'], tr_comp['pred'])
    roc_0_50 = roc_auc_score(
        get_snr(0, 50, val_df_eval)["target"], get_snr(0, 50, val_df_eval)["pred"]
    )
    roc_15_50 = roc_auc_score(
        get_snr(15, 50, val_df_eval)["target"], get_snr(15, 50, val_df_eval)["pred"]
    )
    roc_25_50 = roc_auc_score(
        get_snr(25, 50, val_df_eval)["target"], get_snr(25, 50, val_df_eval)["pred"]
    )
    roc_0_40 = roc_auc_score(
        get_snr(0, 40, val_df_eval)["target"], get_snr(0, 40, val_df_eval)["pred"]
    )

    roc_0_30 = roc_auc_score(
        get_snr(0, 30, val_df_eval)["target"], get_snr(0, 30, val_df_eval)["pred"]
    )

    return {
        "roc_all": roc_100,
        "roc_0_50": roc_0_50,
        "roc_15_50": roc_15_50,
        "roc_25_50": roc_25_50,
        "roc_0_40": roc_0_40,
        "roc_0_30": roc_0_30,
        "roc_comp_train": roc_comp_train
    }

class SaveModel:
    def __init__(self, folder, exp_name, best=np.inf):
        self.best = best
        self.folder = Path(folder) / f"{exp_name}.pth"

    def __call__(self, score, model, epoch):
        if score < self.best:
            self.best = score
            print(f"Better model found at epoch {epoch} with value: {self.best}.")
            torch.save(model.state_dict(), self.folder)


class SaveModelMetric:
    def __init__(self, folder, exp_name, best=-np.inf):
        self.best = best
        self.folder = Path(folder) / f"{exp_name}.pth"

    def __call__(self, score, model, epoch):
        if score > self.best:
            self.best = score
            print(f"Better model found at epoch {epoch} with value: {self.best}.")
            torch.save(model.state_dict(), self.folder)


class SaveModelEpoch:
    def __init__(self, folder, exp_name, best=-np.inf):
        self.best = best
        self.folder = Path(folder)
        self.exp_name = exp_name

    def __call__(self, score, model, epoch):
        self.best = score
        print(f"Better model found at epoch {epoch} with value: {self.best}.")
        torch.save(model.state_dict(), f"{self.folder/self.exp_name}_{epoch}.pth")


def custom_auc_score(p, gt):
    return roc_auc_score(gt.cpu().numpy(),  torch.sigmoid(p).cpu().numpy().reshape(-1))


def fit_mixup(
    epochs,
    model,
    train_dl,
    valid_dl,
    loss_fn,
    opt,
    metric,
    val_df,
    folder="models",
    exp_name="exp_00",
    device=None,
    sched=None,
    mixup_=False,
    save_md=SaveModel,
):

    os.makedirs(folder, exist_ok=True)
    loss_fn_trn = loss_fn
    if mixup_:
        mixup = Mixup(num_classes=2, mixup_alpha=0.4, prob=0.8)
        loss_fn_trn = BinaryCrossEntropy()
    mb = master_bar(range(epochs))

    mb.write(
        [
            "epoch",
            "train_loss",
            "valid_loss",
            "val_metric",
            "roc_all",
            "roc_0_50",
            "roc_15_50",
            "roc_25_50",
            "roc_0_40",
            "roc_0_30",
            "roc_comp_train",
        ],
        table=True,
    )
    model.to(device)  # we have to put our model on gpu
    #scaler = torch.cuda.amp.GradScaler()  # this for half precision training
    save_md = save_md(folder, exp_name)

    for i in mb:  # iterating  epoch
        trn_loss, val_loss = 0.0, 0.0
        trn_n, val_n = len(train_dl.dataset), len(valid_dl.dataset)
        model.train()  # set model for training
        for (xb, yb) in progress_bar(train_dl, parent=mb):
            xb, yb = xb.to(device), yb.to(device)  # putting batches to device
            if mixup_:
                xb, yb = mixup(xb, yb)
           
            out = model(xb)  # forward pass
            loss = loss_fn_trn(out, yb)  # calulation loss

            trn_loss += loss.item()
            #print(loss.item())
            opt.zero_grad()  # zeroing optimizer
            loss.backward()  # backward
            opt.step()  # optimzers step
            if sched is not None:
                sched.step()  # scuedular step

        trn_loss /= mb.child.total

        # putting model in eval mode
        model.eval()
        gt = []
        pred = []
        # after epooch is done we can run a validation dataloder and see how are doing
        with torch.no_grad():
            for (xb, yb) in progress_bar(valid_dl, parent=mb):
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = loss_fn(out, yb)
                val_loss += loss.item()

                gt.append(yb.detach())
                pred.append(out.detach())
        # calculating metric
        metric_ = metric(torch.cat(pred), torch.cat(gt))
        # saving model if necessary
        save_md(metric_, model, i)
        val_loss /= mb.child.total
        dict_res = generate_report(val_df, torch.cat(pred), f"{folder}/{exp_name}_{i}")

        pd.DataFrame(
            {
                "trn_loss": [trn_loss],
                "val_loss": [val_loss],
                "metric": [metric_],
                "roc_all": [dict_res["roc_all"]],
                "roc_0_50": [dict_res["roc_0_50"]],
                "roc_15_50": [dict_res["roc_15_50"]],
                "roc_25_50": [dict_res["roc_25_50"]],
                "roc_0_40": [dict_res["roc_0_40"]],
                "roc_0_30": [dict_res["roc_0_30"]],
                "roc_comp_train": [dict_res["roc_comp_train"]]
            }
        ).to_csv(f"{folder}/{exp_name}_{i}.csv", index=False)
        mb.write(
            [
                i,
                f"{trn_loss:.6f}",
                f"{val_loss:.6f}",
                f"{metric_:.6f}",
                f"{dict_res['roc_all']:.6f}",
                f"{dict_res['roc_0_50']:.6f}",
                f"{dict_res['roc_15_50']:.6f}",
                f"{dict_res['roc_25_50']:.6f}",
                f"{dict_res['roc_0_40']:.6f}",
                f"{dict_res['roc_0_30']:.6f}",
                f"{dict_res['roc_comp_train']:.6f}",
            ],
            table=True,
        )
    print("Training done")
    # loading the best checkpoint

In [6]:
def time_mask(spec, T=10):
    cloned = spec.clone().detach()
    len_spectro = cloned.shape[2]
    num_masks = np.random.randint(3, 8)
    for i in range(0, num_masks):
        t = random.randrange(0, T)
        t_zero = random.randrange(0, len_spectro - t)

        # avoids randrange error if values are equal and range is empty
        if (t_zero == t_zero + t): return cloned

        mask_end = random.randrange(t_zero, t_zero + t)
        cloned[:, :,t_zero:mask_end] = 0
    return cloned




def freq_mask(spec, F=30):
    cloned = spec.clone().detach()
    num_mel_channels = cloned.shape[1]
    num_masks = np.random.randint(3, 8)
    for i in range(0, num_masks):        
        f = random.randrange(0, F)
        f_zero = random.randrange(0, num_mel_channels - f)

        # avoids randrange error if values are equal and range is empty
        if (f_zero == f_zero + f): return cloned

        mask_end = random.randrange(f_zero, f_zero + f) 
        cloned[:, f_zero:mask_end, :] = 0
    
    return cloned

In [7]:
def preprocess(sft):
    sft = sft * 1e22
    sft = sft.real**2 + sft.imag**2
    return sft


def normalize(data):
    data[0] = data[0] / data[0].mean()
    data[1] = data[1] / data[1].mean()
    data = data.reshape(2, 360, 128, 32).mean(-1)  # compress 4096 -> 128
    data = data - data.mean()
    data = data / data.std()
    return torch.tensor(data)


def read_h5(file):
    file = Path(file)
    with h5py.File(file, "r") as f:
        filename = file.stem
        k = f[filename]
        h1 = k["H1"]
        l1 = k["L1"]
        h1_stft = h1["SFTs"][()]
        h1_timestamp = h1["timestamps_GPS"][()]
        l1_stft = l1["SFTs"][()]
        l1_timestamp = l1["timestamps_GPS"][()]
        
        data_dict = {"sft" : np.stack([h1_stft[:, :4096], l1_stft[:, :4096]]), 
                 "timestamps": {"H1": h1_timestamp, 
                                    "L1": l1_timestamp}}
        
        return data_dict
    
class TrainDatasetCashe(torch.utils.data.Dataset):
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """

    def __init__(self, signal_fns, noise_fns, cashe_fns, freq_tfms=False, iteration=75000):
        self.signal_fns = signal_fns
        self.noise_fns = noise_fns
        self.cashe_fns = cashe_fns
        self.tfms = freq_tfms
        self.iteration = iteration

    def __len__(self):
        return len(self.cashe_fns)
    
    
    def generate_random_file(self):
        noise_fn = random.choice(self.noise_fns)
        signal_fn = random.choice(self.signal_fns)
        if np.random.rand() >= 0.5:
            img = normalize(preprocess(torch.load(noise_fn)["sft"]))
            y = 0.
        else:
            img = normalize(
                preprocess(torch.load(noise_fn)["sft"] + torch.load(signal_fn)["sft"])
            )
            y = 1.
        return img, y
    
    def get_cashe(self, i):
        #fn = random.choice(self.cashe_fns)
        fn = self.cashe_fns[i]
        data = torch.load(fn)
        img, y = data['sft'], data['target']
        return img, y

    def __getitem__(self, i):
        img, y = self.get_cashe(i)

        if self.tfms:
            if np.random.rand() <= 0.5:
                img = freq_mask(img)
            if np.random.rand() <= 0.5:
                img = time_mask(img)
            img = img.numpy()
            if np.random.rand() <= 0.6:  # horizontal flip
                img = np.flip(img, axis=1).copy()
            if np.random.rand() <= 0.6:  # vertical flip
                img = np.flip(img, axis=2).copy()
            if np.random.rand() <= 0.6:  # vertical shift
                img = np.roll(img, np.random.randint(low=0, high=img.shape[1]), axis=1)
            if np.random.rand() <= 0.5:  # channel shuffle
                img = img[np.random.permutation([0, 1]), ...]
        return img, y
    
class ValLoader(torch.utils.data.Dataset):
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """
    def __init__(self, df, freq_tfms=False):
        self.df = df
        self.tfms = freq_tfms
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """
        i (int): get ith data
        """
        r = self.df.iloc[i]
        y = np.float32(r.target)
        img = torch.load(r.id)['sft']
        return img, y



In [8]:
signal = list(Path('../data/custom_data/SIGNAL_V0/data').glob('*.pth'))

real_noise_fns =  sorted(
            Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
            key=lambda x: str(x).split("_")[-2])

fake_noise_fns = sorted(
    Path("../data/custom_data/DATA_V34/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)


noise = list(Path('../data/custom_data/DATA_V31_V32_NOISE').glob('*.pth')) + real_noise_fns[:1100] + fake_noise_fns
cashe_fns = list(Path('cashe_dataset').glob('*.pth'))

val_df = pd.read_csv('../data/SPLITS/V_22/val_df.csv')
comp_train = pd.read_csv('../data/train_labels.csv')
comp_train.columns = ['fn', 'target']
comp_train = comp_train.query('target>=0')
comp_train['fn'] = comp_train['fn'].apply(lambda x: Path('../data/train')/f'{x}.hdf5')
comp_train.columns = ['id', 'target']
comp_train['data_type'] = 'comp_train'
real_noise_df = pd.DataFrame({"id": real_noise_fns[1100:], 'target': 0., 'snr': 0})
real_noise_df['id'] = real_noise_df['id'].apply(lambda x: Path(str(x).replace('.pth', '.h5')))

val_df = pd.concat([val_df, comp_train, real_noise_df], ignore_index=True)
#val_df = comp_train #pd.concat([val_df, comp_train, real_noise_df], ignore_index=True)
val_df['id'] = val_df['id'].apply(lambda x: Path('cashe_dataset_eval')/f"{Path(x).stem}.pth")

len(signal), len(noise), len(cashe_fns)

(5769, 10671, 91086)

In [9]:
val_df['id']

0                 cashe_dataset_eval/hb_2a1259daf9307954b6fc3f4596948072_noise.pth
1                       cashe_dataset_eval/hb_fe2911c3fe2591b30fb9b81774055a70.pth
2                 cashe_dataset_eval/hb_3e8aa29dffdd5deeeee1ede4663b8ff4_noise.pth
3                 cashe_dataset_eval/hb_b25dfec297f7816e5c18694507262d04_noise.pth
4                 cashe_dataset_eval/hb_91b24f77852228f71680e994ff31a349_noise.pth
                                           ...                                    
4693    cashe_dataset_eval/hb_ac6b569d6acc698cfa64f65d0d61cdaa_fe6f5a121_noise.pth
4694    cashe_dataset_eval/hb_9d5b620c48841a419555bfa00af1bc29_fecaed870_noise.pth
4695    cashe_dataset_eval/hb_9e35fc491b49468dc9f3a9176c50a04c_ff771a983_noise.pth
4696    cashe_dataset_eval/hb_fba04b04ad6b588b75bcc52ec6c40b2f_ff7f1c8ba_noise.pth
4697    cashe_dataset_eval/hb_454dfb63aaa2ee07dbe36c5328b37930_fffa17f67_noise.pth
Name: id, Length: 4698, dtype: object

In [10]:
# Train - val split
fold = 0
trn_ds = TrainDatasetCashe(signal, noise, cashe_fns, True)
vld_ds = ValLoader(val_df)

trn_dl = DataLoader(
    trn_ds,
    batch_size=CFG.bs,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)
vld_dl = DataLoader(
    vld_ds,
    batch_size=CFG.bs,
    shuffle=False,
    num_workers=CFG.nw,
    pin_memory=True,
)

custom_model = ViT(
    num_classes=1,
    image_size=(360, 128),  # image size is a tuple of (height, width)
    patch_size=(10, 16),  # patch size is a tuple of (height, width)
    dim=1024,
    depth=6,
    heads=12,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1,
    channels=2,
)




opt = torch.optim.AdamW(custom_model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
loss_func = BCEWithLogitsLossFlat()
warmup_steps = int(len(trn_dl) * int(CFG.warmup_pct * CFG.epoch))
total_steps = int(len(trn_dl) * CFG.epoch)
sched = get_linear_schedule_with_warmup(
    opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

fit_mixup(
    epochs=CFG.epoch,
    model=custom_model,
    train_dl=trn_dl,
    valid_dl=vld_dl,
    loss_fn=loss_func,
    opt=opt,
    val_df=val_df,
    metric=custom_auc_score,
    folder=CFG.folder,
    exp_name=f"{CFG.exp_name}_{fold}",
    device="cuda:1",
    sched=sched,
)

epoch,train_loss,valid_loss,val_metric,roc_all,roc_0_50,roc_15_50,roc_25_50,roc_0_40,roc_0_30,roc_comp_train
0,0.676365,0.848184,0.552666,0.552666,0.549777,0.549777,0.54807,0.554396,0.58216,0.478625
1,0.660071,0.680695,0.636342,0.636342,0.512941,0.512941,0.512486,0.508467,0.520183,0.677488
2,0.614383,0.648725,0.647968,0.647968,0.510664,0.510664,0.509733,0.490859,0.495809,0.710038
3,0.603797,0.640319,0.670117,0.670117,0.527019,0.527019,0.526423,0.506798,0.518956,0.698862
4,0.609033,0.656899,0.648289,0.648289,0.523507,0.523507,0.525281,0.503218,0.493457,0.674012
5,0.607841,0.650965,0.655623,0.655623,0.516901,0.516901,0.525629,0.492881,0.469574,0.682937
6,0.617954,0.669534,0.616897,0.616897,0.49469,0.49469,0.493466,0.503021,0.542326,0.690613
7,0.615944,0.659311,0.629917,0.629917,0.504809,0.504809,0.504413,0.508478,0.534129,0.682625
8,0.611062,0.651937,0.655325,0.655325,0.523103,0.523103,0.524577,0.519965,0.533514,0.7106
9,0.611126,0.671469,0.648771,0.648771,0.522005,0.522005,0.515834,0.507154,0.538906,0.705225


Better model found at epoch 0 with value: 0.5526663367425493.
Better model found at epoch 36 with value: 0.5315411122108582.
Training done


In [None]:
df = pd.read_csv('../EXP_200/EXP_200_BASELINE_CASHE_V2/EXP_200_BASELINE_CASHE_V2_convnext_large_in22k_0_9_oof.csv')

In [None]:
from sklearn.metrics import roc_auc_score
def get_score(df):
    return roc_auc_score(df['target'], df['pred'])

def get_freq(left, right, df):
    df_ = pd.concat([df.query(f"f0>{left} & f0<{right}"), df.query("snr==0")])
    return df_

def get_snr(left, right, df):
    df_ = pd.concat([df.query(f"snr>{left} & snr<{right}"), df.query("snr==0")])
    return df_

In [None]:
get_score(get_freq(50, 100, df))

In [None]:
get_score(get_freq(100, 200, df))

In [None]:
get_score(get_freq(200, 300, df))

In [None]:
get_score(get_freq(300, 400, df))

In [None]:
get_score(get_freq(400, 500, df))

In [None]:
print("rho_10_40, freq  50 100:", get_score(get_snr(10, 40, (get_freq(50, 100, df)))))
print("rho_10_40, freq 100, 200:", get_score(get_snr(10, 40, (get_freq(100, 200, df)))))
print("rho_10_40, freq 200, 300:", get_score(get_snr(10, 40, (get_freq(200, 300, df)))))
print("rho_10_40, freq 300, 400:", get_score(get_snr(10, 40, (get_freq(300, 400, df)))))
print("rho_10_40, freq 400, 500:", get_score(get_snr(10, 40, (get_freq(400, 500, df)))))

In [None]:
bad = get_snr(10, 40, (get_freq(400, 500, df))).query('target==1')

In [None]:
def read_data(file):
    file = Path(file)
    with h5py.File(file, "r") as f:
        filename = file.stem
        k = f[filename]
        h1 = k["H1"]
        l1 = k["L1"]
        h1_stft = h1["SFTs"][()]
        h1_timestamp = h1["timestamps_GPS"][()]
        l1_stft = l1["SFTs"][()]
        l1_timestamp = l1["timestamps_GPS"][()]
        freq = f["frequency_Hz"][:]
        
        data_dict = {"sft" : np.stack([h1_stft[:, :4096], l1_stft[:, :4096]]), 
                 "timestamps": {"H1": h1_timestamp, 
                                    "L1": l1_timestamp} , 
                 "frequency": freq}
        
        return data_dict
    

def preprocess(sft):
    sft = sft * 1e22
    sft = sft.real**2 + sft.imag**2
    return sft

def h5_to_torch(fn):
    fn = Path(fn)
    data = read_data(fn)
    torch.save(data, str(fn).replace(fn.suffix, '.pth'))
    
#noise_fns = list(Path('../data/custom_data/DATA_V34/data/').glob('*.h5'))
#Parallel(n_jobs=16)(
#    delayed(h5_to_torch)(fn=i)
#    for i in tqdm(noise_fns)
#)

In [None]:
bad['id']