Denoise the MDT directly
====
Instead of working with currents (gradient of MDT, times some constants + coriolis parameter) we can probably smooth + denoise the MDT directly

Read in MDTs
----

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib

# Need to mount the SING RDSF dir somewhere
rdsf_dir = pathlib.Path("~/geog_rdsf/").expanduser()

noisy_filepath = (
    rdsf_dir
    / "data"
    / "projects"
    / "SING"
    / "richard_stuff"
    / "Table2"
    / "dtu18_eigen-6c4_do0280_rr0004.dat"
)
clean_filepath = (
    rdsf_dir / "data" / "projects" / "dtop" / "cmip6" / "cmip6_historical_mdts_yr5.dat"
)

assert noisy_filepath.exists()
assert clean_filepath.exists()

In [None]:
import numpy as np
from current_denoising.generation import ioutils

# It's called read_currents, but actually just reads the array
noisy_mdt = ioutils.read_currents(noisy_filepath)
noisy_mdt[noisy_mdt == -1.9e19] = np.nan
clean_mdt = ioutils.read_clean_mdt(
    path=clean_filepath,
    metadata_path=clean_filepath.with_stem(clean_filepath.stem + "_meta").with_suffix(
        ".txt"
    ),
    year=2001,
    model="CMCC-CM2-HR4",
)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.image import AxesImage

from current_denoising.utils import util


def mdt_imshow(current_grid: np.ndarray, axis: plt.Axes, **kwargs) -> AxesImage:
    """Imshow for MDTs - no extent set"""
    lat, long = util.lat_long_grid(current_grid.shape)
    extent = kwargs.get("extent", [long[0], long[-1], lat[0], lat[-1]])

    imshow_kw = {
        "origin": "upper",
        "cmap": "Spectral",
        "vmax": 2,
        "vmin": -2,
    }
    imshow_kw.update(kwargs)
    imshow_kw["extent"] = extent

    im = axis.imshow(current_grid, **imshow_kw)
    im.set_extent(extent)

    return im


fig, axes = plt.subplots(2, 1, figsize=(12, 12))

im = mdt_imshow(noisy_mdt, axes[0])
axes[0].set_title("Mean Dynamic Topography (noisy)")
fig.colorbar(im, ax=axes[0])

im = mdt_imshow(clean_mdt, axes[1])
axes[1].set_title("MDT (simulated, no noise)")
fig.colorbar(im, ax=axes[1])

We need to remove NaNs in order to Gaussian smooth
----
Replace them with their nearest neighbour

In [None]:
"""
Show the MDT with NaNs replaced by the nearest non-NaN value
"""

from current_denoising.generation import mdt

nan_filled = mdt.fill_nan_with_nearest(noisy_mdt)

fig, axis = plt.subplots(1, 1, figsize=(10, 6))

im = mdt_imshow(nan_filled, axis)
fig.colorbar(im, ax=axis)
axis.set_title("NaN replaced by nearest neighbour")

fig.tight_layout()

Smooth and find residual
----
This is the "noise" we want to remove

In [None]:
"""
Applying a Gaussian filter to the gridded field is non-trivial (the grid point size changes with latitude)
"""

from scipy.ndimage import gaussian_filter

from current_denoising.generation import mdt
from current_denoising.utils import util

sigma_km = 200
sigma_grid = sigma_km / (util.KM_PER_DEG / 4)


# TODO for now just do a naive smoothing
def naive_smooth(img: np.ndarray) -> np.ndarray:
    """
    Invalid but simple smoothing of a gridded field containing NaNs

    Invalid since the kernel is constant in size in terms of grid points,
    which means it varies in size spatially.
    """
    nan_mask = np.isnan(img)

    field = mdt.fill_nan_with_nearest(img)

    # 8 grid points -> around 200km radius at equator
    field = gaussian_filter(field, sigma=sigma_grid)
    return np.where(nan_mask, np.nan, field)


fig, axes = plt.subplots(2, 1, figsize=(12, 12))

noisy_mdt_smoothed = naive_smooth(noisy_mdt)
residual = noisy_mdt - noisy_mdt_smoothed

im = mdt_imshow(noisy_mdt_smoothed, axes[0])
axes[0].set_title("Gaussian Smoothed MDT (naive)")
fig.colorbar(im, ax=axes[0])

im = mdt_imshow(residual, axes[1], vmin=-0.5, vmax=0.5, cmap="seismic")
axes[1].set_title("Residual")
fig.colorbar(im, ax=axes[1])

fig.tight_layout()

In [None]:
"""
But I've written a function to do it approximately
"""


def better_smooth(img: np.ndarray) -> np.ndarray:
    """
    Approximate smoothing with variable kernel
    """
    nan_mask = np.isnan(img)

    field = mdt.fill_nan_with_nearest(img)

    # Approximately the same kernel size as above
    return np.where(nan_mask, np.nan, mdt.gauss_smooth(field, sigma_km))


fig, axes = plt.subplots(3, 1, figsize=(12, 12))

noisy_mdt_smoothed2 = better_smooth(noisy_mdt)
residual = noisy_mdt - noisy_mdt_smoothed2

smoothing_diff = noisy_mdt_smoothed - noisy_mdt_smoothed2

im = mdt_imshow(noisy_mdt_smoothed2, axes[0])
axes[0].set_title("Smoothed MDT (varying kernel)")
fig.colorbar(im, ax=axes[0])

im = mdt_imshow(residual, axes[1], vmin=-0.5, vmax=0.5, cmap="seismic")
axes[1].set_title("Residual")
fig.colorbar(im, ax=axes[1])

im = mdt_imshow(smoothing_diff, axes[2], vmin=-0.1, vmax=0.1, cmap="PiYG")
axes[2].set_title("Difference between naive & latitude-dependent smoothing")
fig.colorbar(im, ax=axes[2])

fig.tight_layout()

Optional: train a GAN to generate realistic-looking tiles of MDT noise
----

So that we have a large set of training data for the denoising model

In [None]:
"""
First we use some heuristics to extract some tiles from the MDT residual - we don't want ones that are too far from the equator (they are distorted),
and we don't want ones that contain too much variance
"""

from math import isqrt
from current_denoising.generation import ioutils


def plot_mdt_tiles(
    tiles: np.ndarray, indices: list[tuple[int, int]], grid_shape: tuple[int, int]
) -> plt.Figure:
    """
    Plot a grid of tiles (as an Nxshapexshape np array), labelling the lat/longs according
    to the indices extracted given that the global gridded field was shaped grid_shape
    """
    n_row = isqrt(tiles.shape[0])
    assert (
        n_row**2 == tiles.shape[0]
    ), f"must have square number of tiles, got {tiles.shape[0]}"

    fig, axes = plt.subplots(n_row, n_row, figsize=(12, 12))

    lat, long = util.lat_long_grid(grid_shape)

    for axis, tile, (y, x) in zip(axes.flat, tiles, indices):
        im = axis.imshow(tile, origin="upper", vmin=-0.5, vmax=0.5, cmap="seismic")
        im.set_extent([long[x], long[x + tile_size], lat[y], lat[y + tile_size]])

    fig.tight_layout()

    cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
    fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

    return fig


rng = np.random.default_rng(0)

tile_size = 32
tiles, indices = ioutils.extract_tiles(
    rng,
    residual,
    num_tiles=25,
    max_latitude=np.inf,
    tile_size=tile_size,
    return_indices=True,
)
print(indices[0])

fig = plot_mdt_tiles(tiles, indices, residual.shape)
fig.suptitle("Example MDT patches")

In [None]:
"""
Choose some tiles as examples, from specific chosen locations on the grid
"""

from scipy.stats import moment

# Four normal, two distorted by globe, two full of signal near coasts
lats = [0, -20, -20, 30, -50, 80, 10, 45]
longs = [0, 50, -100, 130, -100, 0, -51, 120]
example_tiles = [
    util.get_tile(residual, (lat, long), tile_size // 4)
    for lat, long in zip(lats, longs)
]

fig, axes = plt.subplots(8, 3, figsize=(9, 24))

bins = np.linspace(-0.3, 0.3, 100)
for tile, axs, lat, long in zip(example_tiles, axes, lats, longs):
    # Plot the tile
    axs[0].imshow(
        tile,
        extent=[long, long + tile_size // 4, lat, lat + tile_size // 4],
        vmin=-0.5,
        vmax=0.5,
        cmap="seismic",
    )

    # Plot the hist of this tile
    axs[1].hist(tile.flat, bins=bins)
    labels = [_, _, "std", "skew", "kurtosis"]
    axs[1].set_title(
        " ".join(
            [f"{labels[i]} {moment(tile.flat, order=i, nan_policy='omit'):.4f}\n" for i in range(2, 5)]
        )
    )

    log_power = np.log2(abs(np.fft.fftshift(np.fft.fft2(tile)))**2)
    axs[2].imshow(log_power)

axes[0, 0].set_title("Tile")
axes[0, 2].set_title("log power spectrum")
fig.tight_layout()

In [None]:
"""
Plot the amount of Fourier power above a threshold
"""

from typing import Callable

from matplotlib.colors import TwoSlopeNorm


def fft_power_fraction_factory(
    window_size: int,
    threshold: float,
    *,
    band: str = "high",  # "high" for high-freq fraction, "low" for low-freq fraction
) -> Callable[..., np.ndarray]:
    """
    Build a callable(arr, axis=(-2,-1)) -> ndarray that returns the fraction of FFT power
    in a radial band relative to total power over the given axes.

    threshold is in [0,1] relative to Nyquist radius.
    band = "high" -> r_norm >= threshold
    band = "low"  -> r_norm <= threshold
    """
    y, x = np.indices((window_size, window_size))
    cy, cx = window_size // 2, window_size // 2
    r = np.hypot(x - cx, y - cy)
    r_norm = r / r.max()

    if band == "high":
        mask = r_norm >= threshold
    elif band == "low":
        mask = r_norm <= threshold
    else:
        raise ValueError(f"band must be 'high' or 'low', got {band!r}")

    mask = mask.astype(float)  # for multiplication

    def _f(arr: np.ndarray, axis: tuple[int, int] = (-2, -1)) -> np.ndarray:
        # Fill NaNs per-window with the window nanmean
        means = np.nanmean(arr, axis=axis, keepdims=True)
        arr_filled = np.where(np.isnan(arr), means, arr)

        # FFT over the window axes
        f = np.fft.fft2(arr_filled, axes=axis)
        f = np.fft.fftshift(f, axes=axis)
        power = np.abs(f) ** 2

        # Sum band power vs total power over the same axes
        # mask broadcasts over the last two dims (the window axes)
        band_power = (power * mask).sum(axis=axis)
        total_power = power.sum(axis=axis)

        return band_power / total_power

    return _f


def apply_to_map(f):
    retval = util.apply_to_sliding_window(residual, f, tile_size)
    return np.where(np.isnan(residual), np.nan, retval)


power_threshold = 0.05  # The power threshold
fft_power_fcn = fft_power_fraction_factory(tile_size, power_threshold, band="high")
power_fraction_map = apply_to_map(fft_power_fcn)

fig, axes = plt.subplots(2, 1)

im = mdt_imshow(
    power_fraction_map,
    axes[0],
    cmap="RdGy_r",
    vmin=None,
    vmax=None,
)
fig.colorbar(im, ax=axes[0])
axes[0].set_title(f"Tile FFT power > {100*power_threshold}% of max")

fig.tight_layout()

Choose tiles to keep based on our criteria
----

In [None]:
"""
Plot the selected tiles on a map with the latitudes indicated. Hopefully we can see that we've de-selected ones near the coasts

This will e.g. select for tiles that have 80% of their power in frequencies above 5% of the minimum
"""

from current_denoising.plotting import maps

latitude_threshhold = 64.0
tiles, indices = ioutils.extract_tiles(
    rng,
    residual,
    num_tiles=1024,
    tile_criterion=(lambda x: -ioutils.fft_fraction(x, power_threshold), -power_fraction),
    max_latitude=latitude_threshhold,
    tile_size=tile_size,
    return_indices=True,
)

tile_grid = np.ones_like(residual) * np.nan

for tile, (y, x) in zip(tiles, indices):
    tile_grid[y : y + tile_size, x : x + tile_size] = tile

fig, axis = plt.subplots(figsize=(16, 9))
im = mdt_imshow(tile_grid, axis=axis, vmin=-0.3, vmax=0.3, cmap="seismic")
fig.colorbar(im, ax=axis)

lat, long = util.lat_long_grid(tile_grid.shape)
extent = [long[0], long[-1], lat[0], lat[-1]]

axis.imshow(
    np.isnan(residual),
    cmap=maps.clear2black_cmap(),
    extent=extent,
    vmin=0,
    vmax=1,
    origin="upper",
)

for t in (latitude_threshhold, -latitude_threshhold):
    axis.axhline(t, color="r", linestyle="--")
axis.text(-197, latitude_threshhold, f"{latitude_threshhold}" + r"$\degree$", color="r")

fig.suptitle(
    f"Extracted patches; keeping tiles with\n{power_fraction*100}% of power above {power_threshold*100}% of the maximum"
)
fig.tight_layout()

In [None]:
fig = plot_mdt_tiles(tiles[:25], indices[:25], residual.shape)

In [None]:
assert False, "to stop the GAN executing if i dont want it to"

In [None]:
"""
Train a simple GAN on these tiles
"""

from current_denoising.generation import dcgan

dataset = dcgan.TileLoader(tiles)

In [None]:
import torch
import matplotlib.cm as cm
import matplotlib.colors as colors
from scipy.stats import wasserstein_distance

from current_denoising.plotting import training, img_validation


def _gp_plot(training_metrics, lambda_gp, plot_dir):
    fig = training_metrics.plot_gp_wd_ratio(lambda_gp)

    fig.savefig(plot_dir / "gp_wd.png")
    plt.close(fig)


def _grad_plot(train_metrics, g_lr, d_lr, plot_dir):
    fig = train_metrics.plot_param_gradients(g_lr, d_lr)
    fig.savefig(plot_dir / "grads.png")
    plt.close(fig)


def _grad_norm_plot(train_metrics, plot_dir):
    fig = train_metrics.plot_critic_grad_norms()
    fig.savefig(plot_dir / "grad_norm.png")
    plt.close(fig)


def _imgs_plot(imgs, plot_dir, title):
    fig = img_validation.show(imgs, cmap="turbo")
    mappable = cm.ScalarMappable(
        norm=colors.Normalize(vmin=0.0, vmax=1.4), cmap="turbo"
    )
    mappable.set_array([])
    fig.colorbar(mappable, ax=fig.axes)
    fig.suptitle(f"Generated images")
    fig.savefig(plot_dir / "generated.png")
    plt.close(fig)


def _hist_plot(imgs, dataloader, plot_dir, title):
    fig, axis = plt.subplots(1, 1, figsize=(8, 5))
    hist_kw = {
        "bins": np.linspace(-1, 1, 150),
        "density": True,
        "alpha": 0.5,
        "histtype": "step",
        "linewidth": 2,
    }
    img_validation.hist(imgs, axis=axis, **hist_kw, label="Generated images")
    img_validation.hist(
        next(iter(dataloader)),
        axis=axis,
        **hist_kw,
        label="Real images",
        linestyle="dashed",
    )
    axis.set_title(f"Image Hists - {title}")
    axis.legend()
    fig.tight_layout()
    fig.savefig(plot_dir / "hists.png")
    plt.close(fig)


def _fft_plot(imgs, plot_dir, title):
    # Plot FFTs of the generated and real images
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    _, gen_fft = img_validation.fft(imgs, axis=axes[0])
    axes[0].set_title("Generated images FFT")

    _, real_fft = img_validation.fft(next(iter(dataloader)), axis=axes[1])
    axes[1].set_title(f"Real images FFT {title}")
    fig.tight_layout()
    fig.savefig(plot_dir / "ffts.png")
    plt.close(fig)

    fft_mse = np.mean((gen_fft - real_fft) ** 2)

    return fft_mse


def train_gan(
    hyperparams: dcgan.GANHyperParams,
    dataloader: torch.utils.data.DataLoader,
    img_size: int,
    batch_size: int,
    output_dir: pathlib.Path,
) -> tuple[torch.nn.Module, torch.nn.Module]:
    """
    Train the GAN and make lots of debug plots

    Returns G + D and some metrics
    """
    metrics = {}

    # Train the GAN
    (generator, discriminator, train_metrics) = dcgan.train_new_gan(
        dataloader,
        hyperparams,
        "cuda",
        img_size=img_size,
        output_dir=output_dir,
    )

    title = f"g_lr={hyperparams.g_lr=}, d_lr={hyperparams.d_lr}"

    # Plot training losses
    fig = train_metrics.plot_scores()
    fig.suptitle(title)
    fig.tight_layout()
    fig.savefig(output_dir / "losses.png")
    plt.close(fig)

    # Plot contributions of gradient penalty and Wasserstein distance to discriminator loss
    _gp_plot(train_metrics, hyperparams.lambda_gp, output_dir)

    # Plot gradients
    _grad_plot(train_metrics, hyperparams.g_lr, hyperparams.d_lr, output_dir)

    # Plot grad norm
    _grad_norm_plot(train_metrics, output_dir)

    # Generate some images and display them
    gen_imgs = generator.gen_imgs(batch_size, hyperparams.generator_latent_size)
    _imgs_plot(gen_imgs, output_dir, title)
    _hist_plot(gen_imgs, dataloader, output_dir, title)
    fft_mse = _fft_plot(gen_imgs, output_dir, title)

    # Track metrics
    metrics["mean_wd_gp_ratio"] = (
        hyperparams.lambda_gp
        * np.mean(train_metrics.gradient_penalties, axis=1)
        / np.mean(train_metrics.wasserstein_dists, axis=1)
    ).mean()

    # Don't want the first epochs - need to give the model some time to stabilise
    metrics["mean_gradient_ratio"] = abs(
        (
            (hyperparams.g_lr * train_metrics.generator_param_gradients[-30:])
            / (hyperparams.d_lr * train_metrics.critic_param_gradients[-30:])
        ).mean()
    )

    # Don't want the first epochs - need to give the model some time to stabilise
    metrics["avg_grad_norm"] = np.mean(train_metrics.critic_interp_grad_norms, axis=1)[
        -30:
    ].mean()

    metrics["hist_wasserstein"] = wasserstein_distance(
        gen_imgs.detach().cpu().numpy().flatten(),
        next(iter(dataloader)).cpu().numpy().flatten(),
    )
    metrics["fft_mse"] = fft_mse

    return generator, discriminator, metrics

In [None]:
batch_size = 128

output_dir = pathlib.Path("outputs/mdt_gan/test/")
output_dir.mkdir(parents=True, exist_ok=True)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
)
hyperparams = dcgan.GANHyperParams(
    n_epochs=500,
    g_lr=0.0002,
    d_lr=0.0002,
    n_critic=5,
    lambda_gp=20,
    generator_latent_channels=32,
    generator_latent_size=4,
    n_discriminator_blocks=4,
)

gen, critic, metrics = train_gan(
    hyperparams,
    dataloader,
    img_size=tile_size,
    batch_size=batch_size,
    output_dir=output_dir,
)

In [None]:
torch.save(gen.state_dict(), output_dir / "generator_final.pth")
torch.save(critic.state_dict(), output_dir / "discriminator_final.pth")

In [None]:
"""
Generate and display some example tiles
"""

n_gen = 16
gen_tiles = dcgan.generate_tiles(
    gen, n_tiles=n_gen, noise_size=hyperparams.generator_latent_size, device="cuda"
)
scaled = (gen_tiles - gen_tiles.mean()) / (gen_tiles.max() - gen_tiles.min())
plot_mdt_tiles(scaled, np.zeros((n_gen, 2), dtype=int), (100, 100)).suptitle(
    f"Output size: {scaled[0].shape}"
)

In [None]:
"""
Repeat with a larger input size, to demonstrate the outfilling
"""

gen_tiles = dcgan.generate_tiles(
    gen, n_tiles=n_gen, noise_size=2 * hyperparams.generator_latent_size, device="cuda"
)
scaled = (gen_tiles - gen_tiles.mean()) / (gen_tiles.max() - gen_tiles.min())
fig = plot_mdt_tiles(scaled, np.zeros((n_gen, 2), dtype=int), (100, 100))

fig.suptitle(f"Output size: {scaled[0].shape}")

Stitch these tiles together
----
As a comparison, stitch the tiles together using quilting - this will likely give us something less good than the native GAN outfilling.

Train the denoising model
----
Train the denoising model with synthetic noise + real patches of clean MDT

Evaluate the denoising model on different noisy MDTs
----