In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1,0"

In [2]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import scipy.signal
import pandas as pd
from tqdm.notebook import tqdm
import gc
import pickle

In [3]:
#rm -r '/media/iafoss/New Volume/ML/G2Net2022/data/denoise_apr_H'

In [4]:
PATH = '/media/iafoss/New Volume/ML/G2Net2022/data/test'
PATH_EX = '/media/iafoss/New Volume/ML/G2Net2022/data/external_apr_H'
OUT = '/media/iafoss/New Volume/ML/G2Net2022/data/denoise_apr_H'
SOURCE = 'H1_SFTs_amplitudes'

files = [os.path.join(PATH_EX,f) for f in sorted(os.listdir(PATH_EX))]
T = 1800
SR = 4096 #16384 # !!! change if work with 4096 SR data
SZ = 360
TH = 10#1.5
os.makedirs(OUT, exist_ok=True)

df = pd.read_csv('data/test_stationery.csv')
df = df.loc[~df.stationery]
df.head()

Unnamed: 0,id,stationery,freq
14,004f1b282,False,391.430278
17,006e25113,False,479.295278
23,008ec5560,False,317.4775
24,00948246a,False,227.008611
38,0112d6cc3,False,335.5375


In [5]:
def read_hdf5(fname):
    with h5py.File(fname, 'r') as f:
        strain = f['strain']['Strain'][:]
        ts = f['strain']['Strain'].attrs['Xspacing']

        metaKeys = f['meta'].keys()
        meta = f['meta']
        gpsStart = meta['GPSstart'][()]
        duration = meta['Duration'][()]
        has_nan = strain[np.isnan(strain)].size > 0
    return {'strain':strain, 'ts':ts, 
            'gpsStart':gpsStart, 'duration':duration, 'has_nan':has_nan}

def extract_data_from_hdf5(path):
    data = {}
    with h5py.File(path, "r") as f:
        ID_key = list(f.keys())[0]
        # Retrieve the frequency data
        data['freq'] = np.array(f[ID_key]['frequency_Hz'])
        # Retrieve the Livingston decector data
        data['L1_SFTs_amplitudes'] = np.array(f[ID_key]['L1']['SFTs'])
        data['L1_ts'] = np.array(f[ID_key]['L1']['timestamps_GPS'])
        # Retrieve the Hanford decector data
        data['H1_SFTs_amplitudes'] = np.array(f[ID_key]['H1']['SFTs'])
        data['H1_ts'] = np.array(f[ID_key]['H1']['timestamps_GPS'])
    return data

class Model_FFT(nn.Module):
    def __init__(self, N=SR*T, sr=SR):
        super().__init__()
        window = scipy.signal.windows.tukey(N, 0.001)
        self.window = nn.Parameter(torch.from_numpy(window),requires_grad=False)
        self.range = [89500,901500] #50-500 Hz
        self.freq = (np.fft.rfftfreq(N)*sr)[self.range[0]:self.range[1]]
        self.sr, self.N = sr, N
        
    def forward(self, x):
        with torch.no_grad():
            ys,shifts = [],[]
            for i in range(0,x.shape[-1] - self.N, self.sr):
                xi = x[i:i+self.N]
                if torch.isnan(xi).any(-1): continue
                y = torch.fft.rfft(xi*self.window)[self.range[0]:self.range[1]] / self.sr
                y = (y*1e22).abs().float().cpu()
                ys.append(y)
                shifts.append(i//self.sr)
        if len(ys) > 0: return torch.stack(ys,0), torch.LongTensor(shifts)
        else: return None,None

In [None]:
fft_model = Model_FFT().cuda()
freq = fft_model.freq

data_prev = None
for fname in files:
    print(fname)
    data = torch.from_numpy(read_hdf5(fname)['strain'])
    if data_prev is not None:
        data = torch.cat([data_prev[max(0,len(data_prev)-SR*T):],data])
    data_prev = data
    stfts,shifts = fft_model(data.float().cuda())
    if stfts is None: continue
    
    for index, row in tqdm(df.iterrows(), total=len(df)):
        idx = row['id']
        data_src = extract_data_from_hdf5(os.path.join(PATH, idx+'.hdf5'))
        freq_start = (np.abs(freq - data_src['freq'][0])).argmin()

        tgt = stfts[:,freq_start:freq_start+SZ]
        src = torch.from_numpy(np.abs(data_src[SOURCE]*1e22)).permute(1,0)
        dists = torch.cdist(src.cuda(),tgt.cuda()).cpu()

        #print(dists.min())
        if dists.min() < TH:
            if os.path.isfile(os.path.join(OUT,idx+'.pickle')):
                with open(os.path.join(OUT,idx+'.pickle'), 'rb') as handle:
                    denoised_data = pickle.load(handle)
            else:
                denoised_data = {}
            
            values,indices = dists.min(-1)
            #print(idx, indices[values < TH], torch.where(values < TH)) ##
            dif = (src[values < TH] - tgt[indices[values < TH]])
            for dif_i, i in zip(dif,torch.where(values < TH)[0]):
                denoised_data[i.item()] = dif_i.numpy()
            with open(os.path.join(OUT,idx+'.pickle'), 'wb') as f:
                pickle.dump(denoised_data, f, protocol=pickle.HIGHEST_PROTOCOL)
            #break
    del stfts
    gc.collect()
    #break ##

/media/iafoss/New Volume/ML/G2Net2022/data/external_apr_H/H-H1_GWOSC_O3a_4KHZ_R1-1238163456-4096.hdf5
/media/iafoss/New Volume/ML/G2Net2022/data/external_apr_H/H-H1_GWOSC_O3a_4KHZ_R1-1238167552-4096.hdf5
