In [20]:
import torch as th
import pytest
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
import os
import glob

In [59]:
class SeismicDataset(Dataset):

    def __init__(self, folder_path, is_noise = False, randomized = False):

        """
        Args:
            folder_path (str): Path to the folder containing .npz files.
            is_noise (bool): 
        """

        self.folder_path = folder_path
        self.file_names = glob.glob(f'{folder_path}/**/*.npz', recursive=True)
        self.is_noise = is_noise
        self.randomized = randomized


    def __len__(self) -> int:
        return len(self.file_names)

    def __getitem__(self, idx) -> tuple[th.Tensor, int, str]:
        wave_path = self.file_names[idx]
        wave = np.load(wave_path, allow_pickle=True)

        wave_name = (os.path.splitext(os.path.basename(wave_path)))[0]

        upper_bound = 6000
        if self.is_noise:
            upper_bound = 12000

        start = 0
        if self.randomized:
            start = np.random.randint(low = 0, high = upper_bound)

        Z_channel = wave['earthquake_waveform_Z'][start:start+6000]
        N_channel = wave['earthquake_waveform_N'][start:start+6000]
        E_channel = wave['earthquake_waveform_E'][start:start+6000]

        event = np.stack([Z_channel, N_channel, E_channel], axis=0)

        tensor = th.from_numpy(event)

        tensor_normalized = tensor / tensor.abs().max()

        p_wave_start = 6000 - start

        return tensor_normalized, p_wave_start, wave_name


In [60]:
import ipytest
ipytest.autoconfig()

def test_addition():
    assert 1 + 1 == 2

def test_dataset_length():
    folder_path = "C:/Users/cleme/ETH/Master/DataLab/dsl-as24-challenge-3/data/signal/train"
    dataset = SeismicDataset(folder_path)
    assert len(dataset) == 20230 

def test_output_shape():
    folder_path = "C:/Users/cleme/ETH/Master/DataLab/dsl-as24-challenge-3/data/signal/train"
    dataset = SeismicDataset(folder_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    next_tensor, _, _ = next(iter(dataloader))
    assert next_tensor.shape == (1,3,6000)

def test_is_noise_randomized():
    folder_path = "C:/Users/cleme/ETH/Master/DataLab/dsl-as24-challenge-3/data/signal/train"
    dataset = SeismicDataset(folder_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    next_tensor, p_wave_start, _ = next(iter(dataloader))
    assert p_wave_start >= 0 and p_wave_start <= 6000

def test_is_normalized():
    folder_path = "C:/Users/cleme/ETH/Master/DataLab/dsl-as24-challenge-3/data/signal/train"
    dataset = SeismicDataset(folder_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    next_tensor, p_wave_start, _ = next(iter(dataloader))
    assert th.all(th.logical_and(next_tensor >= -1,next_tensor <= 1))

def test_wave_name():
    folder_path = "C:/Users/cleme/ETH/Master/DataLab/dsl-as24-challenge-3/data/signal/train"
    dataset = SeismicDataset(folder_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    next_tensor, p_wave_start, wave_name = next(iter(dataloader))
    assert len(wave_name) > 0


ipytest.run('-vv')

platform win32 -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- c:\Users\cleme\miniconda3\envs\dsl\python.exe
cachedir: .pytest_cache
rootdir: c:\Users\cleme\ETH\Master\DataLab\dsl-as24-challenge-3
plugins: anyio-4.2.0, typeguard-4.3.0
[1mcollecting ... [0mcollected 6 items

t_dcd6bca15e86460a80240907e5f6295f.py::test_addition [32mPASSED[0m[32m                                  [ 16%][0m
t_dcd6bca15e86460a80240907e5f6295f.py::test_dataset_length [32mPASSED[0m[32m                            [ 33%][0m
t_dcd6bca15e86460a80240907e5f6295f.py::test_output_shape [32mPASSED[0m[32m                              [ 50%][0m
t_dcd6bca15e86460a80240907e5f6295f.py::test_is_noise_randomized [32mPASSED[0m[32m                       [ 66%][0m
t_dcd6bca15e86460a80240907e5f6295f.py::test_is_normalized [32mPASSED[0m[32m                             [ 83%][0m
t_dcd6bca15e86460a80240907e5f6295f.py::test_wave_name [32mPASSED[0m[32m                                 [100%][0m



<ExitCode.OK: 0>