In [None]:
import numpy as np
import glob2
import datetime
from pathlib import Path
from tqdm.notebook import tqdm
import pickle
from matplotlib import pyplot as plt
plt.rcParams.update({'font.size': 18})
import math
import pandas as pd

from utils.data_reading.sound_data.station import StationsCatalog
from utils.detection.association_geodesic import squarize
from utils.detection.association_geodesic_3D import compute_candidates, update_valid_grid, update_results, load_detections, compute_grids

## Parameters initialization

In [None]:
# paths
CATALOG_PATH = "/media/plerolland/akoustik"
dataset = "MAHY4"
DETECTIONS_DIR = f"../../../../data/detection/TiSSNet_Pn_OBS-fixed/{dataset}"
SOUND_MODEL_PATH = f"../../../../data/sound_model"

# Detections loading parameters
MIN_P_TISSNET_PRIMARY = 0.3  # min probability of browsed detections
MIN_P_TISSNET_SECONDARY = 0.15  # min probability of detections that can be associated with the browsed one
MERGE_DELTA_S = 5 # threshold below which we consider two events should be merged
MERGE_DELTA = datetime.timedelta(seconds=MERGE_DELTA_S)

REQ_CLOSEST_STATIONS = 0  # The REQ_CLOSEST_STATIONS th closest stations will be required for an association to be valid

# sound model definition
STATIONS = StationsCatalog(CATALOG_PATH).filter_out_undated().filter_out_unlocated().by_dataset(dataset)
seismic_paths = glob2.glob("../../../../data/MAHY/loc_3D/*.npz")

# association running parameters
SAVE_PATH_ROOT = None  # change this to save the grids as figures, leave at None by default

## Detections processing

In [None]:
DETECTIONS_DIR_NAME = DETECTIONS_DIR.split("/")[-1]

Path(f"{DETECTIONS_DIR}/cache").mkdir(parents=True, exist_ok=True)
DET_PATH = f"{DETECTIONS_DIR}/cache/detections_{MIN_P_TISSNET_SECONDARY}_{MERGE_DELTA_S}.pkl"
if not Path(DET_PATH).exists():
    STATIONS = StationsCatalog(CATALOG_PATH).filter_out_undated().filter_out_unlocated()
    det_files = [f for f in glob2.glob(DETECTIONS_DIR + "/*.pkl") if Path(f).is_file()]
    DETECTIONS = load_detections(det_files, STATIONS, MIN_P_TISSNET_SECONDARY, merge_delta=datetime.timedelta(seconds=MERGE_DELTA_S))
    with open(DET_PATH, "wb") as f:
        pickle.dump((DETECTIONS), f)
else:
    with open(DET_PATH, "rb") as f:
        DETECTIONS = pickle.load(f)

idx_det = 0
IDX_TO_DET = {}
for idx, s in enumerate(DETECTIONS.keys()):
    s.idx = idx  # indexes to store efficiently the associations
    DETECTIONS[s] = list(DETECTIONS[s])
    for i in range(len(DETECTIONS[s])):
        DETECTIONS[s][i] = np.concatenate((DETECTIONS[s][i], [idx_det]))
        IDX_TO_DET[idx_det] = DETECTIONS[s][i]
        idx_det += 1
    DETECTIONS[s] = np.array(DETECTIONS[s])
DETECTION_IDXS = np.array(list(range(idx_det)))

STATIONS = [s for s in DETECTIONS.keys()]
for i in range(len(STATIONS)):
    STATIONS[i].idx = i
FIRSTS_DETECTIONS = {s : DETECTIONS[s][0,0] for s in STATIONS}
LASTS_DETECTIONS = {s : DETECTIONS[s][-1,0] for s in STATIONS}

DETECTIONS_MERGED = np.concatenate([[(det[0], det[1], det[2], s) for det in DETECTIONS[s]] for s in STATIONS])
DETECTIONS_MERGED = DETECTIONS_MERGED[DETECTIONS_MERGED[:, 1] > MIN_P_TISSNET_PRIMARY]
DETECTIONS_MERGED = DETECTIONS_MERGED[np.argsort(DETECTIONS_MERGED[:, 1])][::-1]

## Grid computation

In [None]:
BOUNDS = [(1_000,100_000), (-13.4,-12.4), (45.25,46.25)]
GRID_SIZES = [100, 100]
PICK_UNCERTAINTY = 0.1
PICK_UNCERTAINTY = 1
MAX_CLOCK_DRIFT = 0.1
GEOMETRICAL = 0.1

GRID_PATH = f"{DETECTIONS_DIR}/cache/grids_{BOUNDS[0][0]}_{BOUNDS[0][1]}_{BOUNDS[1][0]}_{BOUNDS[1][1]}_{BOUNDS[2][0]}_{BOUNDS[2][1]}_{GRID_SIZES[0]}_{GRID_SIZES[1]}_{PICK_UNCERTAINTY}_{MAX_CLOCK_DRIFT}.pkl"

if not Path(GRID_PATH).exists():
    GRID_TO_COORDS, TDoA, MAX_TDoA, TDoA_UNCERTAINTY, LATS, DEPTHS, TRAVEL_TIMES = compute_grids(BOUNDS, GRID_SIZES, STATIONS, seismic_paths, pick_uncertainty=PICK_UNCERTAINTY, max_clock_drift=MAX_CLOCK_DRIFT, geometrical_uncertainty=GEOMETRICAL)
    with open(GRID_PATH, "wb") as f:
        pickle.dump((GRID_TO_COORDS, TDoA, MAX_TDoA, TDoA_UNCERTAINTY, LATS, DEPTHS, TRAVEL_TIMES), f)
else:
    with open(GRID_PATH, "rb") as f:
        GRID_TO_COORDS, TDoA, MAX_TDoA, TDoA_UNCERTAINTY, LATS, DEPTHS, TRAVEL_TIMES = pickle.load(f)
GRID_TO_COORDS = np.array(GRID_TO_COORDS)

In [None]:
d, s1, s2 = DEPTHS[80], STATIONS[0], STATIONS[1]

depth_mask = GRID_TO_COORDS[:,0]==d
weights = TDoA[s1][s2][depth_mask]

fig, ax = plt.subplots(figsize=(8,8))
sq = squarize(GRID_TO_COORDS[depth_mask,1:], weights, BOUNDS[1], BOUNDS[2], size=1000)

im = ax.imshow(sq[::-1], cmap="seismic",extent=(BOUNDS[2][0], BOUNDS[2][-1], BOUNDS[1][0], BOUNDS[1][-1]))
xticks = np.arange(np.floor(BOUNDS[2][0]/0.25)*0.25, np.ceil(BOUNDS[2][-1]/0.25)*0.25 + 0.25, 0.25)
yticks = np.arange(np.floor(BOUNDS[1][0]/0.25)*0.25, np.ceil(BOUNDS[1][-1]/0.25)*0.25 + 0.25, 0.25)
ax.set_xticks(xticks)
ax.set_yticks(yticks)
for s_ in STATIONS:
    p = s_.get_pos()

    if p[0] > BOUNDS[1][1] or p[0] < BOUNDS[1][0] or p[1] > BOUNDS[2][1] or p[1] < BOUNDS[2][0]:
        print(f"Station {s_.name} out of bounds")
        continue
    ax.plot(p[1], p[0], 'yx', alpha=0.75, markersize=10, markeredgewidth=3)
    if "3" in s_.name:
        ax.annotate(s_.name, xy=(p[1], p[0]), xytext=(p[1]-4.5*(BOUNDS[2][1]-BOUNDS[2][0])/30, p[0]+(BOUNDS[1][1]-BOUNDS[1][0])/50), textcoords="data", color='y', alpha=0.9, weight='bold')
    else:
        ax.annotate(s_.name, xy=(p[1], p[0]), xytext=(p[1]-(BOUNDS[2][1]-BOUNDS[2][0])/30, p[0]+(BOUNDS[1][1]-BOUNDS[1][0])/50), textcoords="data", color='y', alpha=0.9, weight='bold')
cbar = plt.colorbar(im,fraction=0.0415, pad=0.04)
ax.set_xlabel("lon (°)")
ax.set_ylabel("lat (°)")
plt.title(f"Stations {s1.name}-{s2.name} (depth = {d/1_000} km)")
plt.savefig(f'{DETECTIONS_DIR}/figures/grids_{BOUNDS[0][0]}_{BOUNDS[0][1]}_{BOUNDS[1][0]}_{BOUNDS[1][1]}_{BOUNDS[2][0]}_{BOUNDS[2][1]}_{GRID_SIZES[0]}_{GRID_SIZES[1]}_{PICK_UNCERTAINTY}_{MAX_CLOCK_DRIFT}_xy.png', bbox_inches='tight')

In [None]:
lat, s1, s2 = LATS[len(LATS)//2], STATIONS[0], STATIONS[1]

lat_mask = GRID_TO_COORDS[:,1]==lat
weights = TDoA[s1][s2][lat_mask]

fig, ax = plt.subplots(figsize=(8,8))
sq = squarize(GRID_TO_COORDS[lat_mask][:,[0,2]], weights, BOUNDS[0], BOUNDS[2], size=1000)

im = ax.imshow(sq, cmap="seismic", aspect="auto", extent=(BOUNDS[2][0], BOUNDS[2][-1], -BOUNDS[0][1]/1_000, -BOUNDS[0][0]/1_000))
xticks = np.arange(np.floor(BOUNDS[2][0]/0.25)*0.25, np.ceil(BOUNDS[2][-1]/0.25)*0.25 + 0.25, 0.25)
yticks = np.arange(np.floor(-BOUNDS[0][1]/5_000)*5, np.ceil(-BOUNDS[0][0]/5_000)*5 + 5, 5)
ax.set_xticks(xticks)
ax.set_yticks(yticks)

cbar = plt.colorbar(im,fraction=0.0415, pad=0.04)
ax.set_xlabel("lon (°)")
ax.set_ylabel("depth (km)")
plt.title(f"Stations {s1.name}-{s2.name} (lat = {lat}°)")
plt.savefig(f'{DETECTIONS_DIR}/figures/grids_{BOUNDS[0][0]}_{BOUNDS[0][1]}_{BOUNDS[1][0]}_{BOUNDS[1][1]}_{BOUNDS[2][0]}_{BOUNDS[2][1]}_{GRID_SIZES[0]}_{GRID_SIZES[1]}_{PICK_UNCERTAINTY}_{MAX_CLOCK_DRIFT}_xz.png', bbox_inches='tight')

## Association
(note: parallelize this with e.g. ProcessPoolExecutor for large datasets)

In [None]:
print("starting association")
MIN_ASSOCIATION_SIZE = 3
ASSOCIATION_RECORD_TOLERANCE = 0
max_reached_per_det = {det_idx: MIN_ASSOCIATION_SIZE+ASSOCIATION_RECORD_TOLERANCE for det_idx in DETECTION_IDXS}


already_examined = set()

def process_detection(arg):
    detection, already_examined, max_reached_per_det = arg
    max_reached_per_det_modifications = {}
    local_association = []
    date1, p1, idx_det1, s1 = detection

    # list all other stations and sort them by distance from s1
    other_stations = np.array([s2 for s2 in STATIONS if s2 != s1
                               and date1 + datetime.timedelta(days=1) > FIRSTS_DETECTIONS[s2]
                               and date1 - datetime.timedelta(days=1) < LASTS_DETECTIONS[s2]])
    other_stations = other_stations[np.argsort([MAX_TDoA[s1][s2] for s2 in other_stations])]

    # given the detection date1 occurred on station s1, list all the detections of other stations that may be generated by the same source event
    current_association = {s1:(date1, idx_det1)}
    candidates = compute_candidates(other_stations, current_association, DETECTIONS, MAX_TDoA, MERGE_DELTA_S)

    # update the list of other stations to only include the ones having at least a candidate detection
    other_stations = [s for s in other_stations if len(candidates[s]) > 0]

    # define the recursive browsing function (that is responsible for browsing the search space of associations for s1-date1)
    def backtrack(station_index, current_association, valid_grid, associations):
        if station_index == len(other_stations):
            return
        station = other_stations[station_index]

        candidates = compute_candidates([station], current_association, DETECTIONS, MAX_TDoA, MERGE_DELTA_S)
        probabilities = [DETECTIONS[station][idx][1] for idx in candidates[station]]
        candidates[station] = np.array(candidates[station])[np.argsort(probabilities)][::-1][:10]
        for idx in candidates[station]:
            date, p, idx_det = DETECTIONS[station][idx]

            if date in already_examined:
                # the det was already browsed as main
                continue
            if len(other_stations) < max_reached_per_det[idx_det] - ASSOCIATION_RECORD_TOLERANCE - 1:
            # the det already belongs to an association larger that what we could have here
                continue

            valid_grid_new, dg_new = update_valid_grid(current_association, valid_grid, station, date, TDoA, TDoA_UNCERTAINTY)

            valid_points_new = np.argwhere(valid_grid_new)[:, 0]

            if len(valid_points_new) > 0:
                current_association[station] = (date, idx_det)

                if np.all([len(current_association) >= max_reached_per_det[idx] - ASSOCIATION_RECORD_TOLERANCE for _, idx in
                       current_association.values()]):
                    update_results(current_association, valid_points_new, local_association, TDoA, TDoA_UNCERTAINTY)
                    for _, idx in current_association.values():
                        if len(current_association) > max_reached_per_det[idx]:
                            max_reached_per_det[idx] = len(current_association)
                            max_reached_per_det_modifications[idx] = len(current_association)
                backtrack(station_index + 1, current_association, valid_grid_new, associations)
                del current_association[station]
        # also try without self
        if station_index >= REQ_CLOSEST_STATIONS:
            backtrack(station_index + 1, current_association, valid_grid, associations)
        return

    if len(other_stations) >= max_reached_per_det[idx_det1]-ASSOCIATION_RECORD_TOLERANCE - 1:
        # we only browse other stations if we can make at least a trio
        backtrack(0, current_association, None, associations)
    return local_association, max_reached_per_det_modifications


frac = 0.1
n_chunks = math.ceil(1/frac)
chunk_size = len(DETECTIONS_MERGED) // n_chunks
chunks = [DETECTIONS_MERGED[i * chunk_size : (i + 1) * chunk_size] for i in range(n_chunks-1)]
chunks.append(DETECTIONS_MERGED[9 * chunk_size :])


# main part (note: process parallelization is a very efficient solution in case needed)
for i in range(len(chunks)):
    fname = f"{DETECTIONS_DIR}/cache/associations_{PICK_UNCERTAINTY}_{MIN_ASSOCIATION_SIZE}_{i*frac:.02f}.pkl"
    if Path(fname).exists():
        continue
    associations = []
    for det in tqdm(chunks[i]):
        local_association, max_reached_per_det_modifications = process_detection((det, already_examined, max_reached_per_det))
        already_examined.add(det[0])
        associations.extend(local_association)
        for i, v in max_reached_per_det_modifications.items():
            max_reached_per_det[i] = max(max_reached_per_det[i], max_reached_per_det_modifications[i])
    with open(fname, "wb") as f:
        pickle.dump(associations, f)

In [None]:
print(len(associations))

# Associations plot

In [None]:
nb_per_coord = [0 for i in range(len(GRID_TO_COORDS))]

association_files = glob2.glob(f"{DETECTIONS_DIR}/cache/associations_{PICK_UNCERTAINTY}_{MIN_ASSOCIATION_SIZE}_*.pkl")
for file in association_files:
    with open(file, "rb") as f:
        associations = pickle.load(f)
    for association in tqdm(associations):
        detections, valid_points = association
        if len(detections) < 4:
            continue
        for i in valid_points:
            nb_per_coord[i] += 1

nb_per_coord = np.array(nb_per_coord)

In [None]:
from mpl_toolkits.mplot3d import Axes3D
import plotly.express as px
import pandas as pd

%matplotlib qt

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

mask = nb_per_coord > 50

p = ax.scatter(GRID_TO_COORDS[mask, 2], GRID_TO_COORDS[mask, 1], -GRID_TO_COORDS[mask, 0]/1_000, c=nb_per_coord[mask], cmap='inferno', s=10)
#df = pd.DataFrame({'x': GRID_TO_COORDS[mask, 2], 'y': GRID_TO_COORDS[mask, 1], 'z': -GRID_TO_COORDS[mask, 0]/1_000, 'val': nb_per_coord[mask]})
#fig = px.scatter_3d(df, x='x', y='y', z='z', color='val', opacity=1.0, size_max=10)

plt.xlim(45.5,45.7)
plt.ylim(-12.8,-12.6)

fig.colorbar(p, ax=ax, label='nb_per_coord')

plt.show()

In [None]:
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

%matplotlib inline

ax.view_init(elev=15, azim=240)

mask = nb_per_coord > 25
sc = ax.scatter(GRID_TO_COORDS[mask, 2], GRID_TO_COORDS[mask, 1], -GRID_TO_COORDS[mask, 0] / 1_000, c=nb_per_coord[mask], cmap='plasma', s=10, alpha=0.9)

ax.set_xlabel("lon (°)", fontsize=10)
ax.set_ylabel("lat (°)", fontsize=10)
ax.set_zlabel("depth (km)", fontsize=10)
ax.grid(True)
ax.tick_params(axis='both', labelsize=8)     # X et Y
ax.tick_params(axis='z', labelsize=8)

cb = fig.colorbar(sc, ax=ax, shrink=0.6)
cb.set_label("Number of associations")

#plt.savefig(f'{DETECTIONS_DIR}/figures/proj_3d.png', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
import numpy as np
import pickle
import glob2
from tqdm import tqdm
from collections import Counter

dep_bounds = [30_000, 60_000]
dep_bounds = [35_000, 45_000]

coords_2d = GRID_TO_COORDS[GRID_TO_COORDS[:, 0] == DEPTHS[0], 1:]
coord_to_idx = {tuple(coord): idx for idx, coord in enumerate(coords_2d)}

nb_per_coord = {n: Counter() for n in range(3, 5)}

association_files = glob2.glob(f"{DETECTIONS_DIR}/cache/associations_{PICK_UNCERTAINTY}_{MIN_ASSOCIATION_SIZE}_*.pkl")
for file in association_files:
    with open(file, "rb") as f:
        associations = pickle.load(f)
    for detections, valid_points in tqdm(associations):
        valid_coords = GRID_TO_COORDS[valid_points]
        mask = (valid_coords[:, 0] > dep_bounds[0]) & (valid_coords[:, 0] < dep_bounds[1])
        filtered_coords = valid_coords[mask]

        for coord in filtered_coords[:, 1:]:
            idx = coord_to_idx.get(tuple(coord))
            if idx is not None:
                nb_per_coord[len(detections)][idx] += 1

In [None]:
%matplotlib inline

min_size_display = 4

log = False
weights = np.array([np.sum([nb_per_coord[n][i] for n in range(min_size_display,5)]) for i in range(len(coords_2d))])

fig, ax = plt.subplots(figsize=(10.5,14))
sq = squarize(coords_2d, weights, BOUNDS[1], BOUNDS[2], size=1000)
if log:
    sq[sq<1] = 1
    sq = np.log10(sq)
#sq[sq<1.5] = 1.5
im = ax.imshow(sq[::-1], cmap="inferno",extent=(BOUNDS[2][0], BOUNDS[2][-1], BOUNDS[1][0], BOUNDS[1][-1]))
xticks = np.arange(np.floor(BOUNDS[2][0]/0.1)*0.1, np.ceil(BOUNDS[2][-1]/0.1)*0.1 + 0.1, 0.1)
yticks = np.arange(np.floor(BOUNDS[1][0]/0.1)*0.1, np.ceil(BOUNDS[1][-1]/0.1)*0.1 + 0.1, 0.1)
ax.set_xticks(xticks)
ax.set_yticks(yticks)
for s_ in STATIONS:
    p = s_.get_pos()

    if p[0] > BOUNDS[1][1] or p[0] < BOUNDS[1][0] or p[1] > BOUNDS[2][1] or p[1] < BOUNDS[2][0]:
        print(f"Station {s_.name} out of bounds")
        continue
    ax.plot(p[1], p[0], 'yx', alpha=0.75, markersize=10, markeredgewidth=3)
    ax.annotate(s_.name, xy=(p[1], p[0]), xytext=(p[1]-(BOUNDS[2][1]-BOUNDS[2][0])/30, p[0]+(BOUNDS[1][1]-BOUNDS[1][0])/50), textcoords="data", color='y', alpha=0.9, weight='bold')
cbar = plt.colorbar(im,fraction=0.0415, pad=0.04)
cbar.set_label(f'Counts of resulting associations{" (log)" if log else ""}', rotation=270, labelpad=20)
ax.set_xlabel("lon (°)")
ax.set_ylabel("lat (°)")
plt.tight_layout()
Path(f"{DETECTIONS_DIR}/figures").mkdir(exist_ok=True)
plt.savefig(f'{DETECTIONS_DIR}/figures/_min-{min_size_display}_{dep_bounds[0]/1_000:.0f}-{dep_bounds[1]/1_000:.0f}km_{"log" if log else ""}.png', bbox_inches='tight')

In [None]:
import geopandas as gpd
from shapely.geometry import Polygon
from skimage import measure

threshold = 1500 # MAHY3
threshold = 4000 # MAHY2 (note: associations of size 3)
threshold = 2000 # MAHY1
threshold = 6000 # MAHY0
threshold = 350 # MAHY4
lat_bounds, lon_bounds = BOUNDS[1], BOUNDS[2]

contours = measure.find_contours(sq[::-1], threshold)

geoms = []
for contour in contours:
    if len(contour) < 3:
        continue
    lon_vals = np.linspace(lon_bounds[0], lon_bounds[1], sq.shape[1])
    lat_vals = np.linspace(lat_bounds[0], lat_bounds[1], sq.shape[0])[::-1]
    coords = [(lon_vals[int(x)], lat_vals[int(y)]) for y, x in contour]
    geoms.append(Polygon(coords))

gdf = gpd.GeoDataFrame(geometry=geoms, crs="EPSG:4326")
gdf.to_file(f"{DETECTIONS_DIR}/figures/contours.geojson", driver="GeoJSON")

In [None]:
DELTA = pd.Timedelta(seconds=15)
S_df = pd.read_csv(
    "../../../../data/MAHY/lavayssiere_and_public.csv", header=None, names=["date","lat","lon","depth","mb"], parse_dates=["date"]
)
S_df = S_df[(S_df['date'] >= np.min(DETECTIONS_MERGED[:,0]) - 5*DELTA) & (S_df['date'] <= np.max(DETECTIONS_MERGED[:,0]))]
pts_df = []

matched = {}

association_files = glob2.glob(f"{DETECTIONS_DIR}/cache/associations_{PICK_UNCERTAINTY}_{3}_*.pkl")
for file in association_files:
    with open(file, "rb") as f:
        associations = pickle.load(f)
    for detections, valid_points in tqdm(associations):
        coords = GRID_TO_COORDS[valid_points]
        per_depth = [coords[coords[:,0]==d] for d in np.unique(coords[:,0])]
        chosen_depth = np.argmax(len(p) for p in per_depth)
        coords = np.mean(per_depth[chosen_depth], axis=0)

        dates = np.array([IDX_TO_DET[idx][0] for idx in detections[:,1]])
        date = dates[0] + np.mean(dates[0] - dates) - datetime.timedelta(seconds=12)

        pts_df.append({"date":date, "depth":coords[0], "lat":coords[1], "lon":coords[2]})

        candidates = S_df[(S_df['date'] >= date - DELTA) & \
       (S_df['date'] <= date + DELTA)]
        for idx in candidates.index:
            d_diff = np.sqrt(
                (S_df['lat'][idx] - coords[1])**2+
                (S_df['lon'][idx] - coords[2])**2+
                ((S_df['depth'][idx] - coords[0]) / 111_000)**2)

            t_diff = S_df['date'][idx] - date

            matched.setdefault(idx, []).append((t_diff.total_seconds(), d_diff, detections, coords))

n = 0
S_df_matched = S_df.loc[list(matched.keys())].copy()
for idx in matched.keys():
    if len(matched[idx]) > 1:
        longest = np.argmax([len(d[2]) for d in matched[idx]])
        best = np.argmin([d[1] for d in matched[idx]])
        if len(matched[idx][longest][2]) == 4:
            n += 1
            best = longest
        matched[idx] = [matched[idx][best]]

    h_dates = []
    for si, di in matched[idx][0][2]:
        s, (date, _, _) = STATIONS[si], IDX_TO_DET[di]
        S_df_matched.loc[idx, s.name] = date
        cell = np.argmin(np.sqrt((matched[idx][0][3][0]-GRID_TO_COORDS[:,0])**2 + (111_000*(matched[idx][0][3][1]-GRID_TO_COORDS[:,1]))**2 + (111_000*(matched[idx][0][3][2]-GRID_TO_COORDS[:,2]))**2))
        h_dates.append(date - datetime.timedelta(seconds=TRAVEL_TIMES[s][cell]))
    S_df_matched.loc[idx, "h_date"] = h_dates[0] + np.mean([hd - h_dates[0] for hd in h_dates])
    S_df_matched.loc[idx, "h_depth"] = matched[idx][0][3][0]/1_000
    S_df_matched.loc[idx, "h_lat"] = matched[idx][0][3][1]
    S_df_matched.loc[idx, "h_lon"] = matched[idx][0][3][2]

print(n, len(S_df_matched), len(S_df))

h = dataset[-1]
S_df_matched.to_csv(f'../../../../data/MAHY/loc_3D/twin-cat/{dataset}_OBS-fixed.csv', index=False, columns=["date","h_date","lat","h_lat","lon","h_lon","depth","h_depth","mb",f"MAHY{h}1",f"MAHY{h}2",f"MAHY{h}3",f"MAHY{h}4"], float_format='%.3f')

pts_df = pd.DataFrame(pts_df)
pts_df.to_csv(f'../../../../data/MAHY/loc_3D/twin-cat/{dataset}_all_OBS-fixed.csv', index=False, columns=["date","lat","lon","depth"], float_format='%.3f')

In [None]:
# MAHY4 : 1 - 194 326 532
# MAHY3 : 1 - 403 629 1267
# MAHY2 : 1 - 0 748 1697
# MAHY1 : 1 - 405 609 1379
# MAHY0 : 1 - 759 1120 1920

In [None]:
# MAHY 4 example:

# no intercept
# 0.2 : 86 322 532
# 0.5 : 119 324 532
# 1 : 165 325 532

# intercept
# 0.2 : 95 321 532
# 0.5 : 147 323 532
# 1 : 194 326 532

In [None]:
all_shifts = sum([[v[0] for v in m] for m in matched.values()], [])
plt.hist(all_shifts)
plt.xlim(-15,15)

plt.figure()
all_shifts = sum([[v[1] for v in m] for m in matched.values()], [])
plt.hist(all_shifts)