# ML-ready data generation
We read pressure data from historical balloon flights, pre-process them, and store them as waveform snippets of the same length in .h5 files

In [None]:
%matplotlib notebook
import pandas as pd
import numpy as np
from obspy.core.utcdatetime import UTCDateTime
import importlib
import matplotlib.pyplot as plt
import obspy
import os
import netCDF4 as nc
from tqdm import tqdm

## Load data

### Ridgecrest

In [None]:
def create_one_obspy_trace(times, amp, balloon, starttime, dt, name_maps):
                        
    #f = interpolate.interp1d(times, amp, kind='linear', )
    #times_interp = np.arange(one_tec_data.epoch.values.min()*dt, one_tec_data.epoch.values.max()*dt+dt, dt)
    #vTEC = f(times_interp)
    tr = obspy.Trace()
    tr.data = amp
    tr.stats.delta = dt
    #tr.stats.network = station
    #tr.stats.station = satellite+'ZZZ'+station
    tr.stats.station = name_maps[balloon]
    tr.stats.starttime = starttime+times[0]
    return tr

def load_balloon_data(dir_data, starttimes, name_maps):
    
    datas = {'GPS': obspy.Stream(), 'Baro': obspy.Stream()}
    for subdir, dirs, files in os.walk(dir_data):
        #print(files)
        for file in files:
            filepath = subdir + os.sep + file
            if not '.csv' in file:
                continue
                
            #if not 'Tortoise' in file:
            #    continue
            
            balloon = file.split('_')[0]
            if not balloon in starttimes.keys():
                continue
            print(balloon)
            
            data = pd.read_csv(filepath, header=[0])
            
            type_data = file.split('_')[1].split('.')[0]
            
            starttime = starttimes[balloon]
            times = data['GPSTime(s)'].values
            #if balloon == 'Tortoise':
            #    print(file)
            #    print(times/3600)
            dt = times[1]-times[0]
            #print(data.columns)
            try:
                amp = data['WGS84Altitude(m)'].astype(float).values
            except:
                amp = data[data.columns[-1]].values
            tr_data = create_one_obspy_trace(times, amp, balloon, starttime, dt, name_maps)
            datas[type_data] += tr_data
            
    return datas

starttimes = {
    'Hare': UTCDateTime(2019, 7, 22),
    'Tortoise': UTCDateTime(2019, 7, 22),
    'Hare2': UTCDateTime(2019, 8, 9),
    'CrazyCatLower': UTCDateTime(2019, 8, 9),
    'CrazyCatUpper': UTCDateTime(2019, 8, 9),
}
name_maps = {
    'Hare': 'hare',
    'Tortoise': 'tort',
    'Hare2': 'hare2',
    'CrazyCatLower': 'CraLo',
    'CrazyCatUpper': 'CraUp',
}
dir_data = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/Siddharth_balloon/'
st_crazycat = load_balloon_data(dir_data, starttimes, name_maps)

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/Siddharth_balloon/'
st_crazycat['Baro'].write(f"{folder}st_all.mseed", format="MSEED")
st_crazycat['GPS'].write(f"{folder}st_all_gps.mseed", format="MSEED")

### Strateole

In [None]:
from tqdm import tqdm

folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/Strateole2/'

files = []
files.append( f'{folder}ST2_C0_01_STR1_TSEN_P1s_v03b.nc' )
files.append( f'{folder}ST2_C1_04_STR2_TSEN_P1s_v01.nc' )
files.append( f'{folder}ST2_C1_03_TTL4_TSEN_P1s_v01.nc' )
files.append( f'{folder}ST2_C1_01_TTL5_TSEN_P1s_v01.nc' )
files.append( f'{folder}ST2_C1_16_TTL5_TSEN_P1s_v01.nc' )
files.append( f'{folder}ST2_C1_17_TTL3_TSEN_P1s_v01.nc' )

st_strateole = obspy.Stream()
st_strateole_gps = obspy.Stream()
for file in tqdm(files):
    dataset = nc.Dataset(file)
    file = file.replace('P1s_', '')
    dataset_location = nc.Dataset(file)
    station = f"b{file.split('/')[-1].split('_')[2]}{file.split('/')[-1].split('_')[1]}"

    starttime = UTCDateTime(dataset_location.date_start)
    time_pressure = dataset.variables['time'][:].filled()
    #time_pressure -= time_pressure[0]
    pressure = dataset.variables['pressure'][:].filled()
    
    dt = np.diff(time_pressure)[0]
    
    tr = obspy.Trace()
    tr.data = pressure
    tr.stats.station = station
    tr.stats.delta = dt
    tr.stats.starttime = UTCDateTime(dataset_location.date_start)
    
    st_strateole += tr
    
    time_gps = dataset_location.variables['time'][:].filled()
    alt = dataset_location.variables['alt'][:].filled()
    
    dt = np.diff(time_gps)[0]
    print(file)
    tr = obspy.Trace()
    tr.data = alt
    tr.stats.station = station
    tr.stats.delta = dt
    tr.stats.starttime = UTCDateTime(dataset_location.date_start)
    
    st_strateole_gps += tr

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/Strateole2/'
st_strateole.write(f"{folder}st_all.mseed", format="MSEED")
st_strateole_gps.write(f"{folder}st_all_gps.mseed", format="MSEED")

In [None]:
strateole=dict(
    b01C0 = [UTCDateTime("2019-11-13T01:11:48.000000Z"), UTCDateTime("2020-02-27T18:13:28.000000Z")],
    b04C1 = [UTCDateTime("2023-09-29T05:34:34.000000Z"), UTCDateTime("2023-10-28T15:30:44.000000Z")],
    b03C1 = [UTCDateTime("2023-09-28T23:47:58.000000Z"), UTCDateTime("2023-10-29T22:30:08.000000Z")],
    b01C1 = [UTCDateTime("2023-09-27T00:41:36.000000Z"), UTCDateTime("2023-09-28T01:37:16.000000Z")],
    b16C1 = [UTCDateTime("2023-12-03T21:53:48.000000Z"), UTCDateTime("2023-12-30T01:36:28.000000Z")],
    b17C1 = [UTCDateTime("2023-12-10T01:25:44.000000Z"), UTCDateTime("2023-12-30T07:23:24.000000Z")]
)

plt.figure()
ref_time = st_strateole_gps[0].stats.starttime
for tr_gps in st_strateole_gps:
    #if 
    offset = tr_gps.stats.starttime-ref_time
    plt.plot(tr_gps.times()+offset, tr_gps.data, label=tr_gps.stats.station)
    offset_start, offset_end = tr_gps.stats.starttime-ref_time+1e4, tr_gps.stats.endtime-ref_time-1e4
    plt.axvspan(offset_start, offset_end, color='black', alpha=0.3)
    template = f'{tr_gps.stats.station} = [UTCDateTime("{tr_gps.stats.starttime+offset_start}"), UTCDateTime("{tr_gps.stats.starttime+offset_end}")]'
    print(template)
plt.legend()

### Minibooster

In [None]:
from tqdm import tqdm

folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/MINIBOOSTER_balloon/signal_ascii/'

st_minibooster = obspy.Stream()
for subdir, dirs, files in os.walk(folder):
    for file in tqdm(files):
        
        filepath = subdir + file
        if '.mseed' in file:
            continue
        
        station = file.split('.')[1]
        file = open(filepath, 'r')
        lines = file.readlines()
        
        try:
            year, doy, hour, minute, second, dt = lines[0].split()
        except:
            continue
            
        tr = obspy.Trace()
        tr.data = np.array(lines[1:]).astype(float)
        tr.stats.station = station
        tr.stats.delta = dt
        tr.stats.starttime = UTCDateTime(f'{year}-{doy}T{hour}:{minute}:{second}')
        st_minibooster += tr

In [None]:
import scipy

id = 0
tr = st_minibooster[id].copy()
tr.data = tr.data.filled(0.)
tr.detrend()

tr_gps = st_minibooster_gps[id].copy()
tr_gps.detrend()

low_avail = False
tr_low = None
if low_avail:
    tr_low = st_minibooster[id+1].copy()
    tr_low.data = tr_low.data.filled(0.)
    tr_low.detrend()

station = tr.stats.station[:2]

#explosion = UTCDateTime('2020-08-20T06:30:00')
explosion = tr.stats.starttime + 9260. # explosion 1 b1 b2 b3
explosion = tr.stats.starttime + 16180. # explosion 2 b1 b2 b3
#explosion = tr.stats.starttime + 24000.-50. # explosion 2 b1
offset = explosion-tr.stats.starttime

duration = 200.
idx_explosion = np.argmin(abs(tr.times()-offset))
idx_duration = np.argmin(abs(tr.times()-duration))

offset_gps = explosion-tr_gps.stats.starttime
idx_explosion_gps = np.argmin(abs(tr_gps.times()-offset_gps))
idx_duration_gps = np.argmin(abs(tr_gps.times()-duration))

order = 2
lowcut = 0.1
highcut = 1.
sos = scipy.signal.butter(order, [lowcut, highcut], fs=1./tr.stats.delta, btype='band', output='sos')
data_filt = scipy.signal.sosfilt(sos, tr.data)
sig1 = data_filt[idx_explosion:idx_explosion+idx_duration]
if tr_low is not None:
    data_filt_low = scipy.signal.sosfilt(sos, tr_low.data)
    sig2 = data_filt_low[idx_explosion:idx_explosion+idx_duration]
    lags = scipy.signal.correlation_lags(sig1.size, sig2.size, mode='same')
    corr = scipy.signal.correlate(sig1/np.linalg.norm(sig1, axis=0), sig2/np.linalg.norm(sig2, axis=0), mode='same')
   
lowcut = 0.05
highcut = 0.49
sos = scipy.signal.butter(order, [lowcut, highcut], fs=1./tr_gps.stats.delta, btype='band', output='sos')
data_filt_gps = scipy.signal.sosfilt(sos, tr_gps.data)
sig1_gps = data_filt_gps[idx_explosion_gps:idx_explosion_gps+idx_duration_gps]
#sig1_gps = tr_gps.data#[idx_explosion_gps:idx_explosion_gps+idx_duration_gps]

fig = plt.figure(figsize=(5,5))
grid = fig.add_gridspec(3, 1)

if tr_low is not None:
    ax = fig.add_subplot(grid[:2,0])
else:
    ax = fig.add_subplot(grid[:2,0])
ax.plot(tr.times()[idx_explosion:idx_explosion+idx_duration]-offset, sig1, label='high')
if tr_low is not None:
    ax.plot(tr.times()[idx_explosion:idx_explosion+idx_duration]-offset, sig2, label='low')
ax.legend()
ax.set_title(f'Balloon {station} - band: {lowcut}-{highcut} Hz')

if tr_low is not None:
    ax = fig.add_subplot(grid[2,0])
    ax.plot(tr.times()[1]*lags, corr)
    ax.set_xlabel(f'Time (s) since {explosion}')
else:
    ax = fig.add_subplot(grid[2,0], sharex=ax)
    ax.plot(tr_gps.times()[idx_explosion_gps:idx_explosion_gps+idx_duration_gps]-offset_gps, sig1_gps, '-o')
    ax.set_xlabel(f'Time (s) since {explosion}')

In [None]:
minibooster=dict(
        B1HI = [UTCDateTime("2020-08-20T05:37:51.000000Z"), UTCDateTime("2020-08-20T13:42:51.000000Z")],
        B2HI = [UTCDateTime("2020-08-20T05:44:23.000000Z"), UTCDateTime("2020-08-20T13:04:23.000000Z")],
        B2LO = [UTCDateTime("2020-08-20T05:44:45.000000Z"), UTCDateTime("2020-08-20T13:54:45.000000Z")],
        B3HI = [UTCDateTime("2020-08-20T05:27:50.000000Z"), UTCDateTime("2020-08-20T07:56:10.000000Z")],
        B3LO = [UTCDateTime("2020-08-20T05:50:54.000000Z"), UTCDateTime("2020-08-20T10:09:14.000000Z")]
    )

plt.figure()
ref_time = st_minibooster_gps[0].stats.starttime
for tr_gps in st_minibooster_gps:
    offset = tr_gps.stats.starttime-ref_time
    plt.plot(tr_gps.times()+offset, tr_gps.data, label=tr_gps.stats.station)
    #plt.axvline(idx_explosion_gps)
    offset_start, offset_end = 0., 0.,
    if tr_gps.stats.station == 'B3HI':
        offset_start, offset_end = 17100, 26000
    if tr_gps.stats.station == 'B3LO':
        offset_start, offset_end = 18500, 34000
    if tr_gps.stats.station == 'B1HI':
        offset_start, offset_end = 18500, 47600
    if tr_gps.stats.station == 'B2HI':
        offset_start, offset_end = 18500, 44900
    if tr_gps.stats.station == 'B2LO':
        offset_start, offset_end = 18500, 47900
    plt.axvspan(offset_start, offset_end, color='black', alpha=0.3)
    template = f'{tr_gps.stats.station} = [UTCDateTime("{tr_gps.stats.starttime+offset_start}"), UTCDateTime("{tr_gps.stats.starttime+offset_end}")]'
    print(template)
plt.legend()

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/MINIBOOSTER_balloon/signal_ascii/'
st_minibooster.write(f"{folder}st_all.mseed", format="MSEED")

In [None]:
from tqdm import tqdm
from obspy.core.utcdatetime import UTCDateTime

folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/MINIBOOSTER_balloon/balloon_trajectories/'

st_minibooster_gps = obspy.Stream()
for subdir, dirs, files in os.walk(folder):
    for file in tqdm(files):
        filepath = subdir + file
        
        if '.mseed' in file:
            continue
        
        station = file.split('_')[0]
        #print(station)
        df = pd.read_csv(filepath, delimiter='\n', header=None, names=['text'])

        # Define a regular expression pattern to extract the required fields
        pattern = r'\s+gps-time=(\S+)\s+gps-leap=\S+\s+iers-leap=\S+\s+lat=([\+\-]?\d+\.\d+)\s+lon=([\+\-]?\d+\.\d+)\s+elev=([\d\.]+)'
        
        # Use the str.extract method with the pattern to extract the fields into new columns
        df[['gps-time', 'lat', 'lon', 'elev']] = df['text'].str.extract(pattern)
        df.loc[:,'gps-time'] = pd.to_datetime(df.loc[:,'gps-time'])
        df.loc[:,'lat'] = df.loc[:,'lat'].str[1:].astype(float)
        df.loc[:,'lon'] = df.loc[:,'lon'].str[1:].astype(float)
        df.loc[:,'elev'] = df.loc[:,'elev'].astype(float)
        df = df.drop(columns=['text'])
        
        dt = (df.iloc[1]['gps-time'] - df.iloc[0]['gps-time']).total_seconds()
        starttime = UTCDateTime(df['gps-time'].iloc[0])
        
        tr = obspy.Trace()
        tr.data = df.elev.values
        #print(tr.data)
        tr.stats.station = station
        tr.stats.delta = dt
        tr.stats.starttime = starttime
        
        st_minibooster_gps += tr

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/MINIBOOSTER_balloon/balloon_trajectories/'
st_minibooster_gps.write(f"{folder}st_all_gps.mseed", format="MSEED")

### Starliner

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/starliner_scrub/20210803_starliner_scrub/balloon_data/'

st_starliner = obspy.Stream()
datas = {}
for subdir, dirs, files in os.walk(folder):
    for file in tqdm(files):
        filepath = subdir + os.sep + file
        
        if 'DDF' in file:
            continue
          
        if '.mseed' in file:
            continue
        
        try:
            tr = obspy.read(filepath)[0]
            balloon = tr.stats.station[4:]
            channel = tr.stats.channel[-1]
            ext = 'HI'
            if channel == 'B' or channel == '2':
                ext = 'LO'
            station = f'b{balloon}{ext}'
            tr.stats.station = station
            st_starliner += tr
        except:
            continue
        

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/starliner_scrub/20210803_starliner_scrub/balloon_data/'
st_starliner.write(f"{folder}st_all.mseed", format="MSEED")

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/starliner_scrub/20210803_starliner_scrub/trajectory_data/'

st_starliner_gps = obspy.Stream()
datas = {}
for subdir, dirs, files in os.walk(folder):
    for file in tqdm(files):
        file_name = file
        filepath = subdir + os.sep + file#BNG01
        balloon = subdir.split('/')[-2].split('_')[1]
    
        
        if '.mseed' in file:
            continue
        
        channel = file.split('.')[0].split('_')
        
        ext = 'HI'
        if len(channel) > 1:
            if channel[1] == 'a' or channel[1] == '2':
                ext = 'LO'
        balloon = f'b{balloon}{ext}'
        
        print(channel)
        if len(channel) > 1:
            channel = channel[0][:4]+channel[1][:1]
        else:
            channel = channel[0][:5]
            
        print(filepath)
        if 'bounder' in file:
            #    pd.read_csv(filepath, skiprows=34, header=[0])
            #else:
            data = pd.read_csv(filepath, skiprows=67, header=None, names=['AA', 'SEF', 'HTR', 'BW0', 'BW1', 'MET','MAP','PAR','GAR','GF1','GF2','EXT','Pres (hPa)','dP (hPa/line)','Temp (C)','Batt (V)','Cap (V)','GPS Flag (HEX)','GPS SV','Lon','Lat','Alt (m)','Date','Time','gSpeed (m/s)','aRate (m/s)','aFilt (m/s)','heading (deg)', 'N/A'])
            data = data.loc[:,['Lon','Lat','Alt (m)','Date','Time']]
            year = data['Date'].astype(str).str[:4]
            month = data['Date'].astype(str).str[4:6]
            day = data['Date'].astype(str).str[6:]
            data.loc[:,'Date'] = pd.to_datetime(year+'-'+month+'-'+day+'T'+data['Time'])
            
        elif 'gps' in file:
            df = pd.read_csv(filepath, delimiter='\n', header=None, names=['text'])

            # Define a regular expression pattern to extract the required fields
            pattern = r'utc-time=(\S+) lat=([\+\-]?\d+\.\d+) lon=([\+\-]?\d+\.\d+) elev=([\d\.]+)'
            pattern = r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}|[+-]?\d+\.\d+'
            pattern = r'utc-time=(\S+)\s+gps-time=\S+\s+gps-leap=\S+\s+iers-leap=\S+\s+lat=([\+\-]?\d+\.\d+)\s+lon=([\+\-]?\d+\.\d+)\s+elev=([\d\.]+)'
            #print(df['text'].str.extract(pattern))
            # Use the str.extract method with the pattern to extract the fields into new columns
            df[['gps-time', 'lat', 'lon', 'elev']] = df['text'].str.extract(pattern)

            # Convert the columns to appropriate data types
            df['lat'] = df['lat'].astype(float)
            df['lon'] = df['lon'].astype(float)
            df['elev'] = df['elev'].astype(float)
            df['gps-time'] = pd.to_datetime(df['gps-time'])

            # Drop the original text column
            data = df.drop(columns=['text'])
            data.columns = ['Date', 'Lat', 'Lon', 'Alt (m)',]
        else:
            file = open(filepath, 'r')
            lines = file.readlines()
            i = 0
            not_found = True
            while not_found:
                if 'Time' in lines[i]:
                    not_found = False
                i+=1
                
            locs = []
            for j in range (0, len(lines)):
                if 'Longitude' in lines[j]:
                    locs.append(j)
            locs += [len(lines)+10]
            #data = pd.read_csv(filepath, skiprows=i-1, header=[0])
            #data.dropna(subset=['Longitude '], inplace=True)
            #print(data)
            
            #cols = ['Date    ','Time    ','Latitude  ','Longitude ','Head','Km/h','Alt-m  ','Lock','N/A1','Temp C','Pa    ','N/A0','Temp C','Pa    ','N/A']
            for isep in range(1,len(locs)):
                
                data = pd.DataFrame(lines[locs[isep-1]:locs[isep]-10])
                data = data[0].str.split(',', expand=True)
                cols = data.iloc[0].values
                data.columns = cols
                data = data.iloc[1:]
                #print(data)
                try:
                    #data['Date    '] = data['Date    '].str.replace('/', '-')
                    data['Date    '] = '2021-08-'+data['Date    '].str.split('/').str[1]
                    data['Date    '] = pd.to_datetime(data['Date    ']+'T'+data['Time    '])
                    data = data.loc[:,['Date    ', 'Latitude  ', 'Longitude ', 'Alt-m  ']]
                    data.columns = ['Date', 'Lat', 'Lon', 'Alt (m)',]
                    data.loc[:,'Alt (m)'] = data.loc[:,'Alt (m)'].str[1:].astype(float)
                except:
                    #data['UTC Date'] = data['UTC Date'].str.replace('/', '-')
                    data['UTC Date'] = '2021-08-'+data['UTC Date'].str.split('/').str[1]
                    data['UTC Date'] = pd.to_datetime(data['UTC Date']+'T'+data['UTC Time'])
                    data = data.loc[:,['UTC Date', 'Latitude  ', 'Longitude ', 'Alt-m  ']]
                    data.columns = ['Date', 'Lat', 'Lon', 'Alt (m)',]
                    #print(data.loc[:,'Lat'].str[2:])
                    data.loc[:,'Lat'] = data.loc[:,'Lat'].str[2:].astype(float)
                    data.loc[:,'Alt (m)'] = data.loc[:,'Alt (m)'].str[1:].astype(float)

                #data['Alt (m)'] = data['Alt (m)'].str[1:].astype(float)
            
        dt = (data.iloc[1]['Date'] - data.iloc[0]['Date']).total_seconds()
        starttime = UTCDateTime(data['Date'].iloc[0])
            
        tr = obspy.Trace()
        tr.data = data['Alt (m)'].values
        tr.stats.station = balloon
        tr.stats.channel = channel
        tr.stats.delta = dt
        tr.stats.starttime = starttime
            
        st_starliner_gps += tr

In [None]:
folder = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/starliner_scrub/20210803_starliner_scrub/trajectory_data/'
st_starliner_gps.write(f"{folder}st_all_gps.mseed", format="MSEED")

In [None]:
starliner=dict(
    b1LO = [UTCDateTime("2021-08-04T18:48:14.000000Z"), UTCDateTime("2021-08-05T00:46:34.000000Z")],
    b1HI = [UTCDateTime("2021-08-04T18:48:08.000000Z"), UTCDateTime("2021-08-05T03:54:48.000000Z")],
    b2HI = [UTCDateTime("2021-08-03T16:19:04.000000Z"), UTCDateTime("2021-08-04T02:06:04.000000Z")],
    b2LO = [UTCDateTime("2021-08-03T16:23:52.000000Z"), UTCDateTime("2021-08-04T02:10:52.000000Z")],
    b3HI = [UTCDateTime("2021-08-03T16:23:13.000000Z"), UTCDateTime("2021-08-04T01:58:13.000000Z")],
    b4HI = [UTCDateTime("2021-08-03T17:23:13.000000Z"), UTCDateTime("2021-08-04T01:21:43.000000Z")]
)

plt.figure()
ref_time = st_starliner_gps[-1].stats.starttime
offset_start, offset_end = 0., 0.
for tr_gps in st_starliner_gps:
    if tr_gps.stats.station == 'b3HI':
        offset_start, offset_end = 18370, 52870
    if tr_gps.stats.station == 'b2LO' or tr_gps.stats.station == 'b2HI':
        offset_start, offset_end = 17650, 52870
    if tr_gps.stats.station == 'b4HI':
        offset_start, offset_end = 21990, 50700
    if tr_gps.stats.station == 'b1HI':
        offset_start, offset_end = 56900., 89700
    if tr_gps.stats.station == 'b1LO':
        offset_start, offset_end = 56900., 78400
    offset = tr_gps.stats.starttime-ref_time
    plt.plot(tr_gps.times()+offset, tr_gps.data, label=tr_gps.stats.station)
    plt.axvspan(offset_start, offset_end, color='black', alpha=0.3)
    template = f'{tr_gps.stats.station} = [UTCDateTime("{tr_gps.stats.starttime+offset_start}"), UTCDateTime("{tr_gps.stats.starttime+offset_end}")]'
    print(template)
plt.legend()

### Building validation/training datasets

In [None]:
def trim_st(st_all, duration, return_times, overlap=1., freq_min=0.01, freq_max=20.,):
    
    new_st = obspy.Stream()
    id = -1
    sections = dict()
    sections_ids = dict()
    sections_unique_ids = dict()
    cpt_section_ids = -1
    print('Trimming streams')
    for itr, tr_loop in tqdm(enumerate(st_all), total=len(st_all), disable=True):
        
        if itr > 5:
            continue
        
        starttime, endtime = return_times[itr]
        tr_in = tr_loop.copy()
        tr_in.trim(starttime=starttime, endtime=endtime)
        
        if isinstance(tr_in.data, np.ma.MaskedArray):
            unmasked_idx = np.where(~tr_in.data.mask)[0]
            diff_unmasked_idx = np.where(np.diff(unmasked_idx)>1)[0]
            last_idx = unmasked_idx[0]
        else:
            unmasked_idx = np.arange(tr_in.data.size)
            diff_unmasked_idx = np.array([], dtype=int)
            last_idx = 0
        diff_unmasked_idx = np.r_[diff_unmasked_idx, -1]
        
        sections[tr_in.stats.station] = []
        sections_ids[tr_in.stats.station] = []
        sections_unique_ids[tr_in.stats.station] = []
        skipped_windows = 0
        total_windows = 0
        for id_section, idx in enumerate(diff_unmasked_idx):
            
            cpt_section_ids += 1
            sections_unique_ids[tr_in.stats.station].append(cpt_section_ids)
            tr = tr_in.copy()
            current_idx = unmasked_idx[idx]
            tr.data = tr.data[last_idx:current_idx]
            #print(last_idx, current_idx, unmasked_idx[diff_unmasked_idx[0]+1])
            last_idx = unmasked_idx[idx+1]
            #print(np.ma.is_masked(tr.data))
            
            sections[tr_in.stats.station].append( (tr.stats.starttime, tr.stats.endtime) )
            
            tr.stats.starttime
            tr.detrend()
            tr.filter('bandpass', freqmin=freq_min, freqmax=freq_max, zerophase=True, corners=6)
            #tr.resample(freq_max*2)
            starttimes = np.arange(0., tr.times()[-1], duration*overlap)
            
            sections_ids[tr_in.stats.station].append([])
            for starttime in tqdm(starttimes):
                id += 1
                tr_loc_Baro = tr.copy()
                tr_loc_Baro.resample(freq_max*2)
                tr_loc_Baro.trim(starttime=tr.stats.starttime+starttime, endtime=tr.stats.starttime+starttime+duration)
                tr_loc_Baro.stats.station = tr.stats.station + '-' + str(id)
                total_windows += 1
                if tr_loc_Baro.stats.endtime-tr_loc_Baro.stats.starttime<duration:
                    skipped_windows += 1
                    continue
                new_st += tr_loc_Baro
                sections_ids[tr_in.stats.station][id_section].append(tr_loc_Baro.stats.station)
            
        if skipped_windows > 0:
            print(f'{tr_loop.stats.station}: Skipping {skipped_windows}/{total_windows} windows')
            
    return new_st, sections, sections_ids, sections_unique_ids
    
from math import log, ceil, floor
def closest_power(x, power=8):
    possible_results = floor(log(x, power)), ceil(log(x, power))
    return min(possible_results, key= lambda z: abs(x-power**z))
   
def find_closest_duration(target_duration, target_sampling):
    nsize = int(target_sampling*target_duration)
    power = closest_power(nsize, power=2)
    nsize = 2**power
    duration = nsize/target_sampling
    return duration
    
def get_times(st, flight_id):
    
    #l_starttimes, l_endtimes = [tr.stats.starttime for tr in st], [tr.stats.endtime for tr in st.copy()]
    #starttime, endtime = min(l_starttimes), max(l_endtimes)
    #default_times = (starttime, endtime)
    
    available_times = dict(
        ridgecrest=dict(
            CraLo = [UTCDateTime('2019-08-09T18:52:00'), UTCDateTime('2019-08-10T02:43:31')],
            CraUp = [UTCDateTime('2019-08-09T18:52:00'), UTCDateTime('2019-08-10T02:43:31')],
            tort = [UTCDateTime('2019-07-22T14:09:00'), 'default']
        ),
        minibooster=dict(
            B1HI = [UTCDateTime("2020-08-20T05:37:51.000000Z"), UTCDateTime("2020-08-20T13:42:51.000000Z")],
            B2HI = [UTCDateTime("2020-08-20T05:44:23.000000Z"), UTCDateTime("2020-08-20T13:04:23.000000Z")],
            B2LO = [UTCDateTime("2020-08-20T05:44:45.000000Z"), UTCDateTime("2020-08-20T13:54:45.000000Z")],
            B3HI = [UTCDateTime("2020-08-20T05:27:50.000000Z"), UTCDateTime("2020-08-20T07:56:10.000000Z")],
            B3LO = [UTCDateTime("2020-08-20T05:50:54.000000Z"), UTCDateTime("2020-08-20T10:09:14.000000Z")]
        ),
        strateole=dict(
            b01C0 = [UTCDateTime("2019-11-13T01:11:48.000000Z"), UTCDateTime("2020-02-27T18:13:28.000000Z")],
            b04C1 = [UTCDateTime("2023-09-29T05:34:34.000000Z"), UTCDateTime("2023-10-28T15:30:44.000000Z")],
            b03C1 = [UTCDateTime("2023-09-28T23:47:58.000000Z"), UTCDateTime("2023-10-29T22:30:08.000000Z")],
            b01C1 = [UTCDateTime("2023-09-27T00:41:36.000000Z"), UTCDateTime("2023-09-28T01:37:16.000000Z")],
            b16C1 = [UTCDateTime("2023-12-03T21:53:48.000000Z"), UTCDateTime("2023-12-30T01:36:28.000000Z")],
            b17C1 = [UTCDateTime("2023-12-10T01:25:44.000000Z"), UTCDateTime("2023-12-30T07:23:24.000000Z")]
        ),
        starliner=dict(
            b1LO = [UTCDateTime("2021-08-04T18:48:14.000000Z"), UTCDateTime("2021-08-05T00:46:34.000000Z")],
            b1HI = [UTCDateTime("2021-08-04T18:48:08.000000Z"), UTCDateTime("2021-08-05T03:54:48.000000Z")],
            b2HI = [UTCDateTime("2021-08-03T16:19:04.000000Z"), UTCDateTime("2021-08-04T02:06:04.000000Z")],
            b2LO = [UTCDateTime("2021-08-03T16:23:52.000000Z"), UTCDateTime("2021-08-04T02:10:52.000000Z")],
            b3HI = [UTCDateTime("2021-08-03T16:23:13.000000Z"), UTCDateTime("2021-08-04T01:58:13.000000Z")],
            b4HI = [UTCDateTime("2021-08-03T17:23:13.000000Z"), UTCDateTime("2021-08-04T01:21:43.000000Z")]
        )
    )
    
    return_times = []
    for tr in st:
        default_time = (tr.stats.starttime, tr.stats.endtime)
        return_times.append( default_time )
    if flight_id in available_times:
        for itr, tr in enumerate(st):
            if tr.stats.station in available_times[flight_id]:
                default_time = [tr.stats.starttime, tr.stats.endtime]
                return_times[itr] = available_times[flight_id][tr.stats.station]
                if return_times[itr][0] == 'default':
                    return_times[itr][0] = default_time[0]
                if return_times[itr][1] == 'default':
                    return_times[itr][1] = default_time[1]
    
    return return_times
    
target_duration = 150.
overlap = 0.25
freq_min = 0.15
freq_max = 2.5
target_sampling = freq_max*2.
"""
st = st_crazycat['Baro'].copy()
flight_id = 'ridgecrest'
event_type = 'earthquake'
"""
"""
st = st_strateole.copy()
flight_id = 'strateole'
event_type = 'earthquake'
"""
"""
st = st_minibooster.copy()
st.merge()
flight_id = 'minibooster'
event_type = 'explosion'
"""

st = st_starliner.copy()
st.merge()
flight_id = 'starliner'
event_type = 'explosion'


return_times = get_times(st, flight_id)
duration = find_closest_duration(target_duration, target_sampling)
target_size = int(duration*target_sampling)
new_st, sections, sections_ids, sections_unique_ids = trim_st(st, duration, return_times, overlap=overlap, freq_min=freq_min, freq_max=freq_max)

In [None]:
st

In [None]:
import h5py

def store_all_as_hdf5(datasets, st_all, target_size, flight_id, datasets_section_ids, event_type='earthquake', new_st_label=None):

    """
    ## Initialization dataset
    size_crop = st_all['Baro'][0].data.size-int(st_all['Baro'][0].data.size*crop_percent)
    there_are_GPS_data = True if len(st_all['GPS']) > 0 else False
    """
    
    # Create a subgroup for each stream with the corresponding event ID as the subgroup name
    for dataset in datasets:
    
        results = {'X': [], 'label': [], 'event_type': [], 'window': [], 'id': [], 'station': [], 'flight_id': [], 'section_id': []}
        print(f'Building dataset {dataset}')
        
        """
        idmin, idmax = datasets[dataset]
        idmin, idmax = int(idmin), int(idmax)
        st_Baro_loc = st_all['Baro'][idmin:idmax]
        if there_are_GPS_data:
            st_GPS_loc = st_all['GPS'][idmin:idmax]
        else:
            st_GPS_loc = st_Baro_loc
        for tr_Baro_loc_in, tr_GPS_loc in zip(st_Baro_loc, st_GPS_loc):
        """
        for istation, station in tqdm(enumerate(datasets[dataset]), total=len(datasets[dataset])):
            
            tr_label = None
            if new_st_label is not None:
                tr_label = new_st_label.select(station=station)[0].copy()
            
            section_id = datasets_section_ids[dataset][istation]
            
            tr = new_st.select(station=station)[0].copy()
            window = (str(tr.stats.starttime), str(tr.stats.endtime))
            
            """
            tr_Baro_loc_cropped = tr_Baro_loc.copy()
            if size_crop > 0:
                tr_Baro_loc_cropped.data[size_crop//2:-size_crop//2] = 0.
            """
            if abs(tr.data).max() == 0.:
                print('Problem amplitude')
                #print(tr_Baro_loc.stats.station)
                #print(X0.shape, X1.shape)
                continue
            
            X = np.expand_dims(tr.data, axis=-1)
            """
            X0 = np.expand_dims(tr_Baro_loc_cropped.data, axis=-1)
            if there_are_GPS_data:
                X1 = np.expand_dims(tr_GPS_loc.data, axis=-1)
                if not X0.shape[0] == X1.shape[0]:
                    #print(tr_Baro_loc.stats.station)
                    #print(X0.shape, X1.shape)
                    continue
                X = np.concatenate((X0, X1), axis=-1)
            else:
                X = X0
            """
            X = X[:target_size,:]
            
            if tr_label is None:
                label = np.zeros_like(X)
            else:
                label = np.expand_dims(tr_label.data, axis=-1)
            label = label[:target_size,:]
            
            if X.shape[0] < target_size or label.shape[0] < target_size:
                print('problem size')
                print(tr.stats.station)
                print(X.shape)
                continue
            
            if np.isnan(X).any():
                print('problem nan')
                print(tr.stats.station)
                continue
            
            results['X'].append( X )
            results['label'].append( label )
            results['event_type'].append( event_type )
            results['window'].append( window )
            results['flight_id'].append( flight_id )
            results['station'].append( tr.stats.station.split('-')[0] )
            results['id'].append( tr.stats.station )
            results['section_id'].append( section_id )
        
        # Open the HDF5 file in "write" mode
        with h5py.File(filename.format(dataset=dataset, flight_id=flight_id), "w") as f:
            f.create_dataset('X', data=results['X'], dtype='float32')
            f.create_dataset('label', data=results['label'], dtype='float32')
            f.create_dataset('event_type', data=str(results['event_type']))
            f.create_dataset('window', data=results['window'])
            f.create_dataset('flight_id', data=results['flight_id'])
            f.create_dataset('station', data=results['station'])
            f.create_dataset('id', data=results['id'])
            f.create_dataset('section_id', data=results['section_id'])
            
    return results

def find_bounds_dataset_section(list_quantiles, requested_dataset, l_stations):
    
    iprev = 0
    for dataset, quantile in list_quantiles.items():
        icurrent = int(len(l_stations)*quantile)
        idx = np.arange(len(l_stations))[iprev:iprev+icurrent]
        iprev = icurrent
        if dataset == requested_dataset:
            return idx.min(), idx.max()

def prepare_datasets_dates(new_st, sections, sections_ids, sections_unique_ids, list_quantiles, target_size, min_size=10):
    
    datasets = dict()
    datasets_section_ids = dict()
    for dataset, quantile in tqdm(list_quantiles.items()):
        
        datasets[dataset] = []
        datasets_section_ids[dataset] = []
        for station in sections_ids: # Loop over each flight
            l_sections = sections_ids[station]
            for i_section, l_stations in enumerate(l_sections): # Loop over each segment in each flight
                section_unique_id = sections_unique_ids[station][i_section]
                i_min, i_max = find_bounds_dataset_section(list_quantiles, dataset, l_stations)
                starttime, endtime = sections[station][i_section]
                datasets[dataset] += l_stations[i_min:i_max]
                datasets_section_ids[dataset] += [section_unique_id for _ in l_stations[i_min:i_max]]
                
    for dataset in list_quantiles:
        print(f'dataset "{dataset}" ({len(datasets[dataset])} inputs):')
    
    return datasets, datasets_section_ids

new_st_label = None
filename = '/projects/infrasound/data/infrasound/2023_ML_balloon/data/{dataset}_waveform_dataset_{flight_id}_0.015Hz.h5'
list_quantiles = dict(all=1.)

datasets, datasets_section_ids = prepare_datasets_dates(new_st, sections, sections_ids, sections_unique_ids, list_quantiles, target_size)
results = store_all_as_hdf5(datasets, new_st, target_size, flight_id, datasets_section_ids, event_type=event_type, new_st_label=new_st_label)

## Test tsfresh

In [None]:
from tqdm import tqdm
def split_time_series(df, duration, overlap, sample_rate):
    # Convert 'time' to datetime if not already converted and ensure consistent index
    #df['time'] = pd.to_datetime(df['time'])
    #df.set_index('time', inplace=True)
    df = df.sort_index()

    # Determine the sampling interval from the first two timestamps
    #sample_rate = (df.index[1] - df.index[0]).total_seconds()
    samples_per_chunk = int(duration / sample_rate)
    
    overlap_samples = int(samples_per_chunk * overlap )
    step_size = samples_per_chunk - overlap_samples

    # Create an array of start indices for each sub-series
    start_indices = np.arange(0, len(df) - samples_per_chunk + 1, step_size)
    
    # Generate sub-series by slicing df based on start_indices
    def slice_chunk(start_idx, chunk_id):
        chunk = df.iloc[start_idx:start_idx + samples_per_chunk].copy()
        chunk['id'] = chunk_id
        return chunk
    
    # Use list comprehension to create sub-series efficiently
    sub_series = []
    for i, idx in tqdm(enumerate(start_indices), total=len(start_indices)):
        sub_series.append( slice_chunk(idx, i + 1) )

    # Combine all sub-series into a single DataFrame
    return pd.concat(sub_series)

## Test tsfresh
start_constant_alt = UTCDateTime('2019-07-22T14:09:00') # Tortoise
end_constant_alt = st_all['Baro'][0].stats.endtime # Tortoise
tr = st_all['Baro'][0].copy()
tr.resample(10.)
tr.filter('highpass', freq=0.25)
tr.trim(starttime=start_constant_alt, endtime=end_constant_alt)

pd_Baro = pd.DataFrame(np.c_[tr.times(), tr.data], columns=['time', 'pressure'])
duration = 50.
overlap = 0.25
sample_rate = tr.stats.delta
pd_Baro_total = split_time_series(pd_Baro, duration, overlap, sample_rate)

## Celine data

In [None]:
start_constant_alt = UTCDateTime('2019-08-09T18:52:00')
end_constant_alt = UTCDateTime('2019-08-10T02:43:31')
tr_upper = st_all['GPS'].copy()
for tr in tr_upper:
    tr.data = tr.data.astype(float)
tr_upper[0].stats.station = 'low'
tr_upper[1].stats.station = 'up'

plt.figure()
for tr in tr_upper:
    plt.plot(tr.times(), tr.data, label=tr.stats.station)
plt.legend(frameon=False)
plt.axvline(start_constant_alt-tr.stats.starttime, linestyle=':', color='black')
plt.axvline(end_constant_alt-tr.stats.starttime, linestyle=':', color='black')
plt.xlabel(f'Time since {tr.stats.starttime}')
plt.ylabel('Altitude (m)')

tr_upper.trim(starttime=start_constant_alt, endtime=end_constant_alt)


tr_upper.write("../2023_Celine_internship/msc_celine_specfem/utils/utils_NORSAR/test_data_Venus/2019_noise_balloon_GPS.mseed", format="MSEED")