In [2]:
import seisbench.data as sbd
import seisbench.generate as sbg
import seisbench.models as sbm
from seisbench.util import worker_seeding

import obspy
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from obspy.clients.fdsn import Client
from obspy import UTCDateTime

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# i want to load a seisbench data
data = sbd.STEAD()
mask = (data.metadata["trace_category"] == "earthquake_local")
data.filter(mask)
train, dev, test = data.train_dev_test()

In [13]:
data = sbd.DummyDataset(component_order = "ZNE")  # Reload to ensure we have the full dataset again


train = data.train()
dev = data.dev()
test = data.test()

print("Train:", train)
print("Dev:", dev)
print("Test:", test)


Train: DummyDataset - 60 traces
Dev: DummyDataset - 10 traces
Test: DummyDataset - 30 traces


In [21]:
CHANNEL_INDEX = {"Z":0, "N":1, "E":2}  # Channel order for ZNE
CHANNEL = "Z"  # Channel to plot

In [38]:
import seisbench.data as sbd
import numpy as np

# ===========================================================
# 1. Load Dataset
# ===========================================================
data = sbd.DummyDataset(component_order="ZNE")

train = data.train()
dev = data.dev()
test = data.test()

print("Train:", train)
print("Dev:", dev)
print("Test:", test)

# ===========================================================
# 2. Channel setup
# ===========================================================
CHANNEL_INDEX = {"Z": 0, "N": 1, "E": 2}
CHANNEL = "Z"

def preprocess_function(example):
    """
    Extracts the desired channel waveform and attaches it as input_values.
    """
    # SeisBench Trace object → waveform extraction

    # Store as float32 for ML compatibility
    #example["input_values"] = example
    #example["labels"] = example["label"]  # keep label as-is

    return example

print(f"\n--- Current Channel Configuration ---")
print(f"Selected Channel Index: {CHANNEL} ({CHANNEL_INDEX[CHANNEL]})")
print("-----------------------------------")

# ===========================================================
# 3. "Mapping" the preprocessing over SeisBench dataset
# ===========================================================
def map_seisbench(dataset, func):
    """
    Apply a preprocessing function to each SeisBench example.
    Returns a new SeisBench Dataset with modified examples.
    """
    from tqdm import tqdm
    processed = []
    for i in tqdm(range(len(dataset))):
        example = dataset.get_waveforms(i)[0,:]
        processed.append(func(example))
    return processed

# Apply to all splits
train_processed = map_seisbench(train, preprocess_function)
dev_processed = map_seisbench(dev, preprocess_function)
test_processed = map_seisbench(test, preprocess_function)

print("\n✅ Finished preprocessing all splits.")
print(f"Train size: {len(train_processed)}, Dev size: {len(dev_processed)}, Test size: {len(test_processed)}")


Train: DummyDataset - 60 traces
Dev: DummyDataset - 10 traces
Test: DummyDataset - 30 traces

--- Current Channel Configuration ---
Selected Channel Index: Z (0)
-----------------------------------


100%|██████████| 60/60 [00:00<00:00, 1345.72it/s]
100%|██████████| 10/10 [00:00<00:00, 1387.56it/s]
100%|██████████| 30/30 [00:00<00:00, 1661.33it/s]


✅ Finished preprocessing all splits.
Train size: 60, Dev size: 10, Test size: 30





In [90]:
import torch
from torch.utils.data import Dataset
import torchaudio

In [102]:
CHANNEL_INDEX = {"Z": 0, "N": 1, "E": 2}
CHANNEL = "Z"


# ===========================================================
# 2. Define preprocessing function
# ===========================================================
def preprocess_function(trace,idx):
    """
    Extracts the selected component and prepares for model input.
    """
    single_channel = trace[CHANNEL_INDEX[CHANNEL], :]
    # Convert to float32 tensor
    x = torch.tensor(single_channel, dtype=torch.float32)
    x_max = x.abs().max()
    x_norm = x / x_max if x_max > 0 else x
    x_resampled = torchaudio.transforms.Resample(orig_freq=1200, new_freq=16000)(x_norm)

    y = torch.tensor(idx, dtype=torch.long)
    return x_resampled, y


# ===========================================================
# 3. Define custom PyTorch Dataset wrapper
# ===========================================================
class SeisbenchTorchDataset(Dataset):
    def __init__(self, seisbench_dataset, transform=None, sampling_rate=100.0):
        self.dataset = seisbench_dataset
        self.transform = transform
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Lazy load only one trace from disk
        trace = self.dataset.get_waveforms(idx)
        if self.transform:
            x, y = self.transform(trace, idx)
        else:
            x, y = trace, idx
        return {"input_values": x, "labels": y}


# ===========================================================
# 4. Instantiate datasets
# ===========================================================
train_dataset = SeisbenchTorchDataset(train, transform=preprocess_function)

# ===========================================================
# 5. Create DataLoader (lazy + parallel I/O)
# ===========================================================
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,     # parallel loading
    pin_memory=True
)

# ===========================================================
# 6. Example usage
# ===========================================================
for batch in train_loader:
    x = batch["input_values"]  # shape: (batch_size, T)
    y = batch["labels"]
    print(x.shape, y.shape)
    break

torch.Size([32, 16000]) torch.Size([32])




In [97]:
from transformers import Wav2Vec2Model

In [103]:
# run one batch through wav2vec2 model
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
with torch.no_grad():
    outputs = model(x)
    print("Model output shape:", outputs.last_hidden_state.shape)



Model output shape: torch.Size([32, 49, 768])


In [105]:
# Freeze feature extractor (conv front-end)
for param in model.feature_extractor.parameters():
    param.requires_grad = False

# Freeze all encoder layers except the last 3
for layer in model.encoder.layers[:-3]:
    for param in layer.parameters():
        param.requires_grad = False

In [110]:
x.size()

torch.Size([32, 16000])

In [None]:
from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices

model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")

# compute masked indices
batch_size = 32
raw_sequence_length = 16000
sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
mask_time_indices = _compute_mask_indices(
    shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
)
sampled_negative_indices = _sample_negative_indices(
    features_shape=(batch_size, sequence_length),
    num_negatives=model.config.num_negatives,
    mask_time_indices=mask_time_indices,
)
mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
sampled_negative_indices = torch.tensor(
    data=sampled_negative_indices, device=input_values.device, dtype=torch.long
)

with torch.no_grad():
    outputs = model(input_values, mask_time_indices=mask_time_indices)

# compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)

# show that cosine similarity is much higher than random
cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
tensor(True)

# for contrastive loss training model should be put into train mode
model = model.train()
loss = model(
    input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
).loss