In [1]:
import torch
import ml4gw

import numpy as np
from ml4gw import gw
from ml4gw import transforms 
from ml4gw.spectral import fast_spectral_density, spectral_density
from ml4gw.transforms import SnrRescaler, Whiten
from ml4gw.transforms.transform import FittableSpectralTransform
from ml4gw.distributions import LogNormal, PowerLaw

from gasf_data.utils import h5_thang

%matplotlib inline

# Predefined function

#### Sampling

In [2]:
def masking(
    glitch_info: dict,
    segment_duration: float,
    segment_start_time: float=0,
    shift_range: float = 3, 
    pad_width: float = 1.5, # Make this default to half of the kernel width
    sample_rate: int=4096, 
    merge_edges: bool=True
)->dict:
    
    """Provide a buffer mask the covers the glitch at the center of the kernel.
    

    Args:
        glitch_info (dict): Glitch trigger times by each detector.
        segment_duration (float): Duration of the background.
        segment_start_time (float): Start time of the background. Defaults to 0.
        kernel_width (float, optional): The time width to cover a glitch signal.
        The unit is second. Defaults to 3.
        pad_width (float, optional)): 
        sample_rate (int, optional): The sampling rate of the background. Defaults to 4096.
        merge_edges (bool, optional): If true it will autometically conbine glitch masks 
        if the two kernels overlap.

    Returns:
        dict: A mask that labes the idxs that covers all glitch and edges 
        by the kernel start idx and end idx for each detectors. 
    """
    
    mask_kernel = {}
    if pad_width < shift_range/2:
        raise AttributeError(f"pad_width {pad_width} is shorter than half of the kernel_width {shift_range/2}")
    
    half_window = int(shift_range*sample_rate/2)
    seg_idx_count = segment_duration*sample_rate
    

    
    for ifo, glitch_time in glitch_info.items():
        
        # Initialing the first digits in the active segments aline to t0 = 0_sec
        glitch_time -= segment_start_time
        
        # Pop out glitch that lives in the edges
        ### This popping may need another argument passing.
        glitch_time = glitch_time[glitch_time > pad_width]
        glitch_time = glitch_time[glitch_time < segment_duration - pad_width]
        
        glitch_counts = len(glitch_time)
        mask_kernel[ifo] = np.zeros((glitch_counts+2, 2)).astype("int")
        
        # Provde the pad out edges mask
        mask_kernel[ifo][0, :] = np.array([0, pad_width*sample_rate])
        mask_kernel[ifo][-1, :] = np.array([seg_idx_count-pad_width*sample_rate, seg_idx_count])
        
        # Collecting the mask by idx
        glitch_idx = (glitch_time * 4096).astype("int")
        
        mask_kernel[ifo][1:-1, 0] = (glitch_idx - half_window)
        mask_kernel[ifo][1:-1, 1] = (glitch_idx + half_window)
        
    
    if merge_edges:
        
        for ifo, mask in mask_kernel.items():
            
            mask_counts = mask.shape[0]
            for i in range(mask_counts -1 ):
                
                if mask[i,1] > mask[i+1,0]:
                    mask[i,1] = mask[i+1,0]
                    
                    
    return mask_kernel


def filtering_idxs(
    mask_dict: dict,
    *n_idxs: int,
    full: bool=False,
):
    """Find segments that 

    Takes in the labeles 
    Args:
        mask_dict (dict): _description_
        segment_dur (float): _description_
        kernel_width (int, optional): _description_. Defaults to 2.
        sample_rate (int, optional): _description_. Defaults to 4096.
        shuffle (bool, optional): _description_. Defaults to False.

    Returns:
        _type_: _description_
    """
    
    idx_dict = {}
    for ifo, mask in mask_dict.items():
    
        glitch_counts = len(mask)

        sampling_idx = []

        for i in range(glitch_counts-1):
            
            # Collecting usefull segments by its idx
            sampling_idx.append(torch.arange(mask[i,1], mask[i+1,0]))
            
        collected_idx = torch.cat(sampling_idx)
        
        
        if full:
            
            idx_dict[ifo] = collected_idx
        
            
        sampling_idx = torch.randint(0, len(collected_idx), n_idxs)

        idx_dict[ifo] = collected_idx[sampling_idx]
    
    return idx_dict


def strain_sampling(
    strain,
    mask: dict,
    sample_counts,
    sample_rate = 4096,
    kernel_width = 2,
):

    half_kernel_width_idx = int(kernel_width * sample_rate / 2)
    
    sampled_strain = torch.zeros([sample_counts, len(mask), sample_rate*kernel_width])

    # Cosider remove this part out of the function
    sampling_idx = filtering_idxs(
        mask, 
        sample_counts,
    )

    for _ , idxs in sampling_idx.items():
        for i, idx in enumerate(idxs):

            sampled_strain[i,:,:] = strain[:, idx-half_kernel_width_idx:idx+half_kernel_width_idx]
        
    return sampled_strain


def glitch_sampler(
    glitch_info,
    strain,
    segment_duration,
    segment_start_time,
    ifos,
    sample_counts,
    sample_rate = 4096,
    shift_range = 0.9,
    kernel_width = 3,
):
    
    half_kernel_width_idx = int(kernel_width * sample_rate / 2)
    
    sampled_strain = torch.zeros([sample_counts, len(ifos), sample_rate*kernel_width])

    mask_dict = masking(
        glitch_info,
        segment_duration=segment_duration,
        segment_start_time=segment_start_time,
        shift_range=shift_range,
        pad_width=kernel_width,
        sample_rate=sample_rate, 
        merge_edges = False
    )
    
    for i, ifo in enumerate(ifos):
        
        # Remove the padding mask
        mask_dict[ifo] = mask_dict[ifo][1:-1]
        
        glitch_count = len(mask_dict[ifo])
        # print(glitch_count)
        selected_glitch = np.random.randint(0, glitch_count, (sample_counts,))
        sample_center = np.random.randint(
            mask_dict[ifo][selected_glitch][:, 0], 
            mask_dict[ifo][selected_glitch][:, 1], 
            size=(sample_counts)
        )
        
        for j in range(sample_counts):
            
            start_idx = sample_center[j] - half_kernel_width_idx 
            end_idx = sample_center[j] + half_kernel_width_idx
            # print(strain[i, start_idx: end_idx].shape)
            sampled_strain[j, i, :] = strain[i, start_idx: end_idx]

    return sampled_strain

# Data Making

#### Global Variables

In [3]:
NUM_CHANNLES = 2

SAMPLE_RATE = 4096
FFTLENGTH = 2
OVERLAP = 1
# BACKGROUND_DURATION = 36000
BACKGROUND_DURATION = 4096
KERNEL_WIDTH = 3
WAVEFORM_DURATION = 3#8
# WAVEFORM_DURATION = 3
HIGHPASS = 32

ITERATION = 10
BATCH_SIZE = 1163
# ITERATION = 20
# BATCH_SIZE = 320

MIN_SNR = 8
MAX_SNR = 50
ALPHA = 3
SNR_DISTRO = PowerLaw(MIN_SNR, MAX_SNR, ALPHA)

GPSSTARTTIME = 1262471488
GPSENDTIME = 1262507488
GPSTOTALTIME = GPSENDTIME-GPSSTARTTIME # SHOULD BE BACKGROUND_DURATION

BACKGROUND_DURATION = GPSTOTALTIME

#### Some psudo data

In [4]:
# strain = torch.randn(NUM_CHANNLES, SAMPLE_RATE*BACKGROUND_DURATION)
# signals = torch.randn(ITERATION*BATCH_SIZE, NUM_CHANNLES, SAMPLE_RATE*WAVEFORM_DURATION)

# glitch_info = {
#     'H1/time': np.sort(np.random.uniform(0+0.5, BACKGROUND_DURATION-0.5, 3046)), 
#     'L1/time': np.sort(np.random.uniform(0+0.5, BACKGROUND_DURATION-0.5, 3046))
# }


# signals.shape
# strain.shape


# glitch_idx_h1 = (glitch_info['H1/time']*SAMPLE_RATE).astype("int")
# glitch_idx_l1 = (glitch_info['L1/time']*SAMPLE_RATE).astype("int")
   
# GPSSTARTTIME = 0

#### Some Real data

In [5]:
# Filenames

signalsFilename = '/home/dfredin/gwgasf/data/BBH_project.h5'


H1strainFilename_1 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262469120-4096.hdf5'
H1strainFilename_2 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262473216-4096.hdf5'
H1strainFilename_3 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262477312-4096.hdf5'
H1strainFilename_4 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262481408-4096.hdf5'
H1strainFilename_5 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262485504-4096.hdf5'
H1strainFilename_6 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262489600-4096.hdf5'
H1strainFilename_7 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262493696-4096.hdf5'
H1strainFilename_8 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262497792-4096.hdf5'
H1strainFilename_9 = '/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262501888-4096.hdf5'
H1strainFilename_10 ='/home/dfredin/gwgasf/data/H1/H-H1_GWOSC_O3b_4KHZ_R1-1262505984-4096.hdf5'

L1strainFilename_1 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262469120-4096.hdf5'
L1strainFilename_2 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262473216-4096.hdf5'
L1strainFilename_3 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262477312-4096.hdf5'
L1strainFilename_4 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262481408-4096.hdf5'
L1strainFilename_5 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262485504-4096.hdf5'
L1strainFilename_6 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262489600-4096.hdf5'
L1strainFilename_7 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262493696-4096.hdf5'
L1strainFilename_8 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262497792-4096.hdf5'
L1strainFilename_9 = '/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262501888-4096.hdf5'
L1strainFilename_10 ='/home/dfredin/gwgasf/data/L1/L-L1_GWOSC_O3b_4KHZ_R1-1262505984-4096.hdf5'


glitchFilename = '/home/dfredin/gwgasf/data/glitch_info.h5'


# BBH Signals

signal_info = h5_thang(signalsFilename).h5_data()
signal_keys = list(signal_info.keys())

signal_bbh_H1 = signal_info['waveforms_H1']
signal_bbh_L1 = signal_info['waveforms_L1']

signals = torch.FloatTensor(np.stack([signal_bbh_H1, signal_bbh_L1], axis=1))
del signal_bbh_H1, signal_bbh_L1

signals.shape



# LIGO Strain
strain_keys = h5_thang(H1strainFilename_1).h5_keys()

# H1 strain
H1strain_info_1 = h5_thang(H1strainFilename_1).h5_data()
H1strain_1 = np.array(H1strain_info_1['strain/Strain'])

H1strain_info_2 = h5_thang(H1strainFilename_2).h5_data()
H1strain_2 = np.array(H1strain_info_2['strain/Strain'])

H1strain_info_3 = h5_thang(H1strainFilename_3).h5_data()
H1strain_3 = np.array(H1strain_info_3['strain/Strain'])

H1strain_info_4 = h5_thang(H1strainFilename_4).h5_data()
H1strain_4 = np.array(H1strain_info_4['strain/Strain'])

H1strain_info_5 = h5_thang(H1strainFilename_5).h5_data()
H1strain_5 = np.array(H1strain_info_5['strain/Strain'])

H1strain_info_6 = h5_thang(H1strainFilename_6).h5_data()
H1strain_6 = np.array(H1strain_info_6['strain/Strain'])

H1strain_info_7 = h5_thang(H1strainFilename_7).h5_data()
H1strain_7 = np.array(H1strain_info_7['strain/Strain'])

H1strain_info_8 = h5_thang(H1strainFilename_8).h5_data()
H1strain_8 = np.array(H1strain_info_8['strain/Strain'])

H1strain_info_9 = h5_thang(H1strainFilename_9).h5_data()
H1strain_9 = np.array(H1strain_info_9['strain/Strain'])

H1strain_info_10 = h5_thang(H1strainFilename_10).h5_data()
H1strain_10 = np.array(H1strain_info_10['strain/Strain'])

# L1 strain
L1strain_info_1 = h5_thang(L1strainFilename_1).h5_data()
L1strain_1 = np.array(L1strain_info_1['strain/Strain'])

L1strain_info_2 = h5_thang(L1strainFilename_2).h5_data()
L1strain_2 = np.array(L1strain_info_2['strain/Strain'])

L1strain_info_3 = h5_thang(L1strainFilename_3).h5_data()
L1strain_3 = np.array(L1strain_info_3['strain/Strain'])

L1strain_info_4 = h5_thang(L1strainFilename_4).h5_data()
L1strain_4 = np.array(L1strain_info_4['strain/Strain'])

L1strain_info_5 = h5_thang(L1strainFilename_5).h5_data()
L1strain_5 = np.array(L1strain_info_5['strain/Strain'])

L1strain_info_6 = h5_thang(L1strainFilename_6).h5_data()
L1strain_6 = np.array(L1strain_info_6['strain/Strain'])

L1strain_info_7 = h5_thang(L1strainFilename_7).h5_data()
L1strain_7 = np.array(L1strain_info_7['strain/Strain'])

L1strain_info_8 = h5_thang(L1strainFilename_8).h5_data()
L1strain_8 = np.array(L1strain_info_8['strain/Strain'])

L1strain_info_9 = h5_thang(L1strainFilename_9).h5_data()
L1strain_9 = np.array(L1strain_info_9['strain/Strain'])

L1strain_info_10 = h5_thang(L1strainFilename_10).h5_data()
L1strain_10 = np.array(L1strain_info_10['strain/Strain'])


h1strain = np.concatenate([H1strain_1, H1strain_2, H1strain_3, H1strain_4, H1strain_5, H1strain_6, H1strain_7, H1strain_8, H1strain_9, H1strain_10])
l1strain = np.concatenate([L1strain_1, L1strain_2, L1strain_3, L1strain_4, L1strain_5, L1strain_6, L1strain_7, L1strain_8, L1strain_9, L1strain_10])

del H1strain_1, H1strain_2, H1strain_3, H1strain_4, H1strain_5, H1strain_6, H1strain_7, H1strain_8, H1strain_9, H1strain_10
del L1strain_1, L1strain_2, L1strain_3, L1strain_4, L1strain_5, L1strain_6, L1strain_7, L1strain_8, L1strain_9, L1strain_10

strain = torch.FloatTensor(np.stack([h1strain, l1strain], axis=0))

del h1strain, l1strain

strain.shape


# Glitches
glitch_list = 'H1/time','L1/time'
glitch_info = h5_thang(glitchFilename).h5_data(glitch_list)
glitch_keys = h5_thang(glitchFilename).h5_keys()

glitchGPStime_h1 = np.array(glitch_info['H1/time'])
glitchGPStime_l1 = np.array(glitch_info['L1/time'])


glitch_idx_h1 = (((glitchGPStime_h1-GPSSTARTTIME)/36000)*SAMPLE_RATE*SAMPLE_RATE).astype("int")
glitch_idx_l1 = (((glitchGPStime_l1[0:3046]-GPSSTARTTIME)/36000)*SAMPLE_RATE*SAMPLE_RATE).astype("int")

del glitchGPStime_h1, glitchGPStime_l1


In [6]:
half_glitch_idx_span = 2048 ##????
glitch_amp = 3

contaminated_strain = strain

for pulse in glitch_idx_h1:
    contaminated_strain[0, (pulse-half_glitch_idx_span):pulse+half_glitch_idx_span] += glitch_amp*torch.randn((2*half_glitch_idx_span))
    
for pulse in glitch_idx_l1:
    contaminated_strain[1, (pulse-half_glitch_idx_span):pulse+half_glitch_idx_span] += glitch_amp*torch.randn((2*half_glitch_idx_span))
    
    
del glitch_idx_h1, glitch_idx_l1, strain

#### Precauculate_PSD

In [7]:
# # This is served for more detailed control 
# spec_trans = FittableSpectralTransform()

# psds = torch.empty([NUM_CHANNLES, int((SAMPLE_RATE*KERNEL_WIDTH)/2) +1])

# psds[0, :] = spec_trans.normalize_psd(
#     contaminated_strain[0],
#     sample_rate=SAMPLE_RATE,
#     num_freqs=int((SAMPLE_RATE*KERNEL_WIDTH)/2) +1,
#     fftlength=FFTLENGTH,
#     overlap=OVERLAP,
# )

# psds[1, :] = spec_trans.normalize_psd(
#     contaminated_strain[1],
#     sample_rate=SAMPLE_RATE,
#     num_freqs=int((SAMPLE_RATE*KERNEL_WIDTH)/2) +1,
#     fftlength=FFTLENGTH,
#     overlap=OVERLAP,
# )

#### Sampling BG

In [8]:
# Background

mask_dict = masking(
    glitch_info,
    segment_duration=BACKGROUND_DURATION,
    shift_range=KERNEL_WIDTH,
    pad_width=KERNEL_WIDTH/2,
    merge_edges = True
)

background = strain_sampling(
    contaminated_strain,
    mask_dict,
    sample_counts=ITERATION*BATCH_SIZE,
    kernel_width=KERNEL_WIDTH
)

#### Sampling Glitch

In [9]:
# Glitch data

glitches = glitch_sampler(
    glitch_info=glitch_info,
    strain = contaminated_strain,
    segment_duration = BACKGROUND_DURATION,
    segment_start_time = GPSSTARTTIME,
    ifos = ['H1/time', 'L1/time'],
    sample_counts = ITERATION*BATCH_SIZE,
    sample_rate = SAMPLE_RATE,
    shift_range = 0.9,
    kernel_width = 3,
)

#### Sampling Injections

In [10]:
# Injection

mask_dict = masking(
    glitch_info,
    segment_duration=BACKGROUND_DURATION,
    shift_range=KERNEL_WIDTH,
    pad_width=KERNEL_WIDTH/2,
    merge_edges = True
)

In [11]:
####
sampled_bg = strain_sampling(
    contaminated_strain,
    mask_dict,
    sample_counts=ITERATION*BATCH_SIZE,
    kernel_width=KERNEL_WIDTH
)

: 

In [None]:
# Initalizing some class

rescaler = SnrRescaler(
    num_channels=NUM_CHANNLES, 
    sample_rate = SAMPLE_RATE,
    waveform_duration = WAVEFORM_DURATION,
    highpass = HIGHPASS,
)

In [None]:
#########
rescaler.fit(
    contaminated_strain[0],
    contaminated_strain[1],
    fftlength=FFTLENGTH,
    overlap=OVERLAP
)

In [None]:
rescaled_signals, target_snrs = rescaler.forward(
    signals,
    target_snrs=SNR_DISTRO(BATCH_SIZE*ITERATION)
)

In [None]:
injection = sampled_bg + rescaled_signals

In [None]:
# # Injection

# mask_dict = masking(
#     glitch_info,
#     segment_duration=BACKGROUND_DURATION,
#     shift_range=KERNEL_WIDTH,
#     pad_width=KERNEL_WIDTH/2,
#     merge_edges = True
# )

# sampled_bg = strain_sampling(
#     contaminated_strain,
#     mask_dict,
#     sample_counts=ITERATION*BATCH_SIZE,
#     kernel_width=KERNEL_WIDTH
# )

# # Initalizing some class

# rescaler = SnrRescaler(
#     num_channels=NUM_CHANNLES, 
#     sample_rate = SAMPLE_RATE,
#     waveform_duration = WAVEFORM_DURATION,
#     highpass = HIGHPASS,
# )

# rescaler.fit(
#     contaminated_strain[0, :],
#     contaminated_strain[1, :],
#     fftlength=FFTLENGTH,
#     overlap=OVERLAP
# )

# rescaled_signals, target_snrs = rescaler.forward(
#     signals,
#     target_snrs=SNR_DISTRO(BATCH_SIZE*ITERATION)
# )

# injection = sampled_bg + rescaled_signals

In [None]:
whiten_model = Whiten(
    FFTLENGTH,
    SAMPLE_RATE,
    HIGHPASS
)

In [None]:
background_data = whiten_model(
    background, 
    psds
)

injected_data = whiten_model(
    injection,
    psds
)

glitch_data = whiten_model(
    glitches,
    psds
)



In [None]:
background_data.shape, injected_data.shape, glitch_data.shape