In [None]:
from src.notebooks.demo.dev_association import FIRSTS_DETECTIONS
from src.utils.simulation.synthetic import RealStationDataGenerator
from datetime import datetime, timedelta
from pathlib import Path
from src.utils.data_reading.sound_data.station import StationsCatalog
from src.utils.physics.sound_model.spherical_sound_model import HomogeneousSphericalSoundModel as HomogeneousSoundModel
from src.utils.detection.association import compute_candidates, association_is_new, update_valid_grid, update_results
from src.utils.detection.association import compute_grids

In [None]:
catalog_path = "../../../data/recensement_stations_OHASISBIO_RS.csv"
dataset ="OHASISBIO-2018"
STATION = StationsCatalog(catalog_path).by_dataset(dataset)
SOUND_MODEL = HomogeneousSoundModel(sound_speed=1485.5)

In [None]:
gen = RealStationDataGenerator(STATION, SOUND_MODEL)
data = gen.generate_events(datetime.datetime(year=2018, month=1, day=1))
gen.plot_stations_and_events()

In [None]:
STATION.get_coordinate_list()[:,1]

In [None]:
import numpy as np
import pandas as pd
import os
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import multiprocessing as mp

def load_ridge_data(dorsal_db_path):
    """
    Charge les données des dorsales océaniques
    """
    dorsal_files = [f for f in os.listdir(dorsal_db_path) if f.endswith('.xy')]
    print(f"Loading {len(dorsal_files)} ridge files: {dorsal_files}")

    ridge_data = {}
    all_ridge_points = []

    for f in dorsal_files:
        ridge_name = f.replace('axe-', '').replace('.xy', '')
        df = pd.read_csv(os.path.join(dorsal_db_path, f),
                        comment=">", sep=r'\s+')

        # Nettoyage des données
        # df = df.dropna()
        ridge_points = df[['x', 'y']].values

        ridge_data[ridge_name] = ridge_points
        all_ridge_points.append(ridge_points)

        print(f"  {ridge_name}: {len(ridge_points)} points")

    # Combinaison de toutes les dorsales
    all_ridge_points = np.vstack(all_ridge_points)
    print(f"Total ridge points: {len(all_ridge_points)}")

    return ridge_data, all_ridge_points

ridge_data, all_ridge_points = load_ridge_data("../../../data/dorsales/")


def generate_events_near_ridges(n_events, ridge_points, std_km=50):
    """
    Génère des événements proches des dorsales océaniques
    :param n_events: nombre d'événements à générer
    :param ridge_points: array Nx2 des coordonnées des dorsales (lat, lon)
    :param std_km: écart-type du bruit autour des dorsales en km
    :return: array Nx2 des positions des événements
    """
    events = []
    for _ in range(n_events):
        # choisir un point aléatoire sur les dorsales
        ridge_idx = np.random.randint(0, len(ridge_points))
        base_point = ridge_points[ridge_idx]

        # ajouter du bruit normal autour du point (en degrés approximativement)
        # 1 deg ~ 111 km, donc std_deg = std_km / 111
        std_deg = std_km / 111.0
        evt_lat = base_point[1] + np.random.normal(0, std_deg)
        evt_lon = base_point[0] + np.random.normal(0, std_deg)
        events.append([evt_lat, evt_lon])

    return np.array(events)

n_events = 100
simulated_events = generate_events_near_ridges(n_events, all_ridge_points, std_km=150)

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt

lon_max, lat_max = all_ridge_points.max(axis=0)+5
lon_min, lat_min = all_ridge_points.min(axis=0)-5
# Create a map projection (PlateCarree is a simple projection for global data)
fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

# Add coastlines and land features
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, facecolor='lightgray')
for name in ridge_data.keys():
    plt.plot(ridge_data[name][:,0], ridge_data[name][:,1])

# Visualisation rapide
plt.scatter(simulated_events[:,1], simulated_events[:,0], c='red', s=20, label='Simulated Events')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.legend()
plt.title("Simulation d'événements proches des dorsales")
plt.show()

In [None]:
# Example usage with enhanced features

# Create generator with ridge-based events and clock drift
generator = RealStationDataGenerator(
    stations=STATION,
    sound_model=SOUND_MODEL,
    n_real_events=100,
    n_noise_detections=0,
    ridge_data_path="../../../data/dorsales/",
    ridge_std_km=100,  # Events within 100km of ridges
    perfect_events=True,  # Add timing noise
    apply_clock_drift=True,  # Apply clock drift errors
    reference_time_years=1,  # 2 years reference time
    seed=42
)

# Generate events
start_time = datetime.datetime(2023, 1, 1)
events, ground_truth = generator.generate_events(start_time, duration_hours=96)

# Plot results
generator.plot_stations_and_events()

# Show simulation info
print("Simulation parameters:")
for key, value in generator.get_simulation_info().items():
    print(f"  {key}: {value}")

In [None]:
import numpy as np
import pandas as pd
from datetime import datetime, timedelta


import numpy as np
import datetime
import pandas as pd
from datetime import timedelta

def process_synthetic_detections(events_df,
                                min_p_tissnet_primary=0.1,
                                min_p_tissnet_secondary=0.1,
                                merge_delta=timedelta(seconds=5)):
    """
    Process synthetic detections from the RealStationDataGenerator

    Parameters:
    -----------
    events_df : DataFrame
        With columns ['datetime', 'station', 'probability', 'true_event']
    min_p_tissnet_primary : float
        Minimum probability threshold for primary detections
    min_p_tissnet_secondary : float
        Minimum probability threshold for secondary detections
    merge_delta : timedelta
        Time delta for merging close detections

    Returns:
    --------
    detections : dict
        station -> np.array([[datetime, probability], ...])
    detections_merged : np.ndarray
        Array of [datetime, probability, station], sorted by datetime
    """

    # Garder uniquement les événements qui passent le seuil secondaire
    filtered_events = events_df[events_df['probability'] > min_p_tissnet_secondary].copy()

    detections = {}
    detections_merged_list = []

    # Groupement par station
    for station in filtered_events['station'].unique():
        station_events = filtered_events[filtered_events['station'] == station].copy()
        station_events = station_events.sort_values('datetime')

        d = station_events[['datetime', 'probability']].values

        if len(d) == 0:
            detections[station] = np.array([]).reshape(0, 2)
            continue

        # Nettoyage des détections proches ou régulières
        new_d = [d[0]]
        for i in range(1, len(d)):
            dt = (d[i, 0] - d[i - 1, 0]).total_seconds()
            if dt > merge_delta.total_seconds():
                if i < 3:
                    new_d.append(d[i])
                else:
                    dt1 = (d[i, 0] - d[i - 1, 0]).total_seconds()
                    dt2 = (d[i - 1, 0] - d[i - 2, 0]).total_seconds()

                    condition1 = abs(dt1 - dt2) > merge_delta.total_seconds()

                    if i >= 4:
                        dt3 = (d[i - 1, 0] - d[i - 3, 0]).total_seconds()
                        dt4 = (d[i - 2, 0] - d[i - 4, 0]).total_seconds()
                        condition2 = abs(dt3 - dt4) > merge_delta.total_seconds()
                    else:
                        condition2 = True

                    if condition1 and condition2:
                        new_d.append(d[i])

        d = np.array(new_d, dtype=object)
        detections[station] = d

        print(f"Found {len(d)} detections for station {station}")

        # Ajouter à la liste globale
        for det in d:
            detections_merged_list.append([det[0], det[1], station])

    # Création du tableau final
    if len(detections_merged_list) == 0:
        detections_merged = np.array([]).reshape(0, 3)
    else:
        detections_merged = np.array(detections_merged_list, dtype=object)

        # Filtrer sur le seuil primaire
        detections_merged = detections_merged[detections_merged[:, 1] > min_p_tissnet_primary]

        # Trier par datetime
        if len(detections_merged) > 0:
            detections_merged = detections_merged[np.argsort(detections_merged[:, 0])]

    return detections, detections_merged


def analyze_synthetic_detections(detections, detections_merged, ground_truth_df=None):
    """
    Analyze the processed synthetic detections
    """
    analysis = {}

    # Basic statistics
    total_detections = sum(len(det) for det in detections.values())
    analysis['total_detections'] = total_detections
    analysis['total_merged_detections'] = len(detections_merged)
    analysis['num_stations_with_detections'] = len([s for s in detections if len(detections[s]) > 0])

    # Per station statistics
    station_stats = {}
    for station, dets in detections.items():
        station_stats[str(station)] = {
            'num_detections': len(dets),
            'avg_probability': np.mean(dets[:, 1]) if len(dets) > 0 else 0,
            'std_probability': np.std(dets[:, 1]) if len(dets) > 0 else 0
        }
    analysis['station_stats'] = station_stats

    # Time span analysis
    if len(detections_merged) > 1:
        start_time = detections_merged[0, 0]
        end_time = detections_merged[-1, 0]
        time_span_seconds = (end_time - start_time).total_seconds()
        analysis['time_span_hours'] = time_span_seconds / 3600
        analysis['detection_rate_per_hour'] = len(detections_merged) / (time_span_seconds / 3600)

    # Ground truth comparison if available
    if ground_truth_df is not None:
        analysis['num_true_events'] = len(ground_truth_df)
        analysis['detection_efficiency'] = total_detections / len(ground_truth_df) if len(ground_truth_df) > 0 else 0

    return analysis



def plot_detection_timeline(detections_merged, ground_truth_df=None):
    """
    Plot timeline of detections and optionally ground truth events
    """
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter

    if len(detections_merged) == 0:
        print("No detections to plot")
        return

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)

    # Use datetime directly
    timestamps = detections_merged[:, 0]
    probabilities = detections_merged[:, 1].astype(float)

    # Plot detections timeline
    ax1.scatter(timestamps, probabilities, alpha=0.6, s=30)
    ax1.set_ylabel('Detection Probability')
    ax1.set_title('Synthetic Detections Timeline')
    ax1.grid(True, alpha=0.3)

    # Plot detections per station
    station_objects = detections_merged[:, 2]
    station_strings = [str(s) for s in station_objects]
    unique_station_strings = list(set(station_strings))

    # Mapping string -> original station object
    station_map = {}
    for i, station_str in enumerate(station_strings):
        if station_str not in station_map:
            station_map[station_str] = station_objects[i]

    unique_stations = [station_map[s] for s in unique_station_strings]
    station_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_stations)))

    for i, station in enumerate(unique_stations):
        station_mask = np.array([s == station for s in station_objects])
        station_times = np.array(timestamps)[station_mask]
        ax2.scatter(station_times, [i] * len(station_times),
                   c=[station_colors[i]], alpha=0.7, s=50,
                   label=str(station))

    ax2.set_ylabel('Station')
    ax2.set_xlabel('Time')
    ax2.set_title('Detections by Station')
    ax2.set_yticks(range(len(unique_stations)))
    ax2.set_yticklabels([str(s) for s in unique_stations])
    ax2.grid(True, alpha=0.3)

    # Format x-axis
    ax2.xaxis.set_major_formatter(DateFormatter('%H:%M'))

    # Add ground truth if available
    if ground_truth_df is not None:
        for idx, event in ground_truth_df.iterrows():
            ax1.axvline(event['origin_time'], color='red', alpha=0.5, linestyle='--')
            ax2.axvline(event['origin_time'], color='red', alpha=0.5, linestyle='--')

    plt.tight_layout()
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.show()



# Exemple d'utilisation
if __name__ == "__main__":

    # Traiter les détections synthétiques
    DETECTIONS, DETECTIONS_MERGED = process_synthetic_detections(
        events,
        min_p_tissnet_primary=0.1,
        min_p_tissnet_secondary=0.1,
        merge_delta=timedelta(seconds=1)
    )
    #
    # # Analyser les résultats
    # analysis = analyze_synthetic_detections(DETECTIONS, DETECTIONS_MERGED, ground_truth)
    #
    # print("\nDetection Analysis:")
    # for key, value in analysis.items():
    #     if key != 'station_stats':
    #         print(f"{key}: {value}")
    #
    # # Visualiser la timeline
    # plot_detection_timeline(DETECTIONS_MERGED, ground_truth)

In [None]:
FIRSTS_DETECTIONS

## Association

In [None]:
STATIONS = [s for s in DETECTIONS.keys()]
FIRSTS_DETECTIONS = {s : DETECTIONS[s][0,0] for s in STATIONS}
LASTS_DETECTIONS = {s : DETECTIONS[s][-1,0] for s in STATIONS}

In [None]:
ASSOCIATION_OUTPUT_DIR = "../../../data"
DETECTIONS_DIR_NAME = "synthetic_detections"
MIN_P_TISSNET_PRIMARY = 0.1
MIN_P_TISSNET_SECONDARY = 0.1
MERGE_DELTA_S = 10 # threshold below which we consider two events should be merged
REQ_CLOSEST_STATIONS = 0  # The REQ_CLOSEST_STATIONS th closest stations will be required for an association to be valid
RUN_ASSOCIATION = True # set to False to load previous associations without processing it again
NCPUS = 8
SAVE_PATH_ROOT = None

In [None]:
lon_max, lat_max = all_ridge_points.max(axis=0)+5
lon_min, lat_min = all_ridge_points.min(axis=0)-5
LAT_BOUNDS = [lat_min, lat_max]
LON_BOUNDS = [lon_min, lon_max]
GRID_SIZE = 400  # number of points along each axis

(PTS_LAT, PTS_LON, STATION_MAX_TRAVEL_TIME, GRID_STATION_TRAVEL_TIME,
 GRID_STATION_COUPLE_TRAVEL_TIME, GRID_TOLERANCE) = compute_grids(LAT_BOUNDS, LON_BOUNDS, GRID_SIZE, SOUND_MODEL, STATIONS, pick_uncertainty=0, sound_speed_uncertainty=0)

In [None]:
from tqdm import tqdm
from concurrent.futures import as_completed

print("starting association")

OUT_DIR = f"{ASSOCIATION_OUTPUT_DIR}/grids/{DETECTIONS_DIR_NAME}"
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
OUT_FILE = f"{OUT_DIR}/refined_s_{LAT_BOUNDS[0]}-{LAT_BOUNDS[1]},{LON_BOUNDS[0]}-{LON_BOUNDS[1]},{GRID_SIZE},{MIN_P_TISSNET_PRIMARY},{MIN_P_TISSNET_SECONDARY}.npy".replace(" ","")

association_hashlist = set()
associations = {}
processed = set()
def process_detection(arg):
    detection, local_association_hashlist, processed = arg
    local_association = {}
    date1, p1, s1 = detection
    save_path = SAVE_PATH_ROOT
    if save_path is not None:
        save_path = f'{save_path}/{s1.name}-{date1.strftime("%Y%m%d_%H%M%S")}'
        Path(save_path).mkdir(parents=True, exist_ok=True)

    # list all other stations and sort them by distance from s1
    other_stations = np.array([s2 for s2 in STATIONS if s2 != s1
                               and date1+ timedelta(seconds=4*GRID_TOLERANCE) > FIRSTS_DETECTIONS[s2]
                               and date1 - timedelta(seconds=4*GRID_TOLERANCE) < LASTS_DETECTIONS[s2]])
    other_stations = other_stations[np.argsort([STATION_MAX_TRAVEL_TIME[s1][s2] for s2 in other_stations])]

    # given the detection date1 occurred on station s1, list all the detections of other stations that may be generated by the same source event
    current_association = {s1:date1}
    candidates =  compute_candidates(other_stations, current_association, DETECTIONS, STATION_MAX_TRAVEL_TIME, MERGE_DELTA_S)

    # update the list of other stations to only include the ones having at least a candidate detection
    other_stations = [s for s in other_stations if len(candidates[s]) > 0]

    if len(other_stations) < 2:
        return local_association, local_association_hashlist, date1

    # define the recursive browsing function (that is responsible for browsing the search space of associations for s1-date1)
    def backtrack(station_index, current_association, valid_grid, associations, save_path):
        if station_index == len(other_stations):
            return
        station = other_stations[station_index]

        candidates = compute_candidates([station], current_association, DETECTIONS, STATION_MAX_TRAVEL_TIME, MERGE_DELTA_S)
        for idx in candidates[station]:
            date, p = DETECTIONS[station][idx]
            if date in processed:
                continue
            if not association_is_new(current_association, date, local_association_hashlist):
                continue

            valid_grid_new, dg_new = update_valid_grid(current_association, valid_grid, station, date, GRID_STATION_COUPLE_TRAVEL_TIME, GRID_TOLERANCE, save_path, LON_BOUNDS, LAT_BOUNDS)

            valid_points_new = np.argwhere(valid_grid_new)

            if len(valid_points_new) > 0:
                current_association[station] = (date)

                if len(current_association) > 2:
                    update_results(date1, current_association, valid_points_new, local_association, GRID_STATION_COUPLE_TRAVEL_TIME)

                backtrack(station_index + 1, current_association, valid_grid_new, associations, save_path)
                del current_association[station]
        # also try without self
        if station_index >= REQ_CLOSEST_STATIONS:
            backtrack(station_index + 1, current_association, valid_grid, associations, save_path)
        return
    backtrack(0, current_association, None, associations, save_path=save_path)
    return local_association, local_association_hashlist, date1


if __name__ == '__main__':
    # Version séquentielle
    if RUN_ASSOCIATION:
        try:
            DETECTIONS_MERGED = DETECTIONS_MERGED[np.argsort(DETECTIONS_MERGED[:,1])][::-1]
            for det in tqdm(DETECTIONS_MERGED, desc="Processing detections"):
                local_association, local_association_hashlist, date1 = process_detection((det, association_hashlist, processed))
                processed.add(date1)
                association_hashlist = association_hashlist.union(local_association_hashlist)
                associations = associations | local_association
        finally:
            # save the associations no matter if the execution stopped properly
            print(f"Sauvegarde des associations dans {OUT_FILE}")
            np.save(OUT_FILE, associations)

## show results map

In [None]:
import glob2

# OUT_DIR = f"{ASSOCIATION_OUTPUT_DIR}/grids/{DETECTIONS_DIR_NAME}"
# Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
# OUT_FILE = f"{OUT_DIR}/refined_s_{LAT_BOUNDS[0]}-{LAT_BOUNDS[1]},{LON_BOUNDS[0]}-{LON_BOUNDS[1]},{GRID_SIZE},{MIN_P_TISSNET_PRIMARY},{MIN_P_TISSNET_SECONDARY}.npy".replace(" ","")
valid = np.zeros((GRID_SIZE,GRID_SIZE))

MIN_SIZE = 3

# load every npy file in the output directory and create a grid containing associations with cardinal >= 4
for f in tqdm(glob2.glob(f"{OUT_FILE[:-4]}*.npy")):
    associations = np.load(f, allow_pickle=True).item()
    for date, associations_ in associations.items():
        for (detections, valid_points) in associations_:
            if len(detections) > MIN_SIZE:
                continue
            for i, j in valid_points:
                valid[i,j] += 1

plt.figure(figsize=(15,10))
extent = (LON_BOUNDS[0], LON_BOUNDS[-1], LAT_BOUNDS[0], LAT_BOUNDS[-1])
im = plt.imshow(valid[::-1], aspect=1, cmap="inferno", extent=extent, interpolation=None, vmax =200)
cbar = plt.colorbar(im)
cbar.set_label('Nb of associations')

for s in STATIONS:
    p = s.get_pos()

    if p[0] > LAT_BOUNDS[1] or p[0] < LAT_BOUNDS[0] or p[1] > LON_BOUNDS[1] or p[1] < LON_BOUNDS[0]:
        print(f"Station {s.name} out of bounds")
        continue
    plt.plot(p[1], p[0], 'wx', alpha=0.75)
    plt.annotate(s.name, xy=(p[1], p[0]), xytext=(p[1]-(LON_BOUNDS[1]-LON_BOUNDS[0])/15, p[0]+(LAT_BOUNDS[1]-LAT_BOUNDS[0])/100), textcoords="data", color='w', alpha=0.9)

In [None]:
valid = np.zeros((GRID_SIZE,GRID_SIZE))

MIN_SIZE =9

# load every npy file in the output directory and create a grid containing associations with cardinal >= 4
for f in tqdm(glob2.glob(f"{OUT_FILE[:-4]}*.npy")):
    associations = np.load(f, allow_pickle=True).item()
    for date, associations_ in associations.items():
        for (detections, valid_points) in associations_:
            if len(detections) < MIN_SIZE:
                continue
            for i, j in valid_points:
                valid[i,j] += 1

# Create a figure with cartopy's PlateCarree projection
projection = ccrs.PlateCarree()
fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': projection})

# Set the extent of the map (min_lon, max_lon, min_lat, max_lat)
ax.set_extent([LON_BOUNDS[0], LON_BOUNDS[1], LAT_BOUNDS[0], LAT_BOUNDS[1]], crs=projection)

# Add natural features: land, ocean, and coastlines.
# These features will be drawn on top if the image is behind.
ax.add_feature(cfeature.LAND, facecolor='lightgray', zorder=2)
# ax.add_feature(cfeature.OCEAN, facecolor='lightblue', zorder=2)
ax.add_feature(cfeature.COASTLINE, edgecolor='black', linewidth=1, zorder=3)
ax.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='black', zorder=3)

# Plot the georeferenced image.
# Set a lower zorder (e.g., 1) so that the map features drawn with higher zorders remain visible.
# Adjust alpha to add a bit of transparency if desired.
extent = (LON_BOUNDS[0], LON_BOUNDS[1], LAT_BOUNDS[0], LAT_BOUNDS[1])
im = ax.imshow(valid[::-1],
               cmap="winter",
               extent=extent,
               interpolation="nearest",
               origin='upper',
               transform=projection,
               zorder=1,
               alpha=1, vmax=1)

# Add a colorbar for the image.
cbar = plt.colorbar(im, ax=ax, orientation='vertical', pad=0.05)
cbar.set_label('Nb of associations')

# Plot station markers and add annotations using the axes methods.
for s in STATIONS:
    lat, lon = s.get_pos()
    if lat > LAT_BOUNDS[1] or lat < LAT_BOUNDS[0] or lon > LON_BOUNDS[1] or lon < LON_BOUNDS[0]:
        print(f"Station {s.name} out of bounds")
        continue
    # Plot a marker with a higher zorder so it's on top of the image
    ax.plot(lon, lat, 'wx', alpha=0.75, markersize=8, transform=projection, zorder=4)
    ax.text(lon - (LON_BOUNDS[1] - LON_BOUNDS[0]) / 15,
            lat + (LAT_BOUNDS[1] - LAT_BOUNDS[0]) / 100,
            s.name,
            color='white',
            alpha=0.9,
            transform=projection,
            zorder=4)

plt.title("Association Data with Land and Sea")
plt.show()
