In [None]:
"""
Earthquake Event Inference Script Using Trained 2D CNN Model on PSD Features
---------------------------------------------------------------------------

This script performs batch inference for earthquake detection on infrasound PSD data
collected from the Paros sensor network. It processes a user-defined time range by:

1. Loading normalization statistics (mean and std) calculated during training.
2. Loading a trained 2D CNN model (EarthquakeCNN2d) for earthquake vs background classification.
3. Querying and preprocessing waveform data from the Paros sensor via InfluxDB over fixed
   60-second segments, extracting PSD feature vectors (11 windows x 52 frequency bins).
4. Normalizing the PSD vectors with the loaded statistics.
5. Running inference on each PSD vector to predict earthquake probability.
6. Saving results into CSV logs:
   - All predictions
   - Only windows predicted as earthquake events
   - High-confidence earthquake events (probability ≥ 0.90)

This script can be adapted for real-time monitoring or retrospective event analysis by
changing the start and end times, sensor credentials, or output paths.

Author: Ethan Gelfand
Date: 08/12/2025
"""

import numpy as np
import torch
import os
import csv
from tqdm import tqdm
from datetime import datetime, timedelta, UTC
from cnn_model import EarthquakeCNN2d
from DataQueryUtils import psd_vectors_from_range 

# Load normalization stats
mean = np.load("../DataCollection_Preprocessing/Exported_Paros_Data/mean.npy")
std = np.load("../DataCollection_Preprocessing/Exported_Paros_Data/std.npy")

# Load trained model
model = EarthquakeCNN2d(input_shape=(11, 52))
model.load_state_dict(torch.load("../ModelTraining/fold_outputs/fold_3/CNNmodel.pth", map_location="cpu"))
model.eval()
model.to("cpu")

print("Starting Inferences")

# Set test time range (change as needed)
start_time = datetime(2025, 5, 5, 0, 0, 0, tzinfo=None)
end_time = datetime(2025, 5, 5, 23, 59, 59, tzinfo=None)
# Get all PSD vectors for that range
results = psd_vectors_from_range(
    start_time=start_time,
    end_time=end_time,
    sensor_id="141929",
    box_id="parost2",
    password="*****", # Replace with actual password
    fs_in=20,
    fs_out=100,
    window_duration=10,
    overlap=0.5,
    mean=mean,
    std=std
)


# Output CSV for all predictions
dir = "LoggedData"
os.makedirs(dir, exist_ok=True)
log_path = "LoggedData/Earthquake_Predictions_5_05_2025.csv"
event_log_path = "LoggedData/Earthquake_Event_Log_5_05_2025.csv"
strong_event_log_path = "LoggedData/Earthquake_Strong_Event_Log_5_05_2025.csv"

with open(log_path, mode="w", newline="") as f_all, open(event_log_path, mode="w", newline="") as f_event, open(strong_event_log_path, mode="w", newline="") as f_strong_event:
    writer_all = csv.writer(f_all)
    writer_event = csv.writer(f_event)
    writer_strong_event = csv.writer(f_strong_event)
    # Write headers
    header = ["query_time", "window_start", "window_end", "predicted_class", "prob_earthquake", "prob_background"]
    writer_all.writerow(header)
    writer_event.writerow(header)
    writer_strong_event.writerow(header)

    for (window_start, window_end, psd_vector) in tqdm(results, desc="Running inferences", colour="green"):
        input_tensor = torch.tensor(psd_vector, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, 11, 52)
        with torch.no_grad():
            output = model(input_tensor)
            probs = torch.softmax(output, dim=1).numpy()[0]
            pred = np.argmax(probs)

        row = [
            datetime.now(UTC).isoformat(timespec="seconds"),
            window_start.isoformat(timespec="seconds"),
            window_end.isoformat(timespec="seconds"),
            int(pred),
            round(float(probs[1]), 5),
            round(float(probs[0]), 5)
        ]

        writer_all.writerow(row)
        if pred == 1:
            writer_event.writerow(row)
            if probs[1] >= 0.90:
                writer_strong_event.writerow(row)

print("Completed Inferences")



Starting Inferences


Running inferences: 100%|[32m██████████[0m| 1439/1439 [00:00<00:00, 1531.05it/s]

Completed Inferences



