In [None]:
import numpy as np
import glob2
import datetime
from pathlib import Path
from tqdm.notebook import tqdm
import pickle
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib as mpl
plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams.update({
    "pgf.texsystem": "pdflatex",
    "text.usetex": True,
    "font.family": "serif",
    "font.size": 10,
    "axes.titlesize": 10,
    "axes.labelsize": 10,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.fontsize": 8,
})
from matplotlib import rc
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
from matplotlib.ticker import FuncFormatter
def smart_formatter(x, pos):  # reduce the size of "-"
    if abs(x - int(x)) < 1e-6:
        return f"{int(x)}"
    else:
        return f"{x:.1f}"
def shift_xtick_labels(ax, shift_points=2):  # shift extrema of an axis closer to the center
    labels = ax.get_xticklabels()
    dx = shift_points / 72  # pts to ins

    for i, label in enumerate(labels):
        trans = label.get_transform()
        if i == 0:
            label.set_transform(trans + ScaledTranslation(dx, 0, ax.figure.dpi_scale_trans))
        elif i == len(labels) - 1:
            label.set_transform(trans + ScaledTranslation(-dx, 0, ax.figure.dpi_scale_trans))

from matplotlib.transforms import ScaledTranslation
import math
import pandas as pd
from utils.data_reading.sound_data.station import StationsCatalog
from utils.detection.association_geodesic import squarize
from utils.detection.association_geodesic_3D import compute_candidates, update_valid_grid, update_results, load_detections, compute_grids

In [None]:
path = '../../../../data/MAHY/loc_3D/twin-cat'
files = glob2.glob(f"{path}/*raw_OBS-fixed.csv")

error_dicts = {
    'lat_err': {},
    'lon_err': {},
    'depth_err': {},
    'time_err': {}
}

bounds_dict = {
    'lat_err': 0.2,
    'lon_err': 0.2,
    'depth_err': 50,
    'time_err': 5
}

names = {
    'lat_err': 'Latitude',
    'lon_err': 'Longitude',
    'depth_err': 'Depth',
    'time_err': 'Time'
}


units = {
    'lat_err': '°',
    'lon_err': '°',
    'depth_err': 'km',
    'time_err': 's'
}

for f in sorted(files):
    dataset = f.split("/")[-1].split("_")[0]
    df = pd.read_csv(f, parse_dates=['date', 'h_date'])

    error_dicts['lat_err'][dataset] = df["h_lat"] - df["lat"]
    error_dicts['lon_err'][dataset] = df["h_lon"] - df["lon"]
    error_dicts['depth_err'][dataset] = df["h_depth"] - df["depth"]
    error_dicts['time_err'][dataset] = (df["h_date"] - df["date"]).dt.total_seconds()

In [None]:
datasets = sorted(list(error_dicts['lat_err'].keys()))
error_types = list(error_dicts.keys())

n_datasets = len(datasets)
n_errors = len(error_types)

fig, axes = plt.subplots(
    n_datasets, n_errors,
    figsize=(5.5, 5.5*1),
    squeeze=False
)


def latex_formatter(x, _):
    return f"${x:.2f}$"

for i, dataset in enumerate(datasets):
    for j, err_type in enumerate(error_types):
        ax = axes[i][j]
        if j == 0:
            ax.text(-0.5, 0.5, dataset, transform=ax.transAxes,
                    fontsize=9, fontweight='bold', va='center', ha='right', rotation=90)
        data = np.array(error_dicts[err_type][dataset])
        data = data[np.abs(data) < bounds_dict[err_type]]
        n, bins, patches = ax.hist(data, bins=20, color="#4c72b0", edgecolor="black", alpha=0.8)

        ax.set_xlim(-bounds_dict[err_type], bounds_dict[err_type])
        xticks = np.linspace(-bounds_dict[err_type], bounds_dict[err_type], 5)
        ax.set_xticks(xticks)

        if i == n_datasets - 1:
            ax.set_xlabel(f"{names[err_type]} error ({units[err_type]})", fontsize=8)
            for x in xticks:
                ax.axvline(x, ymin=0, ymax=0.025, color="black", linewidth=2, linestyle="-", alpha=0.7, zorder=20)
            ax.xaxis.set_major_formatter(FuncFormatter(smart_formatter))
            shift_xtick_labels(ax)
        else:
            ax.set_xticklabels([])
        if j == 0:
            ax.set_ylabel("Count", fontsize=8)
        else:
            ax.set_yticklabels([])

        ax.tick_params(axis='both', which='both', labelsize=8)
        ax.grid(True, axis='y', linestyle='--', alpha=0.5)


        # quantiles plot
        q05, q95 = np.quantile(data, [0.05, 0.95])
        for qi in range(2):
            q = [q05, q95][qi]
            x_shift = [-30*bounds_dict[err_type]/100,27*bounds_dict[err_type]/100][qi]
            ax.axvline(q, color="red", linestyle="--", linewidth=1.5, zorder=5, alpha=0.3)

            f = f"{q:.2f}" if units[err_type] != "km" else f"{q:.1f}"
            ax.text(
                q+x_shift, 75*ax.get_ylim()[1]/100,
                f,
                color="red",
                ha="center",
                va="bottom",
                fontsize=8,
                fontweight="bold",
                rotation=0,
                zorder=6,
            )

plt.savefig(f'../../../../data/MAHY/figures/OBS_errors.pdf', bbox_inches='tight', dpi=300, pad_inches=0)

In [None]:
for dataset in error_dicts['lat_err'].keys():
    plt.figure()

    plt.hist2d(error_dicts['lon_err'][dataset], error_dicts['depth_err'][dataset], bins=20, density=True, cmap="inferno")
    plt.colorbar()