In [None]:
import sys

sys.path.insert(1, "../")

import glob
from dem_comparison.plots import plot_metrics
from pathlib import Path
from dem_comparison.utils import resample_dataset, read_metrics, filter_data
import multiprocessing as mp
import rasterio as rio
import numpy as np
from tqdm import tqdm
import psutil
import dask.array as da
import geopandas as gpd
from pathlib import Path
import pyproj
from shapely import Point

In [None]:
pkls = glob.glob(
    f"../Antarctica_Dem_Comparison_Ocean_Corrected/dem_diff_metrics/**.pkl"
)
metrics, x, y = read_metrics(pkls)
lp = np.percentile(metrics[0], 5)
up = np.percentile(metrics[0], 99.8)
lp, up

In [None]:
lp_all = 4.33984375
up_all = 428.4710494
mean = -23.067417
std = 513.7998

mean_p = 6.152629
std_p = 4.336989

In [None]:
diffs = glob.glob(f"../Antarctica_Dem_Comparison_Ocean_Corrected/dem_diff/**.tif")
labels = ["ME", "STD", "MSE", "NMAD"]


def read_n_close(r, use_percentiles=False, lp=None, up=None, dtype=np.float16):
    with rio.open(r, "r") as ds:
        data = ds.read(1)
        data = data[~np.isnan(data)]
        if len(data) == 0:
            return None
        if use_percentiles:
            data = data[data > lp]
            data = data[data < up]
    return data.astype(dtype)

In [None]:
dask_data = da.array([])
# count = 0
loop = tqdm(enumerate(diffs), total=len(diffs))
for i, d in loop:
    dt = read_n_close(d)  # , lp=lp_all, up=up_all, use_percentiles=True)
    # dask_arrays.append(da.from_array(dt, chunks=dt.shape))
    dask_data = da.concatenate([dask_data, da.from_array(dt, chunks=dt.shape)])
    # count += len(dt)
    free_mem = psutil.virtual_memory().available / (1024**3)
    if free_mem < 5:
        print("Free memory is low, breaking the loop")
        break
    loop.set_postfix(free_mem=free_mem)

In [None]:
count = 0
sum = np.float32(0)
loop = tqdm(enumerate(diffs), total=len(diffs))
for i, d in loop:
    dt = read_n_close(d, dtype=np.float32, lp=lp_all, up=up_all, use_percentiles=True)
    if dt is None:
        loop.set_postfix(free_mem=psutil.virtual_memory().available / (1024**3))
        continue
    count += len(dt)
    sum += dt.sum()
    loop.set_postfix(free_mem=psutil.virtual_memory().available / (1024**3))
mean = sum / count

std = np.float32(0)
sum = np.float32(0)
loop = tqdm(enumerate(diffs), total=len(diffs))
for i, d in loop:
    dt = read_n_close(d, dtype=np.float32, lp=lp_all, up=up_all, use_percentiles=True)
    if dt is None:
        loop.set_postfix(free_mem=psutil.virtual_memory().available / (1024**3))
        continue
    sum += np.pow(dt - mean, 2).sum()
    loop.set_postfix(free_mem=psutil.virtual_memory().available / (1024**3))
std = np.sqrt(sum / count)

In [None]:
is_error = True
polar = True
pkls = glob.glob(
    f"../Antarctica_Dem_Comparison_Ocean_Corrected/dem_diff_metrics/**.pkl"
)
custom_bins = [[-40, -20, -10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10, 20, 40], 15, 15, 15]
equal_bins = 8
no_bounds = None
no_percentile = None
normal_bounds = [(-10, 100), (0, 10), (0, 250), (0, 5)]
extreme_bounds = [(-10000, -10), (10, 50000), (1e4, 1e8), (2, 500)]
percentiles = (5, 99.8)
outliers = False
plot_name = "../temp/metric_plots/metrics_percentile_equal_bins.html"
new_metrics, _, new_percentiles = plot_metrics(
    pkls,
    is_error,
    polar,
    plot_name,
    data_bounds=no_bounds,
    percentiles_bracket=percentiles,
    bins=equal_bins,
    plot_resolution=(700, 1600),
    percentile_outliers=outliers,
    return_metrics=True,
)

In [None]:
np.mean(new_metrics[0])

In [None]:
is_error = True
polar = True
pkls = glob.glob(f"../Antarctica_Dem_Comparison/dem_diff_metrics/**.pkl")
labels = ["ME" if is_error else "MEAN", "STD", "MSE", "NMAD"]
metrics, x0, y0 = read_metrics(pkls, numerical_axes=polar)
np.percentile(metrics[0], 10), np.percentile(metrics[0], 90)

In [None]:
dfs = gpd.read_file("../temp/coastlines.gpkg")
polys = dfs.geometry
areas = [p.area for p in polys]
big_coast = polys[np.argmax(areas)]

In [None]:
# big_ds = dfs.iloc[[np.argmax(areas)], :]
# big_ds.to_file("../temp/big_coastline.kml", driver="KML")

In [None]:
pkls = glob.glob(
    f"../Antarctica_Dem_Comparison_Ocean_Corrected/dem_diff_metrics/**.pkl"
)

In [None]:
def name_to_coord(fname, transformer):
    fname = Path(fname).name
    coords = fname.split("_")
    if "N" in coords[0]:
        x = int(coords[0].replace("N", ""))
    else:
        x = -int(coords[0].replace("S", ""))
    if "W.pkl" in coords[1]:
        y = -int(coords[1].replace("W.pkl", ""))
    else:
        y = int(coords[1].replace("E.pkl", ""))
    point_3031 = Point(transformer.transform(y, x))
    return point_3031

In [None]:
# transformer = pyproj.Transformer.from_crs(
#     pyproj.CRS("EPSG:4326"), pyproj.CRS("EPSG:3031"), always_xy=True
# )
# inrange_pkls = []
# loop = tqdm(enumerate(pkls), total=len(pkls))
# for i, p in loop:
#     point_3031 = name_to_coord(p, transformer)
#     if big_coast.contains(point_3031):
#         inrange_pkls.append(p)

# with open("../temp/inrange_pkls.txt", "w") as f:
#     for p in inrange_pkls:
#         f.write(f"{p}\n")

In [None]:
pkls = glob.glob(
    f"../Antarctica_Dem_Comparison_Ocean_Corrected/dem_diff_metrics/**.pkl"
)
with open("../temp/inrange_pkls.txt", "r") as f:
    inrange_pkls = f.read().splitlines()
is_error = True
polar = True
custom_bins = [[-40, -20, -10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10, 20, 40], 15, 15, 15]
equal_bins = 25
no_bounds = None
no_percentile = None
normal_bounds = [(-1000, 50), (0, 10), (0, 250), (0, 5)]
extreme_bounds = [(-10000, -10), (10, 50000), (1e4, 1e8), (2, 500)]
percentiles = (5, 99.8)
outliers = True
plot_name = "../temp/metric_plots/metrics_extremes_bounds_percentiles_equal_bins_coastlines_filtered.html"
new_metrics, _, new_percentiles = plot_metrics(
    inrange_pkls,
    # list(set(pkls) - set(inrange_pkls)),
    is_error,
    polar,
    plot_name,
    data_bounds=no_bounds,  # normal_bounds,
    percentiles_bracket=percentiles,
    bins=equal_bins,
    plot_resolution=(550, 1400),
    percentile_outliers=outliers,
    return_metrics=True,
)

In [None]:
np.std(new_metrics[0]), np.mean(new_metrics[0])