In [None]:
import sys

sys.path.append("..")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import h5py
import timm
import torch
import torch.nn as nn
from sklearn.model_selection import StratifiedKFold
from timm import create_model
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import DataLoader
import os
from fastprogress.fastprogress import master_bar, progress_bar
from fastai.vision.all import L, unsqueeze
from timm.data.mixup import Mixup
from timm.loss import (
    LabelSmoothingCrossEntropy,
    BinaryCrossEntropy,
    SoftTargetCrossEntropy,
)

from einops import rearrange

In [None]:
class RollTimeFreq:
    def __init__(self, shift_t_percent=0.3, shift_f_percent=0.3):
        self.shift_t_percent = shift_t_percent
        self.shift_f_percent = shift_f_percent

    def __call__(self, x):
        c, t, f = x.shape
        roll_t = np.random.randint(0, int(self.shift_t_percent * t))
        roll_f = np.random.randint(0, int(self.shift_f_percent * f))
        if np.random.rand() > 0.5:
            x = np.roll(x, roll_t, 1)
        else:
            x = np.roll(x, roll_f, 2)

        if np.random.rand() > 0.5:
            k = [1, 0]
            x = x[k, ...]
        if np.random.rand() > 0.5:
            x = np.flip(x, axis=1).copy()
        if np.random.rand() > 0.5:  # vertical flip
            x = np.flip(x, axis=2).copy()
        return x

In [None]:
def get_snr(left, right, df):
    df_ = pd.concat([df.query(f"snr>{left} & snr<{right}"), df.query("snr==0")])
    return df_


def generate_report(df, p, fn):
    pred = F.softmax(p).cpu().numpy()[:, 1]
    val_df_eval = df.copy()
    val_df_eval["pred"] = pred
    val_df_eval.to_csv(f'{fn}_oof.csv')

    roc_100 = roc_auc_score(val_df_eval["target"], val_df_eval["pred"])
    roc_50_100 = roc_auc_score(
        get_snr(50, 100, val_df_eval)["target"], get_snr(50, 100, val_df_eval)["pred"]
    )
    roc_0_50 = roc_auc_score(
        get_snr(0, 50, val_df_eval)["target"], get_snr(0, 50, val_df_eval)["pred"]
    )
    roc_0_40 = roc_auc_score(
        get_snr(0, 40, val_df_eval)["target"], get_snr(0, 40, val_df_eval)["pred"]
    )
    roc_0_30 = roc_auc_score(
        get_snr(0, 30, val_df_eval)["target"], get_snr(0, 30, val_df_eval)["pred"]
    )

    roc_0_20 = roc_auc_score(
        get_snr(0, 20, val_df_eval)["target"], get_snr(0, 20, val_df_eval)["pred"]
    )
    


    return {
        "roc_all": roc_100,
        "roc_50_100": roc_50_100,
        "roc_0_50": roc_0_50,
        "roc_0_40": roc_0_40,
        "roc_0_30": roc_0_30,
        "roc_0_20": roc_0_20,
    }

In [None]:

def scale_data(data):
     return (data-np.min(data))/(np.max(data)-np.min(data))

class DataV0():
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """
    def __init__(self, df, freq_tfms=False):
        self.df = df
        self.freq_tfms = freq_tfms
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """
        i (int): get ith data
        """
        r = self.df.iloc[i]
        y = np.float32(r.target)
        filename=r.id
        file_id = Path(r.id).stem
        img = np.empty((2, 360, 256), dtype=np.float32)
        with h5py.File(filename, 'r') as f:
            g = f[file_id]

            for ch, s in enumerate(['H1', 'L1']):
                a = g[s]['SFTs'][:, :4096] * 1e22  # Fourier coefficient complex64

                p = np.sqrt(a.real**2 + a.imag**2)  # power
        
                #p /= np.mean(p)  # normalize
                p = np.mean(p.reshape(360, 256, 16), axis=2)  # compress 4096 -> 128
                p = scale_data(p)
                img[ch] = p
        if self.freq_tfms:
            if np.random.rand()>0.5:
                img = self.freq_tfms(img)

        return img, y.astype('int')
    
    
class DataV0Torch():
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """
    def __init__(self, df, freq_tfms=False):
        self.df = df
        self.freq_tfms = freq_tfms
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """
        i (int): get ith data
        """
        r = self.df.iloc[i]
        y = np.float32(r.target)
        filename=r.id
        data = torch.load(filename)
        img = np.empty((2, 360, 256), dtype=np.float32)
        for ch, s in enumerate(['H1', 'L1']):
            p = data[s]  # Fourier coefficient complex64
            p = np.sqrt(p)  # power
            p = np.mean(p.reshape(360, 256, 16), axis=2)  # compress 4096 -> 128
            p = scale_data(p)
            img[ch] = p
        if self.freq_tfms:
            if np.random.rand()>0.5:
                img = self.freq_tfms(img)

        return img, y.astype('int')

In [None]:

df = pd.read_csv('../data/custom_data/DATA_V10/train.csv')
trn_ds = DataV0(df)
vld_dl = DataLoader(
    trn_ds,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)


In [None]:
for x, y in tqdm(vld_dl):
    pass

In [None]:
df['id'] = df['id'].apply(lambda x: x.replace('.h5', '.pth'))

In [None]:
trn_ds = DataV0Torch(df)
vld_dl = DataLoader(
    trn_ds,
    batch_size=64,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)
for x, y in tqdm(vld_dl):
    pass