In [None]:
from einops import rearrange
import copy
import h5py
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from pdb import set_trace
import matplotlib.pyplot as plt
from torch import nn
from x_transformers import  Encoder, Decoder
from x_transformers.autoregressive_wrapper import exists
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from fastai.vision.all import BCEWithLogitsLossFlat
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from fastprogress.fastprogress import master_bar, progress_bar
import os
from timm import create_model
import random
from joblib import Parallel, delayed
from tqdm import tqdm
from secrets import token_hex
from scipy.stats import binned_statistic

def get_random_name(len_k=16):
    token = token_hex(len_k)
    return token
from copy import copy

In [None]:
#h5 for DATA_V31, DATA_V32, train
#everything else is .pth

In [None]:
    
def read_data(path):
    data = {}
    with h5py.File(path, "r") as f:
        ID_key =  path.stem
        # Retrieve the frequency data
        try:
            data['freq'] = np.array(f['frequency_Hz'])
        except:
            data['freq'] = np.array(f[ID_key]['frequency_Hz'])
        # Retrieve the Livingston decector data
        data['L1_SFTs_amplitudes'] = np.array(f[ID_key]['L1']['SFTs'])
        data['L1_ts'] = np.array(f[ID_key]['L1']['timestamps_GPS'])
        # Retrieve the Hanford decector data
        data['H1_SFTs_amplitudes'] = np.array(f[ID_key]['H1']['SFTs'])
        data['H1_ts'] = np.array(f[ID_key]['H1']['timestamps_GPS'])
    return data
    
    
def normalize(data, sz_t=128):     
    time_ids = {"H1": data["H1_ts"], "L1": data["L1_ts"]}
    mean_statH = binned_statistic(
        time_ids["H1"],
        np.abs(data["H1_SFTs_amplitudes"] * 1e22) ** 2,
        statistic="mean",
        bins=sz_t,
        range=(
            max(time_ids["H1"].min(), time_ids["L1"].min()),
            min(time_ids["H1"].max(), time_ids["L1"].max()),
        ),
    )
    mean_statL = binned_statistic(
        time_ids["L1"],
        np.abs(data["L1_SFTs_amplitudes"] * 1e22) ** 2,
        statistic="mean",
        bins=sz_t,
        range=(
            max(time_ids["H1"].min(), time_ids["L1"].min()),
            min(time_ids["H1"].max(), time_ids["L1"].max()),
        ),
    )
    mean_statH = np.nan_to_num(np.transpose(mean_statH.statistic, (0, 1)))
    mean_statL = np.nan_to_num(np.transpose(mean_statL.statistic, (0, 1)))

    x = torch.from_numpy(np.stack([mean_statH, mean_statL], 0).astype(np.float32))
    c, h, w = x.shape
    x[0] /= torch.max(x[0].mean(0, keepdim=True), 0.1 * torch.ones_like(x[0]))
    x[1] /= torch.max(x[1].mean(0, keepdim=True), 0.1 * torch.ones_like(x[1]))
    #x = torch.cat([x, 0.5 * (x[0] + x[1]).unsqueeze(0)], 0)
    return x
    
#generating valid
class ValLoader(torch.utils.data.Dataset):
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """
    def __init__(self, df, freq_tfms=False):
        self.df = df
        self.tfms = freq_tfms
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """
        i (int): get ith data
        """
        r = self.df.iloc[i]
        y = np.float32(r.target)
        img = normalize(read_data(r.id))
        data = {"sft": img, "target": y}
        return data, f"{Path(r.id).stem}.pth"



#val_df = pd.read_csv('../data/SPLITS/V_22/val_df.csv')
#comp_train = pd.read_csv('../data/train_labels.csv')
#comp_train.columns = ['fn', 'target']
#comp_train = comp_train.query('target>=0')
#comp_train['fn'] = comp_train['fn'].apply(lambda x: Path('../data/train')/f'{x}.hdf5')
#comp_train.columns = ['id', 'target']
#comp_train['data_type'] = 'comp_train'
#real_noise_fns =  sorted(
#            Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
#            key=lambda x: str(x).split("_")[-2])
#real_noise_df = pd.DataFrame({"id": real_noise_fns[1100:], 'target': 0., 'snr': 0})
#real_noise_df['id'] = real_noise_df['id'].apply(lambda x: Path(str(x).replace('.pth', '.h5')))
#
#val_df = pd.concat([val_df, comp_train, real_noise_df], ignore_index=True)
#val_df['id'] = val_df['id'].apply(lambda x: Path(x))
#vld_ds = ValLoader(val_df)
#folder_name = Path('cashe_dataset_eval')
#os.makedirs(folder_name, exist_ok=True)
#for i in tqdm(range(len(vld_ds))):
#    data, name = vld_ds[i]
#    torch.save(data, folder_name/name)
#

In [None]:
def combine(sig_, noise, w=1.):
    sig = copy(sig_)
    min_value_l= min(sig['L1_SFTs_amplitudes'].shape[1], noise['L1_SFTs_amplitudes'].shape[1])
    min_value_h= min(sig['H1_SFTs_amplitudes'].shape[1], noise['H1_SFTs_amplitudes'].shape[1])
    sig['L1_SFTs_amplitudes'] = w * sig['L1_SFTs_amplitudes'][:, :min_value_l] + noise['L1_SFTs_amplitudes'][:, :min_value_l]
    sig['H1_SFTs_amplitudes'] = w * sig['H1_SFTs_amplitudes'][:, :min_value_h] + noise['H1_SFTs_amplitudes'][:, :min_value_h]
    sig['H1_ts'] = sig['H1_ts'][:min_value_h]
    sig['L1_ts'] = sig['L1_ts'][:min_value_l]
    return sig


def save_preoprocessed(save_folder, noise_fns, signal_fns):
    noise_fn = random.choice(noise_fns)
    signal_fn = random.choice(signal_fns)
    
    if np.random.rand() >= 0.66:
        img = normalize(torch.load(noise_fn)).numpy()
        y = 0.0
        if np.random.rand() <= 0.5:  # horizontal flip
            img = np.flip(img, axis=1).copy()
        if np.random.rand() <= 0.5:  # vertical flip
            img = np.flip(img, axis=2).copy()
        if np.random.rand() <= 0.5:  # vertical shift
            img = np.roll(img, np.random.randint(low=0, high=img.shape[1]), axis=1)
        if np.random.rand() <= 0.5:  # channel shuffle
            img = img[np.random.permutation([0, 1]), ...]
        img = torch.tensor(img)
        name = get_random_name() + '.pth'
    else:
        
        img = normalize(
            combine(torch.load(signal_fn), torch.load(noise_fn))
        )
        y = 1.0
        name = noise_fn.stem + '_' + signal_fn.stem + '.pth'
    data = {"sft": img, "target": y}
   
    return torch.save(data, save_folder/name)

signal = list(Path("../data/custom_data/SIGNAL_V0/data").glob("*.pth")) + list(Path("../data/custom_data/SIGNAL_V1/data").glob("*.pth"))

real_noise_fns = sorted(
    Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)

fake_noise_fns = sorted(
    Path("../data/custom_data/DATA_V34/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)

noise_sim = list(Path("../data/custom_data/DATA_V31_V32_NOISE/").glob("*.pth"))

noise = real_noise_fns[:1100] + fake_noise_fns  + noise_sim


len(signal), len(noise)

In [None]:
folder_name = Path('cashe_dataset')
n_samples = 20000
os.makedirs(folder_name, exist_ok=True)
#for i in tqdm(range(1000)):
#    save_preoprocessed(folder_name, noise, signal)
#
Parallel(n_jobs=16)(
    delayed(save_preoprocessed)(save_folder=folder_name, noise_fns=noise, signal_fns=signal)
    for i in tqdm(range(n_samples))
)

In [None]:
#def cehck_torch(x):
#    try:
#        torch.load(x)
#        return 0
#    except:
#        return x

In [None]:
#for i in tqdm(signal):
#    try:
#        torch.load(i)
#    except:
#        print(i)
#    

In [None]:
#out = Parallel(n_jobs=16)(
#    delayed(cehck_torch)(i)
#    for i in tqdm(signal)
#)