In [1]:
from einops import rearrange
import copy
import h5py
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from pdb import set_trace
import matplotlib.pyplot as plt
from torch import nn
from x_transformers import  Encoder, Decoder
from x_transformers.autoregressive_wrapper import exists
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from fastai.vision.all import BCEWithLogitsLossFlat
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from fastprogress.fastprogress import master_bar, progress_bar
import os
from timm import create_model
import random
from joblib import Parallel, delayed
from tqdm import tqdm
from secrets import token_hex

def get_random_name(len_k=16):
    token = token_hex(len_k)
    return token




In [None]:
def read_torch(fn):
    data = torch.load(fn)
    data_dict = {"sft" : np.stack([data['H1_SFTs_amplitudes'][:, :4096], data['L1_SFTs_amplitudes'][:, :4096]])}
    return data_dict
                 
    
def preprocess(sft):
    sft = sft * 1e22
    sft = sft.real**2 + sft.imag**2
    return sft


def normalize(data):
    data[0] = data[0] / data[0].mean()
    data[1] = data[1] / data[1].mean()
    data = data.reshape(2, 360, 128, 32).mean(-1)  # compress 4096 -> 128
    data = data - data.mean()
    data = data / data.std()
    return torch.tensor(data)


def read_h5(file):
    file = Path(file)
    with h5py.File(file, "r") as f:
        filename = file.stem
        k = f[filename]
        h1 = k["H1"]
        l1 = k["L1"]
        h1_stft = h1["SFTs"][()]
        h1_timestamp = h1["timestamps_GPS"][()]
        l1_stft = l1["SFTs"][()]
        l1_timestamp = l1["timestamps_GPS"][()]

        data_dict = {
            "sft": np.stack([h1_stft[:, :4096], l1_stft[:, :4096]]),
            "timestamps": {"H1": h1_timestamp, "L1": l1_timestamp},
        }
        return data_dict
#generating valid
#class ValLoader(torch.utils.data.Dataset):
#    """
#    dataset = Dataset(data_type, df)
#
#    img, y = dataset[i]
#      img (np.float32): 2 x 360 x 128
#      y (np.float32): label 0 or 1
#    """
#    def __init__(self, df, freq_tfms=False):
#        self.df = df
#        self.tfms = freq_tfms
#        
#
#    def __len__(self):
#        return len(self.df)
#
#    def __getitem__(self, i):
#        """
#        i (int): get ith data
#        """
#        r = self.df.iloc[i]
#        y = np.float32(r.target)
#        img = normalize(preprocess(read_h5(r.id)['sft']))
#        data = {"sft": img, "target": y}
#        return data, f"{Path(r.id).stem}.pth"
#
#
#
#val_df = pd.read_csv('../data/SPLITS/V_22/val_df.csv')
#comp_train = pd.read_csv('../data/train_labels.csv')
#comp_train.columns = ['fn', 'target']
#comp_train = comp_train.query('target>=0')
#comp_train['fn'] = comp_train['fn'].apply(lambda x: Path('../data/train')/f'{x}.hdf5')
#comp_train.columns = ['id', 'target']
#comp_train['data_type'] = 'comp_train'
#real_noise_fns =  sorted(
#            Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
#            key=lambda x: str(x).split("_")[-2])
#real_noise_df = pd.DataFrame({"id": real_noise_fns[1100:], 'target': 0., 'snr': 0})
#real_noise_df['id'] = real_noise_df['id'].apply(lambda x: Path(str(x).replace('.pth', '.h5')))
#
#val_df = pd.concat([val_df, comp_train, real_noise_df], ignore_index=True)
#vld_ds = ValLoader(val_df)
#folder_name = Path('cashe_dataset_eval')
#os.makedirs(folder_name, exist_ok=True)
#for i in tqdm(range(len(vld_ds))):
#    data, name = vld_ds[i]
#    torch.save(data, folder_name/name)
#

In [None]:
def save_preoprocessed(save_folder, noise_fns, signal_fns):
    noise_fn = random.choice(noise_fns)
    signal_fn = random.choice(signal_fns)
    
    if np.random.rand() >= 0.6:
        img = normalize(preprocess(read_torch(noise_fn)["sft"])).numpy()
        y = 0.0
        if np.random.rand() <= 0.5:  # horizontal flip
            img = np.flip(img, axis=1).copy()
        if np.random.rand() <= 0.5:  # vertical flip
            img = np.flip(img, axis=2).copy()
        if np.random.rand() <= 0.5:  # vertical shift
            img = np.roll(img, np.random.randint(low=0, high=img.shape[1]), axis=1)
        if np.random.rand() <= 0.5:  # channel shuffle
            img = img[np.random.permutation([0, 1]), ...]
        img = torch.tensor(img)
        name = get_random_name() + '.pth'
    else:
        
        img = normalize(
            preprocess(read_torch(noise_fn)["sft"] + read_torch(signal_fn)["sft"])
        )
        y = 1.0
        name = noise_fn.stem + '_' + signal_fn.stem + '.pth'
    data = {"sft": img, "target": y}
   
    return torch.save(data, save_folder/name)

signal = list(Path("../data/custom_data/SIGNAL_V0/data").glob("*.pth")) + list(
    Path("../data/custom_data/SIGNAL_V1/data").glob("*.pth")
)

real_noise_fns = sorted(
    Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)

fake_noise_fns = sorted(
    Path("../data/custom_data/DATA_V34/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)

noise_sim = list(Path("../data/custom_data/DATA_V31_V32_NOISE/").glob("*.pth"))


noise = real_noise_fns[:1100] + fake_noise_fns + list(np.random.choice(noise_sim, 2000))


len(signal), len(noise)

In [None]:
folder_name = Path('cashe_dataset')
n_samples = 15000
os.makedirs(folder_name, exist_ok=True)
#for i in tqdm(range(1000)):
#    save_preoprocessed(folder_name, noise, signal)
#
Parallel(n_jobs=32)(
    delayed(save_preoprocessed)(save_folder=folder_name, noise_fns=noise, signal_fns=signal)
    for i in tqdm(range(n_samples))
)

In [None]:
#signal = list(Path("../data/custom_data/SIGNAL_V0/data").glob("*.pth"))
#
#real_noise_fns = sorted(
#    Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
#    key=lambda x: str(x).split("_")[-2],
#)
#
#fake_noise_fns = sorted(
#    Path("../data/custom_data/DATA_V34/data/").glob("*.pth"),
#    key=lambda x: str(x).split("_")[-2],
#)
#
#noise = list(Path('../data/custom_data/DATA_V31_V32_NOISE').glob('*.pth'))
#
#
#len(signal), len(noise)

In [None]:
#folder_name = Path('cashe_dataset')
#n_samples = 20000
#os.makedirs(folder_name, exist_ok=True)
#Parallel(n_jobs=16)(
#    delayed(save_preoprocessed)(save_folder=folder_name, noise_fns=noise, signal_fns=signal)
#    for i in tqdm(range(n_samples))
#)

In [None]:
def save_preoprocessed(save_folder, noise_fns, signal_fns):
    noise_fn = random.choice(noise_fns)
    signal_fn = random.choice(signal_fns)
    prob = np.random.beta(0.5, 0.5)
    
    if np.random.rand() >= 0.5:
        noise_2 = prob * torch.load(random.choice(noise_fns))['sft']
        noise_1 = (1- prob) * torch.load(noise_fn)["sft"]
        img = normalize(preprocess(noise_2 + noise_1))
        y = 0.0
        name = f'mix_{prob}_{1-prob}_' + noise_fn.stem + '.pth'
    else:
        
        img = normalize(
            preprocess(prob * torch.load(noise_fn)["sft"] +  (1-prob)*torch.load(signal_fn)["sft"])
        )
        y = 1.0
        name =  f'mix_{prob}_{1-prob}_' + noise_fn.stem + '_' + signal_fn.stem + '.pth'
    data = {"sft": img, "target": y}
   
    return torch.save(data, save_folder/name)

In [None]:
fns = list(Path('cashe_dataset').glob('*.pth'))
len(fns)

In [None]:
signal = list(Path("../data/custom_data/SIGNAL_V0/data").glob("*.pth"))

real_noise_fns = sorted(
    Path("../data/custom_data/DATA_V33/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)

fake_noise_fns = sorted(
    Path("../data/custom_data/DATA_V34/data/").glob("*.pth"),
    key=lambda x: str(x).split("_")[-2],
)

noise = real_noise_fns[:1100] + fake_noise_fns + list(Path('../data/custom_data/DATA_V31_V32_NOISE').glob('*.pth'))
len(signal), len(noise)


In [None]:
folder_name = Path('cashe_dataset')
n_samples = 25000
os.makedirs(folder_name, exist_ok=True)
Parallel(n_jobs=16)(
    delayed(save_preoprocessed)(save_folder=folder_name, noise_fns=noise, signal_fns=signal)
    for i in tqdm(range(n_samples))
)

In [None]:
fns = list(Path('cashe_dataset').glob('*.pth'))
for i in tqdm(fns):
    try: torch.load(i)
    except: 
        print(i)
        os.remove(i)

 43%|████████████████████████████████▏                                          | 32118/74981 [00:43<00:57, 740.03it/s]