In [6]:
import os

from datetime import datetime

import obspy
import torch
import random
import torch.nn as nn
import pandas as pd
import numpy as np

from scipy import signal
from torch.utils.data import Dataset
from torchvision import transforms
from plotly import graph_objects as go

In [7]:
def make_deterministic(seed: int = 0):
    """
    Make results deterministic.
    If seed == -1, do not make deterministic.
    Running the script in a deterministic way might slow it down.
    """
    if seed == -1:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_spectrogram(st: obspy.core.stream.Stream, minfreq: float = None, maxfreq: float = None, shape=(129, 2555)) -> np.ndarray:
    st_filt = st.copy()

    if minfreq is not None and maxfreq is not None:
        st_filt.filter('bandpass', freqmin=minfreq, freqmax=maxfreq)

    tr_filt = st_filt.traces[0].copy()
    tr_data_filt = tr_filt.data

    _, _, spectrogram = signal.spectrogram(tr_data_filt, tr_filt.stats.sampling_rate)

    # Normalize
    spectrogram = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min())

    current_shape = spectrogram.shape
    padded_spectrogram = np.zeros(shape, dtype=np.float64)
    min_rows = min(current_shape[0], shape[0])
    min_cols = min(current_shape[1], shape[1])
    padded_spectrogram[:min_rows, :min_cols] = spectrogram[:min_rows, :min_cols]

    return padded_spectrogram.astype(np.float64)

def create_label(st: obspy.core.stream.Stream, row: pd.Series) -> float:
    # Start time of trace (another way to get the relative arrival time using datetime)
    arrival = row['time_rel(sec)']
    starttime = st.traces[0].stats.starttime.datetime
    total = (st.traces[0].stats.endtime.datetime - starttime).total_seconds()
    
    return arrival / total

In [8]:
class ComposeWithLabels:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, sample):
        for t in self.transforms:
            sample = t(sample)
        return sample

class RandomApplyWithLabels:
    def __init__(self, transform, p=0.5):
        self.transform = transform
        self.p = p

    def __call__(self, sample):
        if torch.rand(1).item() < self.p:
            sample = self.transform(sample)
        return sample

class RandomTimeShift:
    """
    Shifts the spectrogram in the time dimension by a random amount.
    Also adjusts the time label accordingly.
    """
    def __init__(self, shift_range):
        self.shift_range = shift_range

    def __call__(self, sample):
        spectrogram = sample['spectrogram']
        label = sample['label']

        shift = np.random.randint(-self.shift_range, self.shift_range)
        total_time_steps = spectrogram.shape[2]

        if shift == 0:
            pass  # No change needed
        elif shift > 0:
            # Shift to the right
            padding = torch.zeros(spectrogram.shape[0], spectrogram.shape[1], shift)
            spectrogram = torch.cat((padding, spectrogram[:, :, :-shift]), dim=2)
        else:
            # Shift to the left
            shift = -shift
            padding = torch.zeros(spectrogram.shape[0], spectrogram.shape[1], shift)
            spectrogram = torch.cat((spectrogram[:, :, shift:], padding), dim=2)

        # Adjust the label
        label += (shift / total_time_steps)
        label = torch.clamp(label, 0.0, 1.0)

        sample['spectrogram'] = spectrogram
        sample['label'] = label
        return sample

class RandomTimeMask:
    def __init__(self, max_mask_size):
        self.max_mask_size = max_mask_size

    def __call__(self, sample):
        spectrogram = sample['spectrogram']
        _, _, t = spectrogram.shape
        mask_size = np.random.randint(0, self.max_mask_size)
        t0 = np.random.randint(0, t - mask_size)
        spectrogram[:, :, t0:t0 + mask_size] = 0

        sample['spectrogram'] = spectrogram
        return sample

class RandomFrequencyMask:
    def __init__(self, max_mask_size):
        self.max_mask_size = max_mask_size

    def __call__(self, sample):
        spectrogram = sample['spectrogram']
        _, f, _ = spectrogram.shape
        mask_size = np.random.randint(0, self.max_mask_size)
        f0 = np.random.randint(0, f - mask_size)
        spectrogram[:, f0:f0 + mask_size, :] = 0

        sample['spectrogram'] = spectrogram
        return sample

class AddNoise:
    def __init__(self, noise_level=0.005):
        self.noise_level = noise_level

    def __call__(self, sample):
        spectrogram = sample['spectrogram']
        noise = torch.randn_like(spectrogram) * self.noise_level
        spectrogram = spectrogram + noise

        sample['spectrogram'] = spectrogram
        return sample

class AmplitudeScaling:
    def __init__(self, scale_range=(0.8, 1.2)):
        self.scale_range = scale_range

    def __call__(self, sample):
        spectrogram = sample['spectrogram']
        scale = np.random.uniform(*self.scale_range)
        spectrogram = spectrogram * scale

        sample['spectrogram'] = spectrogram
        return sample
    
from scipy.ndimage import gaussian_filter

class RandomSpikeAugmentation:
    def __init__(self, base_spike_value=1.0, spike_duration=1, max_num_spikes=3, fade_factor=0.8, noise_level=0.25, size=4, sigma=1):
        """
        Adds N random spikes to the spectrogram with discrete steps, noise, and frequency fade,
        and applies a Gaussian filter to smooth the spikes.
        
        :param base_spike_value: The base value of the spike.
        :param spike_duration: Duration of each spike in time steps.
        :param max_num_spikes: Number of spikes to add.
        :param fade_factor: Factor by which the spike fades at higher frequencies.
        :param noise_level: The amount of random noise to add to the spike.
        :param size: Defines which portion of frequencies will be affected.
        :param sigma: Standard deviation for Gaussian filter.
        """
        self.base_spike_value = base_spike_value
        self.spike_duration = spike_duration
        self.max_num_spikes = max_num_spikes
        self.fade_factor = fade_factor
        self.noise_level = noise_level
        self.size = size
        self.sigma = sigma

    def __call__(self, sample):
        spectrogram = sample['spectrogram']
        _, f, t = spectrogram.shape
        
        num_spikes = np.random.randint(1, self.max_num_spikes) if self.max_num_spikes > 1 else 1
        
        for _ in range(num_spikes):
            # Randomly select the start time for the spike
            spike_start = np.random.randint(0, t - self.spike_duration)
            
            # Create a spike that fades at higher frequencies and has some discrete steps
            for i in range(int(f * (1 / self.size))):  # Iterate over first 1/size of frequencies
                # Compute the fade factor for the current frequency
                fade = self.fade_factor ** i
                
                # Create a spike with noise and discrete steps
                spike_value = self.base_spike_value * fade + (np.random.randn() * self.noise_level)
                spike = torch.ones(self.spike_duration) * spike_value
                spike = torch.clamp(spike, 0.0, 1.0)
                
                # Apply the spike to the spectrogram at the current frequency
                spectrogram[:, i, spike_start:spike_start + self.spike_duration] += spike

        # Convert the spectrogram to numpy for applying the gaussian filter
        spectrogram_np = spectrogram.numpy()

        # Apply Gaussian filter to smooth the spikes
        spectrogram_np = gaussian_filter(spectrogram_np, sigma=self.sigma)

        # Convert back to torch tensor
        sample['spectrogram'] = torch.tensor(spectrogram_np)

        return sample

In [9]:
class SpectrogramDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, sample: float = 1.0, transform=None, augmentations=False):
        self.samples_df = dataframe.sample(frac=sample, replace=False if sample == 1.0 else True)
        # self.samples_df = self.samples_df.sort_values(
        #     by='evid',
        #     key=lambda x: x.str.extract('(\d+)$').iloc[:, 0].astype(int)
        # )
        self.augmentations = augmentations
        
        if transform is None:
            self.transform = transforms.Compose([transforms.ToTensor()])
        else:
            self.transform = transform

        if self.augmentations:
            self.augmentation_transforms = ComposeWithLabels([
                RandomApplyWithLabels(RandomTimeShift(shift_range=20), p=1.0),
                RandomApplyWithLabels(RandomTimeMask(max_mask_size=50), p=1.0),
                RandomApplyWithLabels(RandomFrequencyMask(max_mask_size=2), p=0.5),
                RandomApplyWithLabels(AddNoise(noise_level=0.0075), p=1.0),
                RandomApplyWithLabels(AmplitudeScaling(scale_range=(0.8, 1.2)), p=1.0),
                RandomApplyWithLabels(RandomSpikeAugmentation(size=4, max_num_spikes=2), p=1.0),
                RandomApplyWithLabels(RandomSpikeAugmentation(size=2, max_num_spikes=1), p=1.0),
            ])
        else:
            self.augmentation_transforms = None

    def __len__(self):
        return len(self.samples_df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        row = self.samples_df.iloc[idx]
        spectrogram = np.load(f'./data/lunar/test/spectrograms/{row.filename.split("/")[-1]}.npz')['arr_0']
        label = torch.tensor(row.label, dtype=torch.float64)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
            
        sample = {'spectrogram': spectrogram, 'label': label}

        if self.augmentation_transforms:
            sample = self.augmentation_transforms(sample)
        
        return sample['spectrogram'].double(), sample['label'].double()

In [10]:
test_data_path =  "./data/lunar/test/data"

filenames = []
labels = []
for folder in os.listdir(test_data_path):
    for filename in os.listdir(os.path.join(test_data_path, folder)):
        if not filename.endswith(".mseed"):
            continue

        st = obspy.read(os.path.join(test_data_path, folder, filename))
        spectrogram = create_spectrogram(st, 0.001, 1.0)
        os.makedirs('./data/lunar/test/spectrograms/', exist_ok=True)
        spectrogram_path = os.path.join('./data/lunar/test/spectrograms/', filename.replace(".mseed", ""))
        np.savez(spectrogram_path, spectrogram)
        filenames.append(os.path.join(test_data_path, folder, filename.replace(".mseed", "")))
        labels.append(0)
    
df = pd.DataFrame({'filename': filenames, 'label': labels})
df.to_csv('./data/lunar/test/spectrograms.csv', index=False)

In [11]:
class SeismicCNN(nn.Module):
    def __init__(self):
        super(SeismicCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(3, 3), padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 16 * 319, 128)
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()
        self.double()

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 16 * 319)
        x = torch.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x
    
    def save(self, path):
        torch.save(self.state_dict(), path)
        
    def load(self, path):
        self.load_state_dict(torch.load(path, weights_only=False))
        self.eval()

In [12]:
model = SeismicCNN()
model.load('./data/lunar/models/seismic_activity_cnn_best.pth')
model.eval()

test_dataset = SpectrogramDataset(df, augmentations=False)

fnames = []
detection_times = []
relative_times = []
save_images = False
save_folder = "./data/lunar/test"

for index in range(len(test_dataset)):
    spectrogram, label = test_dataset[index]
    test_filename = test_dataset.samples_df.iloc[index].filename
    tr = obspy.read(test_filename + ".mseed")[0]
    tr_data = tr.data
    tr_times = tr.times()
    starttime = tr.stats.starttime.datetime
    endtime = tr.stats.endtime.datetime
    total_seconds = (endtime - starttime).total_seconds()

    prediction = model(spectrogram).item()
    relative_time = prediction * total_seconds
    
    sampled_time = int(relative_time * tr.stats.sampling_rate)

    on_time = starttime + pd.Timedelta(seconds=relative_time)
    on_time_str = datetime.strftime(on_time, '%Y-%m-%dT%H:%M:%S.%f')
    fnames.append(test_filename.split("/")[-1] + ".mseed")
    detection_times.append(on_time_str)
    relative_times.append(relative_time)

    if save_images:
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=tr_times, y=tr_data, mode='lines', name='Seismogram'
        ))
        fig.add_vline(x=tr_times[sampled_time], line=dict(color='red'), annotation_text="Trig. On", annotation_position="top left")
        # Customize the layout
        fig.update_layout(
            title="Seismogram with STA/LTA Triggers",
            xaxis_title="Time (s)",
            yaxis_title="Amplitude",
            xaxis_range=[min(tr_times), max(tr_times)],
            height=400,
            width=900
        )
        fig.write_image(os.path.join(f'{save_folder}/plots/{test_filename}.png'))

detect_df = pd.DataFrame(data = {
    'filename':fnames,
    'time_abs(%Y-%m-%dT%H:%M:%S.%f)':detection_times,
    'time_rel(sec)': relative_times,
})

detect_df.to_csv(f'{save_folder}/catalog.csv', index=False)