In [None]:
import numpy as np
import glob2
import datetime
from pathlib import Path
import pickle
from tqdm import tqdm
from utils.data_reading.sound_data.station import StationsCatalog
from utils.physics.sound_model.spherical_sound_model import HomogeneousSphericalSoundModel as HomogeneousSoundModel

from utils.detection.association import update_candidates, association_is_new, update_valid_grid, update_results

In [None]:
det_dir = "../../../data/detection/TiSSNet/OHASISBIO-2018"
catalog_path = "/home/plerolland/Bureau/dataset.yaml"
out_dir = "../../../data/detection/association/OHASISBIO_grid.csv"

YEAR = 2018
STATIONS = StationsCatalog(catalog_path).filter_out_undated().filter_out_unlocated().ends_after(datetime.datetime(YEAR,1,1)).starts_before(datetime.datetime(YEAR,12,31))
SOUND_MODEL = HomogeneousSoundModel(sound_speed=1485.5)

MIN_P_TISSNET_PRIMARY = 0.5  # min probability of browsed detections
MIN_P_TISSNET_SECONDARY = 0.1  # min probability of detections that can be associated with the browsed one

In [None]:
LAT_BOUNDS = [-47, -27]
LAT_BOUNDS = [-60, -15]
LON_BOUNDS = [68, 88]
LON_BOUNDS = [40, 90]
PTS_LAT = np.linspace(LAT_BOUNDS[0], LAT_BOUNDS[1], 300)
PTS_LON = np.linspace(LON_BOUNDS[0], LON_BOUNDS[1], 300)

SOUND_SPEED_UNCERTAINTY = 2
GRID_MAX_RES_TIME = (np.sqrt(2) * (PTS_LAT[1]-PTS_LAT[0]) * 111_000) / (SOUND_MODEL.sound_speed - SOUND_SPEED_UNCERTAINTY)
PICK_UNCERTAINTY = 5
GENERIC_TOLERANCE = 10
GRID_TOLERANCE = GRID_MAX_RES_TIME + PICK_UNCERTAINTY
print(f"Grid tolerance of {GRID_TOLERANCE:.2f}s")

GRID_STATION_TRAVEL_TIME = {s : np.zeros((len(PTS_LAT), len(PTS_LON))) for s in STATIONS}
for s in tqdm(STATIONS, desc="computing travel time grid"):
    for ilat, lat in enumerate(PTS_LAT):
        for ilon, lon in enumerate(PTS_LON):
            GRID_STATION_TRAVEL_TIME[s][ilat, ilon] = SOUND_MODEL.get_sound_travel_time([lat, lon], s.get_pos())

GRID_STATION_COUPLE_TRAVEL_TIME = {s : {s2 : np.zeros((len(PTS_LAT), len(PTS_LON))) for s2 in STATIONS} for s in STATIONS}
for s in STATIONS:
    for s2 in STATIONS:
        GRID_STATION_COUPLE_TRAVEL_TIME[s][s2] = GRID_STATION_TRAVEL_TIME[s2] - GRID_STATION_TRAVEL_TIME[s]

STATION_MAX_TRAVEL_TIME = {s : {s2 : SOUND_MODEL.get_sound_travel_time(s.get_pos(), s2.get_pos()) for s2 in STATIONS} for s in STATIONS}

In [None]:
MERGE_DELTA = datetime.timedelta(seconds=5)  # threshold below which we consider two events should be merged
DETECTIONS = {}

for det_file in tqdm(glob2.glob(det_dir + "/*")):
    d = []
    with open(det_file, "rb") as f:
        while True:
            try:
                d.append(pickle.load(f))
            except EOFError:
                break
    d = np.array(d)
    d = d[:,:2]
    d = d[d[:,1] > MIN_P_TISSNET_SECONDARY]
    d = d[np.argsort(d[:,0])]

    # remove duplicates and regularly spaced signals
    new_d = [d[0]]
    for i in range(1, len(d)):
        # check this event is far enough from the previous one
        if d[i,0] - d[i-1,0] > MERGE_DELTA:
            # check this event is not part of a series of regularly spaced events (which probably means we encounter seismic airgun shots)
            if i < 3 or abs((d[i,0]-d[i-1,0]) - (d[i-1,0]-d[i-2,0])) > MERGE_DELTA and abs((d[i,0]-d[i-2,0]) - (d[i-1,0]-d[i-3,0])) > MERGE_DELTA:
                new_d.append(d[i])
    d = np.array(new_d)

    s_name = det_file[:-2].split("/")[-1].split("_")[-1]
    s_name, y_start = ID, sname = "-".join(s_name.split("-")[:-1]), s_name.split("-")[-1]

    station = STATIONS.by_name(s_name).by_starting_year(int(y_start))[0]
    DETECTIONS[station] = d

    print(f"Found {len(d)} detections for station {s_name}")

# we keep all detections in a single list, sorted by date, to then browse detections
DETECTIONS_MERGED = np.concatenate([[(det[0], det[1], s) for det in DETECTIONS[s]] for s in STATIONS])
DETECTIONS_MERGED = DETECTIONS_MERGED[DETECTIONS_MERGED[:,1] > MIN_P_TISSNET_PRIMARY]
DETECTIONS_MERGED = DETECTIONS_MERGED[np.argsort(DETECTIONS_MERGED[:,0])]

In [None]:
associations = {}
association_hashlist = set()

print("starting association")

REQ_CLOSEST_STATIONS = 3  # The REQ_CLOSEST_STATIONS th closest stations will be required for an association to be valid

SAVE_PATH_ROOT = "../../../data/detection/association/grids"
SAVE_PATH_ROOT = None

# dets_merged : (n,3) = n_detections x (det_time, det_probability, station)
for date1, p1, s1 in tqdm(DETECTIONS_MERGED):
    save_path = SAVE_PATH_ROOT
    if save_path is not None:
        save_path = f'{save_path}/{s1.name}-{date1.strftime("%Y%m%d_%H%M%S")}'
        Path(save_path).mkdir(parents=True, exist_ok=True)

    other_stations = np.array([s2 for s2 in STATIONS if s2 != s1])
    other_stations = other_stations[np.argsort([STATION_MAX_TRAVEL_TIME[s1][s2] for s2 in other_stations])]


    current_association = {s1:date1}
    anchors = [[si, datei] for si, datei in current_association.items()]
    candidates =  update_candidates(None, other_stations, anchors, DETECTIONS, STATION_MAX_TRAVEL_TIME, GENERIC_TOLERANCE)


    other_stations = [s for s in other_stations if len(candidates[s]) > 0]
    candidates = {s : candidates[s] for s in other_stations}

    def backtrack(station_index, current_association, valid_grid, associations, candidates, save_path):
        if station_index == len(other_stations):
            return
        station = other_stations[station_index]

        for idx in candidates[station]:
            date, p = DETECTIONS[station][idx]
            if not association_is_new(current_association, date, association_hashlist):
                continue

            valid_grid_new, dg_new = update_valid_grid(current_association, valid_grid, station, date, GRID_STATION_COUPLE_TRAVEL_TIME, GRID_TOLERANCE, save_path, LON_BOUNDS, LAT_BOUNDS)

            valid_points_new = np.argwhere(valid_grid_new)

            if len(valid_points_new) > 0:
                current_association[station] = (date)

                if len(current_association) > 3 and not update_results(date1, current_association, valid_points_new, association_hashlist, associations, GRID_STATION_COUPLE_TRAVEL_TIME):
                    continue  # no solution

                candidates_new = update_candidates(candidates, other_stations[station_index+1:], [[station, date]], DETECTIONS, STATION_MAX_TRAVEL_TIME, GENERIC_TOLERANCE,deep_copy=True)

                backtrack(station_index + 1, current_association, valid_grid_new, associations, candidates_new, save_path)
                del current_association[station]
        # also try without self
        if station_index >= REQ_CLOSEST_STATIONS:
            backtrack(station_index + 1, current_association, valid_grid, associations, candidates, save_path)
        return
    backtrack(0, current_association, None, associations, candidates, save_path=save_path)