This notebook aims at associating detections to compute possible source positions on a grid.

In [None]:
import numpy as np
import glob2
import datetime
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

from utils.detection.association import load_detections
from utils.detection.association import compute_grids
from utils.data_reading.sound_data.station import StationsCatalog
from utils.physics.sound_model.spherical_sound_model import HomogeneousSphericalSoundModel as HomogeneousSoundModel
from utils.detection.association import compute_candidates, association_is_new, update_valid_grid, update_results

In [None]:
# paths
CATALOG_PATH = "../../../data/demo"
DETECTIONS_DIR = "../../../data/detection/TiSSNet/demo"
ASSOCIATION_OUTPUT_DIR = "../../../data/detection/association"

# Detections loading parameters
RELOAD_DETECTIONS = True # if False, load files called "detections.npy" and "detections_merged.npy" containing everything instead of the raw detection output. Leave at True by default
MIN_P_TISSNET_PRIMARY = 0.5  # min probability of browsed detections
MIN_P_TISSNET_SECONDARY = 0.1  # min probability of detections that can be associated with the browsed one
MERGE_DELTA_S = 5 # threshold below which we consider two events should be merged
MERGE_DELTA = datetime.timedelta(seconds=MERGE_DELTA_S)

REQ_CLOSEST_STATIONS = 0  # The REQ_CLOSEST_STATIONS th closest stations will be required for an association to be valid

# sound model definition
SOUND_MODEL = HomogeneousSoundModel(sound_speed=1485.5)

# association running parameters
RUN_ASSOCIATION = True # set to False to load previous associations without processing it again
SAVE_PATH_ROOT = None  # change this to save the grids as figures, leave at None by default
NCPUS = 6  # nb of CPUs used

### Load Detections

In [None]:
STATIONS = StationsCatalog(CATALOG_PATH).filter_out_undated().filter_out_unlocated()
DETECTIONS_DIR_NAME = DETECTIONS_DIR.split("/")[-1]

if RELOAD_DETECTIONS:
    det_files = [f for f in glob2.glob(DETECTIONS_DIR + "/*") if Path(f).is_file()]
    DETECTIONS, DETECTIONS_MERGED = load_detections(det_files, STATIONS, DETECTIONS_DIR, MIN_P_TISSNET_PRIMARY, MIN_P_TISSNET_SECONDARY, MERGE_DELTA)
else:
    DETECTIONS = np.load(f"{DETECTIONS_DIR}/cache/detections.npy", allow_pickle=True).item()
    DETECTIONS_MERGED = np.load(f"{DETECTIONS_DIR}/cache/detections_merged.npy", allow_pickle=True)

STATIONS = [s for s in DETECTIONS.keys()]
FIRSTS_DETECTIONS = {s : DETECTIONS[s][0,0] for s in STATIONS}
LASTS_DETECTIONS = {s : DETECTIONS[s][-1,0] for s in STATIONS}

### Compute grids

In [None]:
LAT_BOUNDS = [-13.4, -12.4]
LON_BOUNDS = [45.2, 46.2]
GRID_SIZE = 50  # number of points along each axis

(PTS_LAT, PTS_LON, STATION_MAX_TRAVEL_TIME, GRID_STATION_TRAVEL_TIME,
 GRID_STATION_COUPLE_TRAVEL_TIME, GRID_TOLERANCE) = compute_grids(LAT_BOUNDS, LON_BOUNDS, GRID_SIZE, SOUND_MODEL, STATIONS, pick_uncertainty=5, sound_speed_uncertainty=2)

### Now associate

In [None]:
print("starting association")

OUT_DIR = f"{ASSOCIATION_OUTPUT_DIR}/grids/{DETECTIONS_DIR_NAME}"
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
OUT_FILE = f"{OUT_DIR}/s_{LAT_BOUNDS[0]}-{LAT_BOUNDS[1]},{LON_BOUNDS[0]}-{LON_BOUNDS[1]},{GRID_SIZE},{MIN_P_TISSNET_PRIMARY},{MIN_P_TISSNET_SECONDARY}.npy".replace(" ","")

association_hashlist = set()
associations = {}

def process_detection(arg):
    detection, local_association_hashlist = arg
    local_association = {}
    date1, p1, s1 = detection
    save_path = SAVE_PATH_ROOT
    if save_path is not None:
        save_path = f'{save_path}/{s1.name}-{date1.strftime("%Y%m%d_%H%M%S")}'
        Path(save_path).mkdir(parents=True, exist_ok=True)

    # list all other stations and sort them by distance from s1
    other_stations = np.array([s2 for s2 in STATIONS if s2 != s1
                               and date1 + datetime.timedelta(seconds=4*GRID_TOLERANCE) > FIRSTS_DETECTIONS[s2]
                               and date1 - datetime.timedelta(seconds=4*GRID_TOLERANCE) < LASTS_DETECTIONS[s2]])
    other_stations = other_stations[np.argsort([STATION_MAX_TRAVEL_TIME[s1][s2] for s2 in other_stations])]

    # given the detection date1 occurred on station s1, list all the detections of other stations that may be generated by the same source event
    current_association = {s1:date1}
    candidates =  compute_candidates(other_stations, current_association, DETECTIONS, STATION_MAX_TRAVEL_TIME, MERGE_DELTA_S)

    # update the list of other stations to only include the ones having at least a candidate detection
    other_stations = [s for s in other_stations if len(candidates[s]) > 0]

    if len(other_stations) < 2:
        return local_association, local_association_hashlist

    # define the recursive browsing function (that is responsible for browsing the search space of associations for s1-date1)
    def backtrack(station_index, current_association, valid_grid, associations, save_path):
        if station_index == len(other_stations):
            return
        station = other_stations[station_index]

        candidates = compute_candidates([station], current_association, DETECTIONS, STATION_MAX_TRAVEL_TIME, MERGE_DELTA_S)
        for idx in candidates[station]:
            date, p = DETECTIONS[station][idx]
            if not association_is_new(current_association, date, local_association_hashlist):
                continue

            valid_grid_new, dg_new = update_valid_grid(current_association, valid_grid, station, date, GRID_STATION_COUPLE_TRAVEL_TIME, GRID_TOLERANCE, save_path, LON_BOUNDS, LAT_BOUNDS)

            valid_points_new = np.argwhere(valid_grid_new)

            if len(valid_points_new) > 0:
                current_association[station] = (date)

                if len(current_association) > 2:
                    update_results(date1, current_association, valid_points_new, local_association, GRID_STATION_COUPLE_TRAVEL_TIME)

                backtrack(station_index + 1, current_association, valid_grid_new, associations, save_path)
                del current_association[station]
        # also try without self
        if station_index >= REQ_CLOSEST_STATIONS:
            backtrack(station_index + 1, current_association, valid_grid, associations, save_path)
        return
    backtrack(0, current_association, None, associations, save_path=save_path)
    return local_association, local_association_hashlist

# main part
if RUN_ASSOCIATION:
    try:
        with ProcessPoolExecutor(NCPUS) as executor:
            futures = {executor.submit(process_detection, (det, association_hashlist)): det for det in DETECTIONS_MERGED}
            for future in tqdm(as_completed(futures), total=len(futures)):
                local_association, local_association_hashlist = future.result()
                association_hashlist = association_hashlist.union(local_association_hashlist)
                associations = associations | local_association
    finally:
        # save the associations no matter if the execution stopped properly
        np.save(OUT_FILE, associations)

### Take a look at the results

In [None]:
valid = np.zeros((GRID_SIZE,GRID_SIZE))

MIN_SIZE = 3

# load every npy file in the output directory and create a grid containing associations with cardinal >= 4
for f in tqdm(glob2.glob(f"{OUT_FILE[:-4]}*.npy")):
    associations = np.load(f, allow_pickle=True).item()
    for date, associations_ in associations.items():
        for (detections, valid_points) in associations_:
            if len(detections) < MIN_SIZE:
                continue
            for i, j in valid_points:
                valid[i,j] += 1

plt.figure(figsize=(15,10))
extent = (LON_BOUNDS[0], LON_BOUNDS[-1], LAT_BOUNDS[0], LAT_BOUNDS[-1])
im = plt.imshow(valid[::-1], aspect=1, cmap="inferno", extent=extent, interpolation=None)
cbar = plt.colorbar(im)
cbar.set_label('Nb of associations')

for s in STATIONS:
    p = s.get_pos()

    if p[0] > LAT_BOUNDS[1] or p[0] < LAT_BOUNDS[0] or p[1] > LON_BOUNDS[1] or p[1] < LON_BOUNDS[0]:
        print(f"Station {s.name} out of bounds")
        continue
    plt.plot(p[1], p[0], 'wx', alpha=0.75)
    plt.annotate(s.name, xy=(p[1], p[0]), xytext=(p[1]-(LON_BOUNDS[1]-LON_BOUNDS[0])/15, p[0]+(LAT_BOUNDS[1]-LAT_BOUNDS[0])/100), textcoords="data", color='w', alpha=0.9)