In [None]:
### ----- Imports ----- ###

import sys
sys.path.append('/home/sdybing/neic-mlaapde')

from mlaapde.access import MLAAPDE_Access
from mlaapde import UTC
import matplotlib.pyplot as plt
import numpy as np
import os.path
import os
import shutil
import glob
import h5py
import tensorflow as tf

os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


# mlpa = MLAAPDE_Access(data_dir = '/data/hank/mlaapde_subset/data', random_seed = 616) # 3 months
# dataset = 'subset'

mlpa = MLAAPDE_Access(data_dir = '/data/hank/mlaapde_v1b/data', random_seed = 616)
dataset = 'v1b'

mlpa.data_dir

In [None]:
from tensorflow.keras import layers
from tensorflow.keras import backend as K

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
        raise

In [None]:
#mlpa.default_args

In [None]:
### ----- Parameters ----- ###

# Where to save the products
models_figs_path = '/home/sdybing/neic-mlaapde/allwaveforms/decimated/'

# MLAAPDE/data generation params
#nsamp = False # Samples of waveforms to load from MLAAPDE
#n_train_samp = 1000000
#n_valid_samp = 200000
#nsamp = n_train_samp + n_valid_samp
sr = 40 # Sampling rate
trim_sec = 60 # Trimming amount around phase pick to get from MLAAPDE
trim_pre_sec = trim_sec
trim_post_sec = trim_pre_sec
window_len = trim_pre_sec + trim_post_sec
#train_split = 0.8 # Percentage of data used in training
#valid_split = 0.2 # Percentage of data used for validation
n_channels = 3 # Instrument channels
cut_lens = [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 25, 30, 35, 40, 50, 60, 70, 80, 90, 100, 110, 120]
cut_lens_finish = [70, 80, 90, 100, 110, 120]
test_cut_lens = [7, 8]
desired_shift = 3
max_shift = desired_shift * 2 # Since the shifting method actually makes it half what this value is set to
min_snr_db = False
max_snr_db = False
log_progress_fraction = 100
valid_phases = ['P', 'Pn', 'Pg']
cast_dtype = np.float32

# Training/model params
epochs_number = 200
batch_size = int(256) # Reducing to help memory
monte_carlo_sampling = 50
drop_rate = 0.5
filters = [32, 64, 96, 128, 256] 

# Used if loading a trained model
training_samps = 100000 
training_dataset = 'v1b'
shift_status = 'shifted'
model_folder_path = '/home/sdybing/neic-mlaapde/allwaveforms/float32/'

# To make end error plots
mean_errors = []
std_errors = []

In [None]:
### ----- Where are the HDF5 files getting saved? ----- ###

# Location of HDF5 data files
hdf5_save_dir = '/data/sdybing/allwaveforms/decimated/'
if os.path.isdir(hdf5_save_dir):
    pass
else:
    os.makedirs(hdf5_save_dir)

# Pick extra labels and set keyword arguments for data parameters
return_labels = ['source_magnitude', 'source_magnitude_type', 'snr_db', 'phase_id']
kwargs = {'valid_phases':valid_phases, 'labels':return_labels, 'trim_pre_sec':trim_pre_sec, 'trim_post_sec':trim_post_sec, 'min_snr_db':min_snr_db, 'max_snr_db':max_snr_db, 'log_progress_fraction':log_progress_fraction, 'cast_dtype':cast_dtype}
#kwargs = {'valid_phases':valid_phases, 'labels':return_labels, 'trim_pre_sec':trim_pre_sec, 'trim_post_sec':trim_post_sec, 'min_snr_db':min_snr_db, 'max_snr_db':max_snr_db, 'cast_dtype':cast_dtype}

In [None]:
try_smaller_dataset = False

if try_smaller_dataset:

    hdf5_save_dir = '/data/sdybing/checksplitcode/'
    n_train_samp = 40000
    n_test_samp = 10000
    nsamp = n_train_samp + n_test_samp

    ### ----- Load the data from HDF5 files ----- ###

    training_data = h5py.File(hdf5_save_dir + '/training_data.hdf5', 'r')

    train_waves_t = training_data['waveforms'][:]
    train_mags = training_data['magnitudes'][:]

    validation_data = h5py.File(hdf5_save_dir + '/validation_data.hdf5', 'r')

    valid_waves_t = validation_data['waveforms'][:]
    valid_mags = validation_data['magnitudes'][:]

    training_data.close()
    validation_data.close()

    print(train_waves_t.shape)
    print(train_mags.shape)
    print(valid_waves_t.shape)
    print(valid_mags.shape)

In [None]:
run_before = True # If this code has been run before and the HDF5 files already exist, set this to True to save time.

In [None]:
### ----- Load and save training data from MLAAPDE ----- ###

if run_before == False:
    
    ### ----- Access the data from MLAAPDE ----- ###
    
    pt0 = UTC('2013-8-1') # Start of training window
    pt1 = UTC('2018-10-1') # End of training window/start of validation window
    pt2 = UTC('2020-1-1') # End of validation window
    
    print('Sampling catalog')
    samps1, cat1 = mlpa.sample_catalog(time1 = pt0, time2 = pt1, nsamp = False, split = [1,0], **kwargs)
    print('Sample_catalog done, splitting samples')
    
    samples = {}
    samples['training'] = samps1['training']
    print('Training samples split, getting waves')
    
    waves = {}
    waves['training'] = mlpa.get_waves(samples['training'], **kwargs)
    print('Get waves done, getting labels')
    
    labels = {}
    labels['training'] = mlpa.get_labels(samples['training'], cat1, labels = return_labels)
    print('Getting labels done, formatting arrays')

    train_waves = waves['training']
    print('Train waves done')
    train_labels = labels['training']
    print('Train labels done')
    train_waves_t = train_waves.transpose(0,2,1)
    print('Train waves transposed')
    train_mags = train_labels['source_magnitude']
    print('Train mags done')
    train_mags_type = train_labels['source_magnitude_type']   
    print('Train mag_types done')
    train_snr_db = train_labels['snr_db']
    print('Train SNRs done')
    train_phase_id = train_labels['phase_id']
    print('Train phase_ids done, next saving the HDF5')
    
    ### ----- Save the data to HDF5 file ----- ###
            
    with h5py.File(hdf5_save_dir + '/training_data.hdf5', 'w') as f1:
        print('Saving waves')
        waves = f1.create_dataset('waveforms', (train_waves_t.shape), data = train_waves_t, chunks = True)
        print('Waves saved')
        print('Saving mags')
        mags = f1.create_dataset('magnitudes', (train_mags.shape), data = train_mags, chunks = True)
        print('Mags saved')
        print('Saving mag_types')
        mags_type = f1.create_dataset('magnitude_types', (train_mags_type.shape), data = train_mags_type, chunks = True)
        print('Mag_types saved')
        print('Saving SNRs')
        snr = f1.create_dataset('snr', (train_snr_db.shape), data = train_snr_db, chunks = True)
        print('SNRs saved')
        print('Saving phase_ids')
        phase_id = f1.create_dataset('phase_id', (train_phase_id.shape), data = train_phase_id, chunks = True)
        print('Phase_ids saved')

#     f1 = h5py.File(hdf5_save_dir + '/training_data.hdf5', 'w')
#     waves = f1.create_dataset('waveforms', (train_waves_t.shape), data = train_waves_t)
#     mags = f1.create_dataset('magnitudes', (train_mags.shape), data = train_mags)
#     mags_type = f1.create_dataset('magnitude_types', (train_mags_type.shape), data = train_mags_type)
#     snr = f1.create_dataset('snr', (train_snr_db.shape), data = train_snr_db)
#     phase_id = f1.create_dataset('phase_id', (train_phase_id.shape), data = train_phase_id)
#     f1.close()

In [None]:
### ----- Load and save validation data from MLAAPDE ----- ###

if run_before == False:
    
    ### ----- Access the data from MLAAPDE ----- ###
    
    pt0 = UTC('2013-8-1') # Start of training window
    pt1 = UTC('2018-10-1') # End of training window/start of validation window
    pt2 = UTC('2020-1-1') # End of validation window

    samps2, cat2 = mlpa.sample_catalog(time1 = pt1, time2 = pt2, nsamp = False, split = [1,0], **kwargs)
    
    samples = {}
    samples['validation'] = samps2['training']
    
    waves = {}
    waves['validation'] = mlpa.get_waves(samples['validation'], **kwargs)
    
    labels = {}
    labels['validation'] = mlpa.get_labels(samples['validation'], cat2, labels = return_labels)

    valid_waves = waves['validation']
    valid_labels = labels['validation']
    valid_waves_t = valid_waves.transpose(0,2,1)
    valid_mags = valid_labels['source_magnitude']
    valid_mags_type = valid_labels['source_magnitude_type']
    valid_snr_db = valid_labels['snr_db']
    valid_phase_id = valid_labels['phase_id']
    
    ### ----- Save the data to HDF5 file ----- ###

    f2 = h5py.File(hdf5_save_dir + '/validation_data.hdf5', 'w')
    waves = f2.create_dataset('waveforms', (valid_waves_t.shape), data = valid_waves_t)
    mags = f2.create_dataset('magnitudes', (valid_mags.shape), data = valid_mags)
    mags_type = f2.create_dataset('magnitude_types', (valid_mags_type.shape), data = valid_mags_type)
    snr = f2.create_dataset('snr', (valid_snr_db.shape), data = valid_snr_db)
    phase_id = f2.create_dataset('phase_id', (valid_phase_id.shape), data = valid_phase_id)
    f2.close()

In [None]:

### ----- Load the full dataset from HDF5 files ----- ###

training_data = h5py.File(hdf5_save_dir + '/training_data.hdf5', 'r')
dataset_names = list(training_data.keys())
print(dataset_names)

train_waves = training_data['waves'][:]
train_mags = training_data['magnitude'][:]
train_phase_id = training_data['phase_id'][:]

validation_data = h5py.File(hdf5_save_dir + '/validation_data.hdf5', 'r')

valid_waves = validation_data['waves'][:]
valid_mags = validation_data['magnitude'][:]

training_data.close()
validation_data.close()

In [None]:
print(train_waves.shape)
print(train_mags.shape)
print(valid_waves.shape)
print(valid_mags.shape)

In [None]:
train_waves_t = train_waves.transpose(0,2,1)
valid_waves_t = valid_waves.transpose(0,2,1)

print(train_waves_t.shape)
print(train_mags.shape)
print(valid_waves_t.shape)
print(valid_mags.shape)

In [None]:
del train_waves
del valid_waves

In [None]:
# isfinite = np.isfinite(train_waves_t)
# #print(isfinite)
# i = np.where(isfinite == False)[0]
# print(i)

In [None]:
# k = np.unique(i)
# print(k)

In [None]:
idx = 741100
idx2 = 741102
# print(train_waves_t[idx])
# print(train_waves_t[idx2])
# plt.plot(train_waves_t[idx])
# plt.show()
# plt.plot(train_waves_t[idx2])
# plt.show();

In [None]:
# Fixing the weird nan wave

print(train_phase_id[idx])
print(train_waves_t[idx])
print(train_phase_id[idx2])
print(train_waves_t[idx2])

copy_wave = train_waves_t[0]
copy_mag = train_mags[0]
copy_wave2 = train_waves_t[1]
copy_mag2 = train_mags[1]

train_waves_t[idx] = copy_wave
train_mags[idx] = copy_mag
train_waves_t[idx2] = copy_wave2
train_mags[idx2] = copy_mag2

In [None]:
# Check to make sure it's good now

print(train_waves_t[idx])
print(train_mags[idx])
print(train_mags[0])

print(train_waves_t[idx2])
print(train_mags[idx2])
print(train_mags[1])

In [None]:
n_train_samp = len(train_mags)
n_valid_samp = len(valid_mags)
nsamp = n_train_samp + n_valid_samp
print(nsamp)

In [None]:
##### -------- Functions -------- #####

### ----- Predictions ----- ###

class KerasDropoutPrediction(object):
    def __init__(self, model):
        self.model = model

    def predict(self, x, n_iter = 10):
        predM = []
        auM = []

        for itr in range(n_iter):

            if itr == 0:
                print('Making predictions...')
            r = model.predict(x, batch_size = batch_size, verbose = 0)

            pred = r[:, 0] 
            au = r[:, 1] 
            predM.append(pred.T)
            auM.append(au.T)

        predM = np.array(predM).reshape(n_iter, len(predM[0]))
        auM = np.array(auM).reshape(n_iter, len(auM[0])) 

        yhat_mean = predM.mean(axis = 0)
        yhat_squared_mean = np.square(predM).mean(axis = 0)

        sigma_squared = 10**(auM)  # should be e, not 10?
        sigma_squared_mean = sigma_squared.mean(axis = 0)

        ep_unc = predM.std(axis = 0)  

        combined = yhat_squared_mean - np.square(yhat_mean) + sigma_squared_mean

        return yhat_mean, sigma_squared_mean, ep_unc, combined

### ----- Training setup ----- ###

def customLoss(yTrue, yPred):
    y_hat = K.reshape(yPred[:, 0], [-1, 1]) 
    s = K.reshape(yPred[:, 1], [-1, 1])
    return tf.reduce_sum(0.5 * K.exp(-1 * s) * K.square(K.abs(yTrue - y_hat)) + 0.5 * s, axis=1)

### ----- Training callbacks ----- ###

class PrintSomeValues(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs = {}):
        print()
        print(f'y_test[0:1] = {valid_mags[0:1]}.')
        print(f'pred = {self.model.predict(shift_valid_waves_t[0:1])}.')

### ----- Data generator ----- ###

debug_generator = False
debug_plot = False

class dataGenerator(tf.keras.utils.Sequence):

    def __init__(self, train_waves_t, train_mags, n_train_samp, window_len, cut_len, max_shift, sr, batch_size, n_channels, shuff = True, noise_rate = 0.5, flip_rate = 0.5, dropchan_rate = 0.05):
    #def __init__(self, train_waves_t, train_mags, nsamp, window_len, cut_len, max_shift, sr, train_split, batch_size, n_channels, shuff = True, noise_rate = 0, flip_rate = 0, dropchan_rate = 0):
        self.train_waves_t = train_waves_t
        self.train_mags = train_mags
        self.n_train_samp = n_train_samp
        self.window_len = window_len
        self.cut_len = cut_len
        self.max_shift = max_shift
        self.shift_len = self.cut_len - self.max_shift
        self.sr = sr
        self.shuff = shuff
        self.batch_size = batch_size
        self.full_lengthpts = int(self.window_len * self.sr)
        self.cut_lengthpts = int(self.cut_len * self.sr)
        self.shift_lengthpts = int(self.shift_len * self.sr)
        self.lentraindata = int(self.n_train_samp)
        self.middle = int(self.full_lengthpts / 2)
        self.n_channels = n_channels
        self.on_epoch_end()
        self.noise_rate = noise_rate
        self.flip_rate = flip_rate
        self.dropchan_rate = dropchan_rate

    def on_epoch_end(self): # Modify dataset between epochs
        self.indexes = np.arange(self.lentraindata, dtype = int) # Array of integers for the training data length
        if self.shuff == True:
            np.random.shuffle(self.indexes) # Shuffle those indices if indicated

    def __len__(self) : # Number of batches in the sequences
        return int(self.lentraindata / self.batch_size) # Length of training data divided by the chosen batch size

    def __data_generation(self, indexes):
        
        # Initialization
        if debug_generator: print('Initial empty shapes!')
        y = np.ones((self.batch_size,))
        
        full_x = np.zeros((self.batch_size, self.full_lengthpts, self.n_channels)) # Shape is batch size by number of samples (window * sps) by number of channels
        if debug_generator: print('Full X shape:' + str(full_x.shape))
        
        cut_x = np.zeros((self.batch_size, self.cut_lengthpts, self.n_channels)) # Shape is batch size by number of samples (window * sps) by number of channels
        if debug_generator: print('Cut X shape:' + str(cut_x.shape))
        
        shift_x = np.zeros((self.batch_size, self.shift_lengthpts, self.n_channels)) # Shape is batch size by number of samples (window * sps) by number of channels
        if debug_generator: print('Shift X shape:' + str(shift_x.shape))
            
        x = shift_x.copy()
    
        # Make the augmentations
        for i, ix in enumerate(indexes):
            n = 0 # Counter to prevent dropping 3 channels
            
            if debug_generator: print('Augmenting!')
            if debug_plot:
                wvf_idx = np.random.choice(np.arange(0,len(self.train_mags),1))
                #wvf_idx = 507
                if i == wvf_idx:
                    print('Waveform index: ' + str(wvf_idx))
                    def plot_features(axis):
                        axis.legend(loc = 'upper left', fontsize = 14)
                        axis.set_xlim(0,self.shift_len)
                        axis.set_ylim(-1.2,1.2)
                        axis.axvline(self.shift_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                        axis.tick_params(axis = 'x', bottom = False, labelbottom = False)
                    f, ((a0, a1, a2), (a3, a4, a5), (a6, a7, a8), (a9, a10, a11), (a12, a13, a14), (a15, a16, a17), (a18, a19, a20)) = plt.subplots(nrows = 7, ncols = 3, gridspec_kw={'height_ratios': [1, 1, 1, 0.75, 1, 1, 1]}, figsize = (22,10), dpi=300, facecolor = 'white')

            # Original waveforms
            if debug_generator: print(ix)
            full_x[i,] = self.train_waves_t[ix,:,0:3]
            y[i,] = self.train_mags[ix,]
            if debug_generator: print('Original full lengthpts: ' + str(self.full_lengthpts))
            if debug_generator: print('Original full x shape: ' + str(full_x.shape))
            if debug_plot:
                if i == wvf_idx:
                    times = np.arange(0, self.window_len, 1/self.sr)
                    
                    a0.set_title('Original waveforms', fontsize = 16)
                    a0.plot(times, full_x[i,:,0], color = 'C0', label = 'E') 
                    a0.legend(loc = 'upper left', fontsize = 14)
                    a0.set_xlim(0,self.window_len)
                    a0.set_ylim(-1.2,1.2)
                    a0.axvline(self.window_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                    a0.tick_params(axis = 'x', bottom = False, labelbottom = False)

                    a3.plot(times, full_x[i,:,1], color = 'C1', label = 'N')
                    a3.set_ylabel('Stream-normalized amplitude', fontsize = 14)
                    a3.legend(loc = 'upper left', fontsize = 14)
                    a3.set_xlim(0,self.window_len)
                    a3.set_ylim(-1.2,1.2)
                    a3.axvline(self.window_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                    a3.tick_params(axis = 'x', bottom = False, labelbottom = False)

                    a6.plot(times, full_x[i,:,2], color = 'C2', label = 'Z')
                    a6.set_xlabel('Time (s)', fontsize = 14)
                    a6.legend(loc = 'upper left', fontsize = 14)
                    a6.set_xlim(0,self.window_len)
                    a6.set_ylim(-1.2,1.2)
                    a6.axvline(self.window_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                    a6.tick_params(axis = 'x', bottom = True, labelbottom = True)

            # Cut to the window length 
            cut_x[i,] = full_x[i, int(self.middle - (self.cut_len/2)*self.sr) : int(self.middle + (self.cut_len/2)*self.sr), 0:3]
            if debug_generator: print('Cut lengthpts: ' + str(self.cut_lengthpts))
            if debug_generator: print('Cut x shape: ' + str(cut_x.shape))
            if debug_plot:
                if i == wvf_idx:
                    cut_times = np.arange(0, self.cut_len, 1/self.sr)

                    a1.set_title('Trimming to desired window length', fontsize = 16)
                    a1.plot(cut_times, cut_x[i,:,0], color = 'C0', label = 'E')
                    a1.legend(loc = 'upper left', fontsize = 14)
                    a1.set_xlim(0,self.cut_len)
                    a1.set_ylim(-1.2,1.2)
                    a1.axvline(self.cut_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                    a1.tick_params(axis = 'x', bottom = False, labelbottom = False)

                    a4.plot(cut_times, cut_x[i,:,1], color = 'C1', label = 'N')
                    a4.set_ylabel('Stream-normalized amplitude', fontsize = 14)
                    a4.legend(loc = 'upper left', fontsize = 14)
                    a4.set_xlim(0,self.cut_len)
                    a4.set_ylim(-1.2,1.2)
                    a4.axvline(self.cut_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                    a4.tick_params(axis = 'x', bottom = False, labelbottom = False)

                    a7.plot(cut_times, cut_x[i,:,2], color = 'C2', label = 'Z')
                    a7.set_xlabel('Time (s)', fontsize = 14)
                    a7.legend(loc = 'upper left', fontsize = 14)
                    a7.set_xlim(0,self.cut_len)
                    a7.set_ylim(-1.2,1.2)
                    a7.axvline(self.cut_len/2, color = 'black', linestyle = '--', alpha = 0.7)
                    a7.tick_params(axis = 'x', bottom = True, labelbottom = True)
            
            # Shifting up to 3 seconds
            self.time_offset = np.random.uniform(low = 0, high = self.max_shift) # seconds
            self.samps_offset = int(self.time_offset * self.sr)
            self.start = self.samps_offset
            self.end = int(self.start + self.shift_len * self.sr)
            shift_x[i,] = cut_x[i, self.start : self.end, 0:3]
            if debug_generator: print('Shift lengthpts: ' + str(self.shift_lengthpts))
            if debug_generator: print('Shift x shape: '  + str(shift_x.shape))
            if debug_plot:
                if i == wvf_idx:
                    print(self.time_offset)
                    print(self.shift_len)
                    shift_times = np.arange(0, self.shift_len, 1/self.sr)

                    a2.set_title('Shifted ' + str(round(self.time_offset,1)) + ' seconds', fontsize = 16)
                    a2.plot(shift_times, shift_x[i,:,0], color = 'C0', label = 'E')
                    plot_features(a2)

                    a5.plot(shift_times, shift_x[i,:,1], color = 'C1', label = 'N')
                    a5.set_ylabel('Stream-normalized amplitude', fontsize = 14)
                    plot_features(a5)

                    a8.plot(shift_times, shift_x[i,:,2], color = 'C2', label = 'Z')
                    a8.set_xlabel('Time (s)', fontsize = 14)
                    plot_features(a8)
                    a8.tick_params(axis = 'x', bottom = True, labelbottom = True)
            
            x[i,] = shift_x[i,]
            if debug_generator: print('Renamed to x shape: ' + str(x.shape))
            
            # Add extra noise
            if(np.random.random() < self.noise_rate):
                x[i,:,0] = x[i,:,0] + np.random.normal(0, np.random.uniform(0.01, 0.15), self.shift_lengthpts)
                x[i,:,1] = x[i,:,1] + np.random.normal(0, np.random.uniform(0.01, 0.15), self.shift_lengthpts)
                x[i,:,2] = x[i,:,2] + np.random.normal(0, np.random.uniform(0.01, 0.15), self.shift_lengthpts)
            if debug_plot:
                if i == wvf_idx:
                    a12.set_title('Extra noise', fontsize = 16)
                    a12.plot(shift_times, x[i,:,0], color = 'C0', label = 'E') 
                    plot_features(a12)

                    a15.plot(shift_times, x[i,:,1], color = 'C1', label = 'N')
                    a15.set_ylabel('Stream-normalized amplitude', fontsize = 14)
                    plot_features(a15)

                    a18.plot(shift_times, x[i,:,2], color = 'C2', label = 'Z')
                    a18.set_xlabel('Time (s)', fontsize = 14)
                    plot_features(a18)
                    a18.tick_params(axis = 'x', bottom = True, labelbottom = True)
            x[i,] = x[i,] / np.max(np.abs(x[i,])) # normalizing again now that it's cut

            # Flip horizontal channels
            if(np.random.random() < self.flip_rate):
                flip =  x[i,:,0].copy()
                x[i,:,0] =  x[i,:,1]
                x[i,:,1] =  flip
            if debug_plot:
                if i == wvf_idx:
                    a13.set_title('Flip horizontal components', fontsize = 16)
                    a13.plot(shift_times, x[i,:,0], color = 'C0', label = 'E') 
                    plot_features(a13)

                    a16.plot(shift_times, x[i,:,1], color = 'C1', label = 'N')
                    a16.set_ylabel('Stream-normalized amplitude', fontsize = 14)
                    plot_features(a16)

                    a19.plot(shift_times, x[i,:,2], color = 'C2', label = 'Z')
                    a19.set_xlabel('Time (s)', fontsize = 14)
                    plot_features(a19)
                    a19.tick_params(axis = 'x', bottom = True, labelbottom = True)
            
            # Drop channels
            if(np.random.random() < self.dropchan_rate):
                x[i,:,0] = 0
                n += 1
            if(np.random.random() < self.dropchan_rate):
                x[i,:,1] = 0
                n += 1
            if(np.random.random() < self.dropchan_rate):
                if n == 2:
                    pass
                else:
                    x[i,:,2] = 0
            if debug_plot:
                if i == wvf_idx:
                    a14.set_title('Drop channel', fontsize = 16)
                    a14.plot(shift_times, x[i,:,0], color = 'C0', label = 'E') 
                    plot_features(a14)

                    a17.plot(shift_times, x[i,:,1], color = 'C1', label = 'N')
                    a17.set_ylabel('Stream-normalized amplitude', fontsize = 14)
                    plot_features(a17)

                    a20.plot(shift_times, x[i,:,2], color = 'C2', label = 'Z')
                    a20.set_xlabel('Time (s)', fontsize = 14)
                    plot_features(a20)
                    a20.tick_params(axis = 'x', bottom = True, labelbottom = True)
            
            if debug_plot:
                if i == wvf_idx:
                    a9.set_visible(False)
                    a10.set_visible(False)
                    a11.set_visible(False)

                    plt.subplots_adjust(hspace = 0)
                    plt.show()
                    plt.close();
            
            if debug_generator: print('Final x shape: ' + str(x.shape))

        return x, y

    def __getitem__(self, index):
        iii = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        x, y = self.__data_generation(iii)

        return x, y

In [None]:
########## STUFF THAT NEEDS LOOPING ##########

debug = False

for cut_len in cut_lens_finish:
    print('Cut len: ' + str(cut_len))
    
    ### ----- Where are the trained models/figures getting saved? ----- ###

    save_dir = models_figs_path + str(dataset) + '_' + str(nsamp) + 'samps_' + str(cut_len-6) + 's_window'
    if os.path.isdir(save_dir):
        pass
    else: # deletes directory to start over: shutil.rmtree(save_dir)  
        os.makedirs(save_dir)

    ### ----- Cut and shift validation data to match the training data ----- ###

    ## Cut ##
    if debug:
        rand = np.random.choice(np.arange(0,len(valid_mags),1))
        print('Rand: ' + str(rand))
        valid_times = np.arange(0, window_len, 1/sr)
        plt.figure(facecolor = 'white')
        plt.suptitle('Original validation data')
        plt.subplot(311)
        plt.plot(valid_times, valid_waves_t[rand,:,0], color = 'C0')
        plt.subplot(312)
        plt.plot(valid_times, valid_waves_t[rand,:,1], color = 'C1')
        plt.subplot(313)
        plt.plot(valid_times, valid_waves_t[rand,:,2], color = 'C2')
        plt.subplots_adjust(hspace = 0)
        plt.show();
    
    middle = int(valid_waves_t.shape[1] / 2)
    if debug: print('Middle: ' + str(middle))
    valid_size = int(n_valid_samp)
    if debug: print('Valid size: ' + str(valid_size))
    cut_valid_waves_t = np.zeros((valid_size, int(cut_len*sr), 3)) 
    if debug: print('Cut waves t shape: ' + str(cut_valid_waves_t.shape))

    for i in range(len(valid_waves_t)):
        cut_valid_waves_t[i,] = valid_waves_t[i, int(middle - (cut_len/2)*sr) : int(middle + (cut_len/2)*sr), 0:3]
    if debug: print('Cut waves t shape: ' + str(cut_valid_waves_t.shape))
    if debug:
        valid_cut_times = np.arange(0, cut_len, 1/sr)
        print('Rand: ' + str(rand))
        plt.figure(facecolor = 'white')
        plt.suptitle('Cut validation data')
        plt.subplot(311)
        plt.plot(valid_cut_times, cut_valid_waves_t[rand,:,0], color = 'C0')
        plt.subplot(312)
        plt.plot(valid_cut_times, cut_valid_waves_t[rand,:,1], color = 'C1')
        plt.subplot(313)
        plt.plot(valid_cut_times, cut_valid_waves_t[rand,:,2], color = 'C2')
        plt.subplots_adjust(hspace = 0)
        plt.show();

    ## Shift ##
    shift_len = cut_len - max_shift
    if debug: print('Shift len: ' + str(shift_len))
    time_offset = np.random.uniform(low = 0, high = max_shift, size = valid_size)
    shift_valid_waves_t = np.zeros((valid_size, int(shift_len * sr), 3)) 

    for ii, offset in enumerate(time_offset):
        bin_offset = int(offset * sr)
        start_bin = bin_offset 
        end_bin = int(start_bin + shift_len * sr)
        shift_valid_waves_t[ii, :, 0] = cut_valid_waves_t[ii, start_bin:end_bin, 0] 
        shift_valid_waves_t[ii, :, 1] = cut_valid_waves_t[ii, start_bin:end_bin, 1]
        shift_valid_waves_t[ii, :, 2] = cut_valid_waves_t[ii, start_bin:end_bin, 2]

    if debug: print('Shift waves t shape: ' + str(shift_valid_waves_t.shape))
    if debug:
        valid_shift_times = np.arange(0, shift_len, 1/sr)
        print('Rand: ' + str(rand))
        plt.figure(facecolor = 'white')
        plt.suptitle('Shifted validation data')
        plt.subplot(311)
        plt.plot(valid_shift_times, shift_valid_waves_t[rand,:,0], color = 'C0')
        plt.subplot(312)
        plt.plot(valid_shift_times, shift_valid_waves_t[rand,:,1], color = 'C1')
        plt.subplot(313)
        plt.plot(valid_shift_times, shift_valid_waves_t[rand,:,2], color = 'C2')
        plt.subplots_adjust(hspace = 0)
        plt.show();

    ### ----- Initialize the model and training setup ----- ###
    
    inp1 = tf.keras.layers.Input(shape = ((cut_len - max_shift)*sr, n_channels), name = 'input_layer') 
    e = tf.keras.layers.Conv1D(filters[1], 3, padding = 'same')(inp1) 
    e = tf.keras.layers.Dropout(drop_rate)(e, training = True)
    e = tf.keras.layers.MaxPooling1D(4, padding = 'same')(e)
    e = tf.keras.layers.Conv1D(filters[0], 3, padding = 'same')(e) 
    e = tf.keras.layers.Dropout(drop_rate)(e, training = True)
    e = tf.keras.layers.MaxPooling1D(4, padding = 'same')(e)
    e = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(100, return_sequences = False, dropout = 0.0, recurrent_dropout = 0.0))(e)
    #e = tf.keras.layers.Dense(2)(e)
    e = tf.keras.layers.Dense(1)(e)
    o = tf.keras.layers.Activation('linear', name = 'output_layer')(e)
    model = tf.keras.models.Model(inputs = [inp1], outputs = o)
    #model.summary()

    #model.compile(optimizer = 'Adam', loss = customLoss)
    model.compile(optimizer = 'Adam', loss = tf.keras.losses.MeanSquaredError())
    
    model_name = str(dataset) + '_' + str(nsamp) + 'samps_' + str(shift_len) + 's'
    lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor = np.sqrt(0.1), cooldown = 0, patience = 4, min_lr = 0.5e-6)
    m_name = str(model_name) + '_{epoch:03d}.h5' 
    filepath = os.path.join(save_dir, m_name)
    early_stopping_monitor = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = filepath, monitor = 'val_loss', mode = 'auto', verbose = 1, save_best_only = True)
    psv = PrintSomeValues()
    callbacks = [lr_reducer, early_stopping_monitor, checkpoint, psv]
    training_generator = dataGenerator(train_waves_t, train_mags, n_train_samp, window_len, cut_len, max_shift, sr, batch_size, n_channels)

    ### ----- Train ----- ###

    history = model.fit(training_generator, epochs = epochs_number, validation_data = (shift_valid_waves_t, valid_mags), callbacks = callbacks);

    ### ----- Plot training curves ----- ###

    plt.figure(facecolor = 'white')
    plt.plot(history.history['loss'],label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    #plt.show()
    plt.savefig(save_dir + '/loss_curves_' + m_name + '.png')
    plt.close();

    ### ----- Make the predictions ----- ###

    #kdp = KerasDropoutPrediction(model)
    #predict, al_unc, ep_unc, comb = kdp.predict(shift_valid_waves_t, monte_carlo_sampling)
    predict = model.predict(shift_valid_waves_t)

    ### ----- Quick plot of the predictions vs. true magnitudes ----- ###

    fig4, ax = plt.subplots(facecolor = 'white')
    ax.scatter(valid_mags, predict, alpha = 0.4, facecolors = 'r', edgecolors = 'r')
    ax.plot([valid_mags.min(), valid_mags.max()], [valid_mags.min(), valid_mags.max()], 'k--', alpha=1, lw=2)
    ax.set_xlabel('Measured magnitude')
    ax.set_ylabel('Predicted magnitude')
    #plt.show()
    fig4.savefig(save_dir + '/scatter_' + m_name + '.png')
    plt.close();

    ### ----- Rename things ----- ###

    measured_mags = valid_mags
    predicted_mags = predict.flatten()

    ### ----- Calculate the error and standard deviation ----- ###

    errors = []

    for idx in range(len(predicted_mags)):
        predicted = predicted_mags[idx]
        measured = measured_mags[idx]
        error = predicted - measured
        errors.append(error)

    mean_error = np.mean(np.array(errors))
    std_error = np.std(np.array(errors))

    print('Mean error: ' + str(round(mean_error,3)))
    print('Error standard deviation: ' + str(round(std_error,2)))

    mean_errors.append(mean_error)
    std_errors.append(std_error)

    ### ----- Make the box and whisker plots with STF magnitude line ----- ###

    Tt = shift_len / 2
    M0_dyncm = Tt**3 * (0.625 * 10**23)
    Mw = ((2/3) * np.log10(M0_dyncm)) - 10.73 # M0 in dyne-cm

    print('Rupture duration: ' + str(Tt) + ' seconds')
    print('M0: ' + str(M0_dyncm) + ' dyne-cm')
    print('Mw: ' + str(round(Mw,2)))

    bins = np.arange(11,85,1)/10
    data_bins = []

    for abin in bins:
        i = np.where(valid_mags == abin)[0]
        predict_bin = np.array(predicted_mags[i])
        data_bins.append(predict_bin)

    fig = plt.figure(figsize =(14, 9), dpi = 300, facecolor = 'white')

    fig.suptitle('MLAAPDE ' + str(dataset) + ' agumented dataset, tested with ' + str(int(n_valid_samp)) + ' ' + str(shift_len) + 's window samples shifted up to 3s', fontsize = 18, y = 0.96, color = 'black')
    ax = fig.add_subplot(111)
    ax.set_facecolor('white')
    ax.text(x = 30, y = 8.8, s = 'Model: ' + m_name, fontsize = 13, color = 'black')
    ax.grid(which = 'major', axis = 'y')
    ax.grid(which = 'major', axis = 'x', markevery = [10,20,30,40,50])
    ax.set_ylim(1,8.6)

    bp = ax.boxplot(data_bins, notch = False, patch_artist = True)
    ax.axvline((Mw-1)*10, color = 'green', linestyle = '--', linewidth = 2) # Position = (magnitude - 1)*10

    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')
        patch.set_edgecolor('blue')
    for median in bp['medians']:
        median.set(color ='blue', linewidth = 3)
    for whisker in bp['whiskers']:
        whisker.set(color ='blue', linewidth = 1)
    for cap in bp['caps']:
        cap.set(color ='blue', linewidth = 1)
    for flier in bp['fliers']:
        flier.set(marker ='+', color ='blue', alpha = 0.5)

    bins_list = bins.tolist()
    ax.set_xticklabels(bins_list, fontsize = 14, color = 'black')
    ax.set_yticklabels([1, 2, 3, 4, 5, 6, 7, 8], fontsize = 14, color = 'black')
    ax.set_ylabel('Predicted magnitude', fontsize = 16, color = 'black')
    ax.set_xlabel('Measured magnitude', fontsize = 16, color = 'black')
    ax.xaxis.set_major_locator(plt.MaxNLocator(8))
    ax.plot((1.1,70),(1.1,8),'r--', linewidth = 3, alpha = 0.5)
    ax.text(s = 'Testing results', x = 2, y = 8, fontsize = 18, backgroundcolor = 'lightskyblue', color = 'black')
    ax.text(s = 'STF magnitude: ' + str(round(Mw,2)), x = 2, y = 7.5, fontsize = 18, backgroundcolor = 'lightgreen', color = 'black')

    #plt.show()
    plt.savefig(save_dir + '/boxplot_durline_' + m_name + '.png', format = 'PNG', facecolor = 'white', transparent = False)
    plt.close();

In [None]:
np.savetxt(models_figs_path + str(dataset) + '_' + str(nsamp) + 'testsamp_' + str(training_samps) + 'trainsamp_meanerrors.txt', np.array(mean_errors))
np.savetxt(models_figs_path + str(dataset) + '_' + str(nsamp) + 'testsamp_' + str(training_samps) + 'trainsamp_stderrors.txt', np.array(std_errors))

In [None]:
### ----- Plot error and std for all windows ----- ###

shift_lengths = []

for cut_len in cut_lens:
    shift_len = cut_len - max_shift
    shift_lengths.append(shift_len)
    
plt.figure(figsize = (10, 6), facecolor = 'white')
plt.title('Testing errors/stds: models trained with\n100,000 augmented samples shifted up to 3 seconds', fontsize = 16)
plt.errorbar(shift_lengths, mean_errors, yerr = std_errors, fmt = '.', markersize = 10, ecolor = 'C1', capsize = 3, label = 'Error bars show 1 standard\ndeviation above each point and\n1 standard deviation below')
plt.scatter(shift_lengths, mean_errors, color = 'C0')
plt.grid()
plt.xlabel('Window length (s)', fontsize = 14)
plt.ylabel('Mean error\n(predicted - measured magnitude)', fontsize = 14)
plt.xticks(fontsize = 13)
plt.yticks(fontsize = 13)
plt.legend(fontsize = 12)
plt.axhline(0, color = 'black', linestyle = '--', alpha = 0.75)

#plt.show()
plt.savefig(models_figs_path + str(dataset) + '_' + str(nsamp) + 'testsamp_' + str(training_samps) + 'trainsamp_all_errors_stds.png', format = 'PNG', facecolor = 'white', transparent = False)
plt.close();