In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import gzip
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from scipy.special import softmax
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, concatenate, Softmax, MaxPool2D
from tensorflow.keras.models import Model

2025-06-27 08:40:56.677648: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751006456.685961   14864 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751006456.688360   14864 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [23]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, data_dir, train=False, batch_size=32, max_files=None, labelMask = False, start_file=None, workers=1, use_multiprocessing=False, sort=True, max_queue_size=10, max_len = 8192, s_freq = 6000, mask_size = 57, **kwargs):
        super().__init__(**kwargs)
        self.data_dir = data_dir
        self.train = train
        self.batch_size = batch_size
        self.max_files = max_files
        self.start_file = start_file
        self.sort = sort
        self.max_len = max_len
        self.s_freq = s_freq
        self.mask_size = mask_size
        self.labelMask = labelMask

        # Load file list
        if self.sort:
            self.file_list = sorted(os.listdir(data_dir))
            if self.max_files:
                self.file_list = self.file_list[:self.max_files]
            if self.start_file:
                self.file_list = self.file_list[self.start_file:]
        else:
            self.file_list = os.listdir(data_dir)[:self.max_files] if self.max_files else os.listdir(data_dir)

        self.workers = workers
        self.use_multiprocessing = use_multiprocessing
        self.max_queue_size = max_queue_size
        self.on_epoch_end()

    def __len__(self):
        """Return the number of batches."""
        return int(np.ceil(len(self.file_list) / self.batch_size))

    def __getitem__(self, idx):
        """Load a batch of data."""
        batch_indices = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size]
        batch_files = [self.file_list[i] for i in batch_indices]
        batch_X, batch_P, batch_S, batch_N = [], [], [], []

        for idx, file_name in enumerate(batch_files):
            try:
                # Read the txt file
                Xnp, p_idx, s_idx = self.readNPZ(file_name)
                X, P, S, N,_ = self.preprocess(Xnp, p_idx, s_idx)
                
                X = np.expand_dims(X, axis=0)
                P = np.expand_dims(P, axis=0)
                S = np.expand_dims(S, axis=0)
                N = np.expand_dims(N, axis=0)
                batch_X.append(X)
                batch_P.append(P)
                batch_S.append(S)
                batch_N.append(N)
            except Exception as e:
                print('Error in processing get_item arrays')
                print(e)
                continue
        try:
            batch_X = np.array(batch_X)
            batch_P = np.array(batch_P)
            batch_S = np.array(batch_S)
            batch_N = np.array(batch_N)
            batch_label = tf.stack([batch_P, batch_S, batch_N],axis=-1)
        except Exception as e:
            print('Error in creating arrays')
            print(e)
        return (batch_X, batch_label)

    def preprocess(self, X, p_idx, s_idx, MaxAmpIdx=0):
        """
        Preprocesses the input data to generate P and S arrays.
        :param X: Input array.
        :param p_idx: Index of the P pick.
        :param s_idx: Index of the S pick.
        :param MaxAmpIdx: Index of the common time maxAmp if the training set contains it.
        :return: Processed X, P, and S arrays.
        """

        # Ensure non-negative P and S
        p_idx, s_idx = max(0, p_idx), max(0, s_idx)

        # Shift (during training) and pad
        X, idx_shift, wind = Pad(X, p_idx, s_idx, self.train)

        # Adjust indices based on padding and if it is in the window
        p_idx = 0 if wind[0] else p_idx + idx_shift if p_idx else p_idx
        s_idx = 0 if wind[1] else s_idx + idx_shift if s_idx else s_idx

        # Initialize P and S
        P, S, N = np.zeros(self.max_len), np.zeros(self.max_len), np.zeros(self.max_len)

        # Fetch the target mask of the subclass
        mask = self.targetMask()
        
        # Fit single arrival mask
        
        P = self.fitMask(p_idx, mask, P)
        S = self.fitMask(s_idx, mask, S)

    
        ############################################################################################################
        # Ensure pdf output
        if not self.train:
            P /= P.sum() if P.sum() != 0 else 1
            S /= S.sum() if S.sum() != 0 else 1
        else:
            P = P / P.sum() if P.sum() != 0 else softmax(P)
            S = S / S.sum() if S.sum() != 0 else softmax(S)

        # fig, ax = plt.subplots(2, 1, sharex=True, figsize=(10, 8))
        # ax[0].plot(X)
        # ax[1].plot(P,label = "P label")
        # # ax[1].plot(pSoft, label="P label")
        # ax[1].plot(S,label = "S label")
        # # ax[1].plot(N, label="N label")
        # # ax[0].set_title(name)
        # plt.legend()
        # plt.show()
        return X, P, S,N, idx_shift

    def fitMask(self, idx, mask, PS_array):
        maskLen = len(mask)
        if (idx != 0) and (idx < self.max_len):
            start = int(np.ceil(idx - maskLen / 2))
            end = int(np.ceil(idx + maskLen / 2))
            if start >= 0 and end <= len(PS_array):
                PS_array[start:end] = np.maximum(PS_array[start:end], mask)
            elif np.ceil((maskLen / 2) + len(PS_array)) >= end > len(PS_array):
                PS_array[start:] = np.maximum(PS_array[start:],mask[:(len(PS_array) - end)])
            elif np.ceil(-(maskLen / 2)) < start < 0:
                PS_array[0:end] = np.maximum( PS_array[0:end], mask[np.abs(start):])
        return PS_array
        
    def popMultiMasks(self, Arv, idx_shift, masks, PS_array, X):
        for a, arr in enumerate(Arv):
            if arr <= 0:
                continue
            arr += idx_shift
            if arr >= 8192 or X[arr][2] == 0:
                continue
            mask = masks[a]
            # print(mask)
            # print(len(mask))
            idx = arr
            PS_array = self.fitMask(idx, mask, PS_array)
        return PS_array

    def get_single(self,idx):
        file_name = self.file_list[idx]
        try:
            Xnp, p_idx, s_idx, MaxAmpIdx = self.readNPZ(file_name)
            X, P, S, N, idx_shift = self.preprocess(Xnp, p_idx, s_idx)
        except Exception as e:
            print(f"Error processing single file {file_name}: {e}")
        return X, P, S, idx_shift

    def total_len(self):
        """Return the total number of files."""
        return len(self.file_list)

    def on_epoch_end(self):
        self.indices = np.arange(len(self.file_list))

    def targetMask(self):
        if self.labelMask:
            return 
        else:
            max_value = 1
    
            ascending_part = np.linspace(0, max_value, self.mask_size // 2, endpoint=False)
            descending_part = np.linspace(max_value, 0, self.mask_size // 2 + 1, endpoint=True)
    
            pdf = np.concatenate((ascending_part, descending_part))
            pdf = pdf / np.sum(pdf)  # Ensures pdf adds up to 1
            mask = pdf
            return mask
        
    # Get Labels returns the modes for the P and S waves of the target distribution
    def getLabels(self):
        labels = []
        for file_name in self.file_list:
            try:
                Xnp, p_idx, s_idx = self.readNPZ(file_name)

                # If getLabels is used, make sure to not shuffle and that random is False
                _, P, S, _, _ = self.preprocess(Xnp, p_idx, s_idx)

                p_idx = P.argmax()
                s_idx = S.argmax()

                labels.append(np.array([p_idx, s_idx]))

            except Exception as e:
                print('Error in getting labels')
                print(e)

        labels = np.vstack([labels])

        return labels

    def readNPZ(self,fileName):
        try:
            filePath = os.path.join(dataDir, fileName)
            with gzip.open(filePath, 'rt') as file:
                ## extract p and s labels
                firstLine = file.readline().strip()
                ## extract seismogram as colomns of x,y,z
                df = pd.read_csv(file,header=None, engine='python')
                ## extract information from the dataframe
                labels = np.array(firstLine.split(','), dtype=int)
        except Exception as e:
                print(f"Error reading file with 'c' and 'python' engines: {e}")
        try:
            X = df.iloc[0:, :3]
            pIdx = labels[0]
            sIdx = labels[1]
        except Exception as e:
            print('Error in processing arrays')
            print(e)
            print(file_path)
        return X, pIdx, sIdx

    def fileDict(self, idx):
        dict = np.array(self.file_list)[idx]
        return dict

    def fileSet(self):
        return np.array(self.file_list)

    def getFileIndex(self,search_string):
        index = next((i for i, s in enumerate(self.file_list) if search_string in s), -1)
        return index

In [25]:
def Pad(X, p_idx, s_idx, random=False, maxAmpIdx = 0, max_len=8192):

    ### Sub Functions
    def standardize(x):  #
        maxAbsAmp = np.max(np.abs(x))
        if maxAbsAmp != 0:
            x /= maxAbsAmp
        return x

    def make_X(X, random, idx_shift):
        if random != True:
            idx_shift = 0
        x0 = X[:, 0]
        x1 = X[:, 1]
        x2 = X[:, 2]

        x0 = np.pad(x0, (idx_shift, max_len), mode='constant')[:max_len]
        x1 = np.pad(x1, (idx_shift, max_len), mode='constant')[:max_len]
        x2 = np.pad(x2, (idx_shift, max_len), mode='constant')[:max_len]

        X = np.column_stack([x0, x1, x2])

        return X, idx_shift

    def backpad(X):
        # Check if X has the correct number of dimensions
        if X.ndim != 2:
            raise ValueError("Input array X must be a 2D array.")
        zeros_to_add = np.subtract((max_len, 3), X.shape)
        rX = np.pad(X, ((0, zeros_to_add[0]), (0, 0)), mode='constant')
        return rX

    def is_within_slice(idx, slice_obj):
        start = slice_obj.start
        stop = slice_obj.stop
        # Check if value is within the range and aligns with the step
        return start <= idx < stop

    #####################################################################################
    
    # Apply standardize function along axis
    X = np.apply_along_axis(standardize, axis=0, arr=X)

    idx_shift = 0
    s_pad = 1500  # number of samples I want to keep after the S pick
    p_pad = 2500  # number of samples I want to keep after the P pick (if there is only a P pick)
    end_pad = 50  # to be able to put gaussian in without issue
    wind = [False, False] # When random == False; Check if the indices lie in the calculated window!

    ## No shuffling
    if (X.shape[0] < max_len): # Case 1 - 4
        idx_shift = 0
        X, idx_shift = make_X(X, random, idx_shift)
        return X, idx_shift, wind
    elif X.shape[0] >= max_len: # i.e. X.shape[0] >= max_len
        if maxAmpIdx > 0:
            max_val = maxAmpIdx
            # print(max_val)
        else:
            row_norms = np.sqrt(np.sum(X ** 2, axis=1))
            max_val = np.argmax(row_norms, axis=0)
        # max_val =
        # max_index = max_val[2]  # files always have data in third column
        max_index = max_val
        start_index = max_index - max_len // 2
        end_index = max_index + max_len // 2

        if (max_index <= max_len) or (max_index <= start_index) or (start_index < 0):
            slice_range = slice(0,max_len)
            X = X[slice_range, :]
            idx_shift = 0
            if not is_within_slice(p_idx, slice_range):
                wind[0] = True
            if not is_within_slice(s_idx, slice_range):
                wind[1] = True

        elif end_index > X.shape[0]:
            slice_range = slice(-max_len,X.shape[0])
            X = X[slice_range, :]
            idx_shift = max_len - X.shape[0]
            if not is_within_slice(p_idx, slice_range):
                wind[0] = True
            if not is_within_slice(s_idx, slice_range):
                wind[1] = True

            # if not is_within_slice(p_idx, slice_range) or not is_within_slice(s_idx, slice_range):
            #     # print("Label not in slice range:")
            #     idx_shift = None

        else:
            slice_range = slice(start_index, end_index)
            X = X[slice_range, :]
            idx_shift = max_len - end_index
            # if not is_within_slice(p_idx, slice_range) or not is_within_slice(s_idx, slice_range):
            #     # print("Label not in slice range:")
            #     idx_shift = None
            if not is_within_slice(p_idx, slice_range):
                wind[0] = True
            if not is_within_slice(s_idx, slice_range):
                wind[1] = True
        return X, idx_shift, wind
    else:
        print("Unsure?")

In [None]:
def UNetModel(input_size=(1, 8192, 3), f=7, s=4):
    filter_shape = (1, f)
    stride_shape = (1, s)

    inputs = Input(shape=input_size)

    # Encoding block
    conv1 = Conv2D(8, filter_shape, activation="relu", padding="same")(inputs)
    conv1 = Conv2D(8, filter_shape, activation="relu", padding="same")(conv1)

    conv2 = Conv2D(8, filter_shape, activation="relu", strides=stride_shape, padding="same")(conv1)
    conv2 = Conv2D(11, filter_shape, activation="relu", padding="same")(conv2)

    conv3 = Conv2D(11, filter_shape, activation="relu", strides=stride_shape, padding="same")(conv2)
    conv3 = Conv2D(16, filter_shape, activation="relu", padding="same")(conv3)

    conv4 = Conv2D(16, filter_shape, activation="relu", strides=stride_shape, padding="same")(conv3)
    conv4 = Conv2D(22, filter_shape, activation="relu", padding="same")(conv4)

    # Middle (bottleneck)
    convm = Conv2D(22, filter_shape, activation="relu", strides=stride_shape, padding="same")(conv4)
    convm = Conv2D(32, filter_shape, activation="relu", padding="same")(convm)

    # Decoding block
    deconv4 = Conv2DTranspose(44, filter_shape, strides=stride_shape, padding="same")(convm)
    uconv4 = concatenate([deconv4, conv4], axis=-1)
    uconv4 = Conv2D(22, filter_shape, activation="relu", padding="same")(uconv4)

    deconv3 = Conv2DTranspose(32, filter_shape, strides=stride_shape, padding="same")(uconv4)
    uconv3 = concatenate([deconv3, conv3], axis=-1)
    uconv3 = Conv2D(16, filter_shape, activation="relu", padding="same")(uconv3)

    deconv2 = Conv2DTranspose(22, filter_shape, strides=stride_shape, padding="same")(uconv3)
    uconv2 = concatenate([deconv2, conv2], axis=-1)
    uconv2 = Conv2D(11, filter_shape, activation="relu", padding="same")(uconv2)

    deconv1 = Conv2DTranspose(16, filter_shape, strides=stride_shape, padding="same")(uconv2)
    uconv1 = concatenate([deconv1, conv1], axis=-1)
    uconv1 = Conv2D(8, filter_shape, activation="relu", padding="same")(uconv1)

    # Outputs
    out1 = Conv2D(1, (1, 1), padding="same")(uconv1)
    out2 = Conv2D(1, (1, 1), padding="same")(uconv1)
    out3 = Conv2D(1, (1, 1), padding="same")(uconv1)

    soft = Softmax(axis=-2)
    out1 = soft(out1)
    out2 = soft(out2)
    out3 = soft(out3)

    output = concatenate([out1, out2, out3], axis=-1)

    model = Model(inputs=inputs, outputs=output)
    return model

In [None]:
def resultsHistogram(Preds,dataGen):
    labels = dataGen.getLabels()
    distPList = []
    distSList = []
    for i in range(dataGen.total_len()):
        pLab = labels[i][0]
        sLab = labels[i][1]
        pPick = np.argmax(Preds[i][0][:,0])
        sPick = np.argmax(Preds[i][0][:,1])
        distP = pLab - pPick
        distS = sLab - sPick
        distPList.append(distP)
        distSList.append(distS)

    bound = 50
    fig = plt.figure()

    l1 = plt.axvline(x=np.percentile(np.abs(distP), 75), color='blue', linestyle='--', label=r'$P_{75}$: P')
    plt.axvline(x=-np.percentile(np.abs(distP), 75), color='blue', linestyle='--')
    l2 = plt.axvline(x=np.percentile(np.abs(distS), 75), color='red', linestyle='--', label='$P_{75}$: S')
    plt.axvline(x=-np.percentile(np.abs(distS), 75), color='red', linestyle='--')

    _,_,l3 = plt.hist(distPList, bins=40, range=(-bound, bound), density=True, color='blue', edgecolor='black', alpha=0.5,
             label='P waves')
    _,_,l4 = plt.hist(distSList, bins=40, range=(-bound, bound), density=True, color='red', edgecolor='black', alpha=0.5,
             label='S waves')

    # Custom legend order
    handles = [l1, l2, l3[0], l4[0]]
    labels = [h.get_label() for h in handles]

    plt.xlim([-bound, bound])

    plt.xlabel('Residual (samples)')
    textblock = (                                                                 
                 r"$P_{75}: P = $" + str(int(np.percentile(np.abs(distP),75))) + "\n"
                 r"$P_{75}: S = $" + str(int(np.percentile(np.abs(distS),75))) + "\n"
                 )
    plt.text(-bound + 2, 0.19, textblock, fontsize=10, color='black', ha='left', va='top')

    plt.ylabel('Probability density')
    plt.ylim([0, 0.2])
    plt.legend(handles, labels)
    plt.show()
    return fig