Data Selection
====
Plots illustrating the data selection procedure

In [None]:
import pathlib

# Need to mount the SING RDSF dir somewhere
rdsf_dir = pathlib.Path("~/geog_rdsf/").expanduser()

noisy_filepath = (
    rdsf_dir
    / "data"
    / "projects"
    / "SING"
    / "richard_stuff"
    / "Table2"
    / "dtu18_eigen-6c4_do0280_rr0004.dat"
)
clean_filepath = (
    rdsf_dir / "data" / "projects" / "dtop" / "cmip6" / "cmip6_historical_mdts_yr5.dat"
)

assert noisy_filepath.exists()
assert clean_filepath.exists()

In [None]:
import numpy as np
from current_denoising.generation import ioutils

# It's called read_currents, but actually just reads the array
noisy_mdt = ioutils.read_currents(noisy_filepath)
noisy_mdt[noisy_mdt == -1.9e19] = np.nan
clean_mdt = ioutils.read_clean_mdt(
    path=clean_filepath,
    metadata_path=clean_filepath.with_stem(clean_filepath.stem + "_meta").with_suffix(
        ".txt"
    ),
    year=2001,
    model="CMCC-CM2-HR4",
)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.image import AxesImage

from current_denoising.utils import util


def mdt_imshow(current_grid: np.ndarray, axis: plt.Axes, **kwargs) -> AxesImage:
    """Imshow for MDTs - no extent set"""
    lat, long = util.lat_long_grid(current_grid.shape)
    extent = kwargs.get("extent", [long[0], long[-1], lat[0], lat[-1]])

    imshow_kw = {
        "origin": "upper",
        "cmap": "Spectral",
        "vmax": 2,
        "vmin": -2,
    }
    imshow_kw.update(kwargs)
    imshow_kw["extent"] = extent

    im = axis.imshow(current_grid, **imshow_kw)
    im.set_extent(extent)

    return im


fig, axes = plt.subplots(2, 1, figsize=(12, 12))

im = mdt_imshow(noisy_mdt, axes[0])
axes[0].set_title("Mean Dynamic Topography (noisy)")
fig.colorbar(im, ax=axes[0])

im = mdt_imshow(clean_mdt, axes[1])
axes[1].set_title("MDT (simulated, no noise)")
fig.colorbar(im, ax=axes[1])

We need to remove NaNs in order to Gaussian smooth
----
Replace them with their nearest neighbour

In [None]:
"""
Show the MDT with NaNs replaced by the nearest non-NaN value
"""

from current_denoising.generation import mdt

nan_filled = mdt.fill_nan_with_nearest(noisy_mdt)

fig, axis = plt.subplots(1, 1, figsize=(10, 6))

im = mdt_imshow(nan_filled, axis)
fig.colorbar(im, ax=axis)
axis.set_title("NaN replaced by nearest neighbour")

fig.tight_layout()

Smooth and find residual
----
This is the "noise" we want to remove

In [None]:
"""
Applying a Gaussian filter to the gridded field is non-trivial (the grid point size changes with latitude)
"""

from scipy.ndimage import gaussian_filter

from current_denoising.generation import mdt
from current_denoising.utils import util

sigma_km = 200
sigma_grid = sigma_km / (util.KM_PER_DEG / 4)


# TODO for now just do a naive smoothing
def naive_smooth(img: np.ndarray) -> np.ndarray:
    """
    Invalid but simple smoothing of a gridded field containing NaNs

    Invalid since the kernel is constant in size in terms of grid points,
    which means it varies in size spatially.
    """
    nan_mask = np.isnan(img)

    field = mdt.fill_nan_with_nearest(img)

    # 8 grid points -> around 200km radius at equator
    field = gaussian_filter(field, sigma=sigma_grid)
    return np.where(nan_mask, np.nan, field)


fig, axes = plt.subplots(2, 1, figsize=(12, 12))

noisy_mdt_smoothed = naive_smooth(noisy_mdt)
residual = noisy_mdt - noisy_mdt_smoothed

im = mdt_imshow(noisy_mdt_smoothed, axes[0])
axes[0].set_title("Gaussian Smoothed MDT (naive)")
fig.colorbar(im, ax=axes[0])

im = mdt_imshow(residual, axes[1], vmin=-0.5, vmax=0.5, cmap="seismic")
axes[1].set_title("Residual")
fig.colorbar(im, ax=axes[1])

fig.tight_layout()

In [None]:
"""
But I've written a function to do it approximately
"""


def better_smooth(img: np.ndarray) -> np.ndarray:
    """
    Approximate smoothing with variable kernel
    """
    nan_mask = np.isnan(img)

    field = mdt.fill_nan_with_nearest(img)

    # Approximately the same kernel size as above
    return np.where(nan_mask, np.nan, mdt.gauss_smooth(field, sigma_km))


fig, axes = plt.subplots(3, 1, figsize=(12, 12))

noisy_mdt_smoothed2 = better_smooth(noisy_mdt)
residual = noisy_mdt - noisy_mdt_smoothed2

smoothing_diff = noisy_mdt_smoothed - noisy_mdt_smoothed2

im = mdt_imshow(noisy_mdt_smoothed2, axes[0])
axes[0].set_title("Smoothed MDT (varying kernel)")
fig.colorbar(im, ax=axes[0])

im = mdt_imshow(residual, axes[1], vmin=-0.5, vmax=0.5, cmap="seismic")
axes[1].set_title("Residual")
fig.colorbar(im, ax=axes[1])

im = mdt_imshow(smoothing_diff, axes[2], vmin=-0.1, vmax=0.1, cmap="PiYG")
axes[2].set_title("Difference between naive & latitude-dependent smoothing")
fig.colorbar(im, ax=axes[2])

fig.tight_layout()

In [None]:
"""
First we use some heuristics to extract some tiles from the MDT residual - we don't want ones that are too far from the equator (they are distorted),
and we don't want ones that contain too much variance
"""

from math import isqrt
from current_denoising.generation import ioutils


def plot_mdt_tiles(
    tiles: np.ndarray, indices: list[tuple[int, int]], grid_shape: tuple[int, int]
) -> plt.Figure:
    """
    Plot a grid of tiles (as an Nxshapexshape np array), labelling the lat/longs according
    to the indices extracted given that the global gridded field was shaped grid_shape
    """
    n_row = isqrt(tiles.shape[0])
    assert (
        n_row**2 == tiles.shape[0]
    ), f"must have square number of tiles, got {tiles.shape[0]}"

    fig, axes = plt.subplots(n_row, n_row, figsize=(12, 12))

    lat, long = util.lat_long_grid(grid_shape)

    for axis, tile, (y, x) in zip(axes.flat, tiles, indices):
        im = axis.imshow(
            tile, origin="upper", vmin=-0.5, vmax=0.5, cmap="seismic", aspect="equal"
        )
        im.set_extent([long[x], long[x + tile_size], lat[y], lat[y + tile_size]])

    fig.tight_layout()

    cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
    fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

    return fig


rng = np.random.default_rng(0)

tile_size = 32
tiles, indices = ioutils.extract_tiles(
    rng,
    residual,
    num_tiles=25,
    max_latitude=np.inf,
    tile_size=tile_size,
    return_indices=True,
)

fig = plot_mdt_tiles(tiles, indices, residual.shape)
fig.suptitle("Example MDT patches")

In [None]:
"""
Plot the amount of Fourier power above a threshold
"""

from typing import Callable


def fft_power_fraction_factory(
    window_size: int,
    threshold: float,
) -> Callable[..., np.ndarray]:
    """
    Build a callable(arr, axis=(-2,-1)) -> ndarray that returns the fraction of FFT power
    in a radial band relative to total power over the given axes.

    This basically does the same thing as ioutils.fft_fraction, but can be broadcast over
    an array of tiles.

    threshold is in [0,1] relative to Nyquist radius.
    band = "high" -> r_norm >= threshold
    band = "low"  -> r_norm <= threshold
    """
    y, x = np.indices((window_size, window_size))
    cy, cx = window_size // 2, window_size // 2
    r = np.hypot(x - cx, y - cy)
    r_norm = r / r.max()

    mask = (r_norm >= threshold).astype(float)

    def _f(arr: np.ndarray, axis: tuple[int, int] = (-2, -1)) -> np.ndarray:
        # Fill NaNs per-window with the window nanmean
        means = np.nanmean(arr, axis=axis, keepdims=True)
        arr_filled = np.where(np.isnan(arr), means, arr)

        power = (
            np.abs(np.fft.fftshift(np.fft.fft2(arr_filled, axes=axis), axes=axis)) ** 2
        )

        # Sum band power vs total power over the same axes
        # mask broadcasts over the last two dims (the window axes)
        band_power = (power * mask).sum(axis=axis)
        total_power = power.sum(axis=axis)

        return band_power / total_power

    return _f


def apply_to_map(f):
    retval = util.apply_to_sliding_window(residual, f, tile_size)
    return np.where(np.isnan(residual), np.nan, retval)


power_threshold = 0.05  # The power threshold
fft_power_fcn = fft_power_fraction_factory(tile_size, power_threshold)
power_fraction_map = apply_to_map(fft_power_fcn)

In [None]:
"""
Need a custom fcn for moments that takes an axis parameter so that we can apply it
in a vectorised way to our grid
"""


def central_moment(
    arr: np.ndarray, order: int, axis: tuple[int, int] = (-2, -1)
) -> np.ndarray:
    mean = np.nanmean(arr, axis=axis, keepdims=True)

    centered = arr - mean
    vals = centered**order

    return np.nanmean(vals, axis=axis)


def skew(arr, axis):
    m2 = central_moment(arr, 2, axis=axis)
    m3 = central_moment(arr, 3, axis=axis)

    return m3 / (np.sqrt(m2) ** 3 + 1e-16)


def kurt(arr, axis):
    m2 = central_moment(arr, 2, axis=axis)
    m4 = central_moment(arr, 4, axis=axis)

    return m4 / (m2**2 + 1e-16)


std_map = apply_to_map(np.nanstd)
skewness_map = apply_to_map(skew)
kurtosis_map = apply_to_map(kurt)

In [None]:
mean_map = apply_to_map(np.nanmean)
rms_map = apply_to_map(lambda x, axis: np.sqrt(np.nanmean(x**2, axis=axis)))

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 5))


def plot_on_axis(arr, axis, title, **kw):
    im = mdt_imshow(arr, axis, **kw)
    fig.colorbar(im, ax=axis)
    axis.set_title(title)


maps = [power_fraction_map, mean_map, rms_map, std_map, skewness_map, kurtosis_map]
titles = [
    f"Tile FFT power > {100*power_threshold}% of max",
    "Tile mean",
    "Tile RMS",
    "Std",
    "Tile Skewness",
    "Tile Kurtosis",
]
vmins = [None, -0.1, None, None, -5, None]
vmaxs = [None, 0.1, 0.2, 0.2, 5, 10]
kw = [{}, {"cmap": "seismic"}, {}, {}, {"cmap": "seismic"}, {}]

for map, axis, title, vmin, vmax, kw_ in zip(
    maps, axes.flat, titles, vmins, vmaxs, kw, strict=True
):
    plot_on_axis(map, axis, title, vmin=vmin, vmax=vmax, **kw_)

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 5))
for map, axis, title, vmin, vmax in zip(
    maps, axes.flat, titles, vmins, vmaxs, strict=True
):
    if vmin is None:
        vmin = np.nanmin(map)
    if vmax is None:
        vmax = np.nanmax(map)

    n, _, _ = axis.hist(map.flat, bins=np.linspace(vmin, vmax, 100))
    axis.set_title(f"{title}; total={np.sum(n):,}")

fig.tight_layout()

In [None]:
"""
Choose some criteria + plot maps of tiles which pass
"""

from current_denoising.plotting.maps import clear2black_cmap

criteria = [(0.8, None), (-0.05, 0.05), (None, 0.05), (None, 0.05), (-2, 2), (None, 6)]


def meets_criteria(arr, min_max):
    min_, max_ = min_max
    min_ = min_ if min_ is not None else -np.inf
    max_ = max_ if max_ is not None else np.inf

    return (arr > min_) & (arr < max_)


def plot_accepted(axis, arr, min_max, name):
    """ """
    # Plot accepted, setting others to NaN (which will come out as white)
    keep = meets_criteria(arr, min_max)

    axis.imshow(np.where(keep, arr, np.nan))

    # Plot land as NaN
    axis.imshow(np.isnan(residual), cmap=clear2black_cmap())
    axis.set_title(f"{name}: [{min_max[0]}, {min_max[1]}]")
    axis.set_axis_off()


fig, axes = plt.subplots(2, 3, figsize=(15, 5))
for axis, map, criterion, name in zip(axes.flat, maps, criteria, titles):
    plot_accepted(axis, map, criterion, name)
fig.tight_layout()

In [None]:
"""
Use FFT and RMS
"""

keep = meets_criteria(maps[0], criteria[0]) & meets_criteria(maps[2], criteria[2])

fig, axis = plt.subplots(figsize=(8, 4))

mdt_imshow(np.where(keep, residual, np.nan), axis, vmin=-0.5, vmax=0.5, cmap="seismic")
mdt_imshow(np.isnan(residual), cmap=clear2black_cmap(), axis=axis, vmin=0, vmax=1)

In [None]:
"""
Make a "distance from land" map
"""

land_distance = ioutils.distance_from_land(residual)

fig, axis = plt.subplots(figsize=(10, 4))
im = mdt_imshow(land_distance, axis, vmin=1, vmax=50)
fig.colorbar(im, ax=axis)
mdt_imshow(np.isnan(residual), axis, cmap=clear2black_cmap(), vmin=0, vmax=1)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 4))

fft_bounds = (0.8, np.inf)
rms_bounds = (0, 0.05)
distance_bounds = (20, np.inf)

for axis, map_, bounds, title in zip(
    axes,
    [power_fraction_map, rms_map, land_distance],
    [fft_bounds, rms_bounds, distance_bounds],
    ["Fourier power", "RMS", "Distance from land"],
    strict=True,
):
    keep = meets_criteria(map_, bounds)
    mdt_imshow(
        np.where(keep, residual, np.nan), axis, vmin=-0.5, vmax=0.5, cmap="seismic"
    )
    mdt_imshow(np.isnan(residual), axis, cmap=clear2black_cmap(), vmin=0, vmax=1)
    axis.set_title(title)

fig.tight_layout()

In [None]:
"""
Get the overall map
"""

keep = (
    meets_criteria(power_fraction_map, fft_bounds)
    & meets_criteria(rms_map, rms_bounds)
    & meets_criteria(land_distance, distance_bounds)
)

fig, axis = plt.subplots(figsize=(16, 8))
mdt_imshow(np.where(keep, residual, np.nan), axis, vmin=-0.5, vmax=0.5, cmap="seismic")
mdt_imshow(np.isnan(residual), axis, vmin=0, vmax=1, cmap=clear2black_cmap())

fig.suptitle("This should match the map in the main notebook")
fig.tight_layout()