In [1]:
import gc

#!/usr/bin/env python3
"""
ls_association_with_clock_drift.py
Seismic event localization with clock drift estimation.
Handles stations with unknown clock drift by initially assigning larger uncertainties
and then estimating the drift parameters after initial localization.
"""
import os
import numpy as np
import pandas as pd
from pyproj import Geod
from joblib import Parallel, delayed, Memory
import time
import psutil
import pickle
import warnings
from utils.physics.sound_model.ellipsoidal_sound_model import GridEllipsoidalSoundModel
# === CONFIGURATION ===
ASSO_FILE = "/media/rsafran/CORSAIR/Association/2018/grids/2018/s_-60-5,35-120,350,0.8,0.6.npy"
OUTPUT_DIR = "/media/rsafran/CORSAIR/Association/validated"
OUTPUT_BASENAME = "s_-60-5,35-120,350,0.8,0.6_with_drift"
CHUNK_SIZE = 25  # checkpoint every N dates
N_JOBS = max(1, 0*os.cpu_count() - 1)  # leave one core free
GRID_LAT_BOUNDS = [-60, 5]
GRID_LON_BOUNDS = [35, 120]
GRID_SIZE = 350
ISAS_PATH = "/media/rsafran/CORSAIR/ISAS/extracted/2018"
BATCH_SIZE = 5000
#SOUND_MODEL
arr = os.listdir(ISAS_PATH)
file_list = [os.path.join(ISAS_PATH, fname) for fname in arr if fname.endswith('.nc')]
SOUND_MODEL = GridEllipsoidalSoundModel(file_list)

# === PERFORMANCE MONITORING ===
start_time = time.time()


def log_progress(message):
    elapsed = time.time() - start_time
    process = psutil.Process(os.getpid())
    memory_usage = process.memory_info().rss / 1024 / 1024  # MB
    print(f"[{elapsed:.1f}s | {memory_usage:.1f}MB] {message}")


# === INITIALIZATION ===

# Precompute grid lat/lon
PTS_LAT = np.linspace(*GRID_LAT_BOUNDS, GRID_SIZE)
PTS_LON = np.linspace(*GRID_LON_BOUNDS, GRID_SIZE)

# Geod instance for geodesic calculations
geod = Geod(ellps="WGS84")

def grid_index_to_coord(indices):
    """Convert grid indices to geographic coordinates"""
    i, j = indices
    return [PTS_LAT[i], PTS_LON[j]]

def process_date(date, associations_list):
    res = []
    # Create simplified associations list to avoid serialization issues
    simplified_associations = []
    for detections, valid_points in associations_list:
        simple_detections = []
        for station_obj, det_time in detections:
            # Extract only necessary data from station_obj
            lat, lon = station_obj.get_pos()
            drift = station_obj.get_clock_error(det_time) if "not_ok" in station_obj.other_kwargs.values() else 0
            station_name = station_obj.name  # Get station name
            simple_detections.append(((lat, lon), det_time, drift, station_name))
        simplified_associations.append((simple_detections, valid_points))

    for detections, valid_points in simplified_associations:
        # Skip tiny clusters
        if len(detections) < 7:
            continue

        # Build refined detections & station positions
        station_positions = [pos for pos, _,_, _ in detections]
        detection_times = [t for _, t,_, _ in detections]
        drifts = [d for _,_,d, _ in detections]
        print(np.array(valid_points))
        c0 = np.mean(np.array(valid_points), axis=0)
        print(c0)
        r, _, _ = SOUND_MODEL.localize_with_uncertainties(
            station_positions, detection_times, drift_uncertainties=drifts, initial_pos=c0
        )
        res.append(r)


    return res



"""Main execution function"""
log_progress(f"Starting with {N_JOBS} workers")

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load input
log_progress(f"Loading associations from {ASSO_FILE}")
associations = np.load(ASSO_FILE, allow_pickle=True).item()
items = list(associations.items())
total_items = len(items)
log_progress(f"Found {total_items} date entries to process")

# Process in batches with checkpoints
validated_associations = {}

for batch_start in range(0, total_items, BATCH_SIZE):
    batch_end = min(batch_start + BATCH_SIZE, total_items)
    batch = items[batch_start:batch_end]

    log_progress(f"Processing batch {batch_start + 1}-{batch_end} of {total_items}")

    # Process batch in parallel
    # Note: we use a smaller chunk_size when jobs > 1 for better load balancing
    effective_chunk = 1 if N_JOBS > 1 else CHUNK_SIZE

    results = Parallel(n_jobs=1, verbose=5, batch_size=effective_chunk)(
        delayed(process_date)(date, lst) for date, lst in batch
    )

    # Store results
    i= 0
    for res in results:
        if res:  # Only store if we have validated results
            validated_associations[i]
            i+=1

    # Checkpoint
    chkpt_path = os.path.join(
        OUTPUT_DIR,
        f"{OUTPUT_BASENAME}_partial_{batch_end}.npy"
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        np.save(chkpt_path, validated_associations)

    log_progress(f"Checkpoint saved: {chkpt_path}")

    # Memory optimization: save and reload larger checkpoints
    if len(validated_associations) > 5000:
        log_progress("Checkpointing and refreshing memory...")
        pickle_path = os.path.join(OUTPUT_DIR, "temp_checkpoint.pkl")
        with open(pickle_path, 'wb') as f:
            pickle.dump(validated_associations, f)

        # Clear and reload
        validated_associations.clear()
        with open(pickle_path, 'rb') as f:
            validated_associations = pickle.load(f)

        # Force garbage collection
        import gc
        gc.collect()


# Final save
log_progress("Saving final results")
final_path = os.path.join(OUTPUT_DIR, f"{OUTPUT_BASENAME}_final.npy")
np.save(final_path, validated_associations)

elapsed = time.time() - start_time
log_progress(f"All done! Final results saved to {final_path}")
log_progress(f"Total execution time: {elapsed:.1f} seconds ({elapsed / 60:.1f} minutes)")


[0.0s | 191.2MB] Starting with 1 workers
[0.0s | 191.2MB] Loading associations from /media/rsafran/CORSAIR/Association/2018/grids/2018/s_-60-5,35-120,350,0.8,0.6.npy
[4.1s | 2949.5MB] Found 13575 date entries to process
[4.1s | 2949.5MB] Processing batch 1-5000 of 13575


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 287 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    0.8s


[[90 26]
 [90 27]
 [90 28]
 [90 29]
 [91 23]
 [91 24]
 [91 25]]
[90.42857143 26.        ]
initial pos  [90.42857143 26.        ]


ValueError: Point [90.42857142857143, 26.0] outside grid bounds. pos : [[90.42857142857143, 26.0], (-38.5465, 52.9287)]