Denoise the MDT directly
====
Instead of working with currents (gradient of MDT, times some constants + coriolis parameter) we can probably smooth + denoise the MDT directly

Read in MDTs
----

In [None]:
import pathlib

# Need to mount the SING RDSF dir somewhere
rdsf_dir = pathlib.Path("~/geog_rdsf/").expanduser()

noisy_filepath = (
    rdsf_dir
    / "data"
    / "projects"
    / "SING"
    / "richard_stuff"
    / "Table2"
    / "dtu18_eigen-6c4_do0280_rr0004.dat"
)
clean_filepath = (
    rdsf_dir / "data" / "projects" / "dtop" / "cmip6" / "cmip6_historical_mdts_yr5.dat"
)

assert noisy_filepath.exists()
assert clean_filepath.exists()

In [None]:
import numpy as np
from current_denoising.generation import ioutils

# It's called read_currents, but actually just reads the array
noisy_mdt = ioutils.read_currents(noisy_filepath)
noisy_mdt[noisy_mdt == -1.9e19] = np.nan
clean_mdt = ioutils.read_clean_mdt(
    path=clean_filepath,
    metadata_path=clean_filepath.with_stem(clean_filepath.stem + "_meta").with_suffix(
        ".txt"
    ),
    year=2001,
    model="CMCC-CM2-HR4",
)

In [None]:
"""
Find the residual between the MDT and the smoothed MDT
"""

from current_denoising.utils import util
from current_denoising.generation import mdt

sigma_km = 200


def get_residual(img: np.ndarray, sigma_km: float) -> np.ndarray:
    """
    Get the residual between the provided MDT gridded field and a Gaussian smoothed version.
    """
    # Find the difference between the original image and a smoothed version, where we replace
    # NaN with their nearest values
    nan_mask = np.isnan(img)
    smoothed = mdt.gauss_smooth(mdt.fill_nan_with_nearest(img), sigma_km)

    # Put NaN values back in to indicate land, and remove the global mean
    smoothed = np.where(nan_mask, np.nan, smoothed)

    residual = img - smoothed

    return ioutils._remove_nanmean(residual)


residual = get_residual(noisy_mdt, sigma_km)

In [None]:
import matplotlib.pyplot as plt
from current_denoising.plotting import maps


def mdt_imshow(current_grid: np.ndarray, axis: plt.Axes, **kwargs):
    """Imshow for MDTs - no extent set"""
    lat, long = util.lat_long_grid(current_grid.shape)
    extent = kwargs.get("extent", [long[0], long[-1], lat[0], lat[-1]])

    imshow_kw = {
        "origin": "upper",
        "cmap": "seismic",
        "vmax": -0.5,
        "vmin": 0.5,
    }
    imshow_kw.update(kwargs)
    imshow_kw["extent"] = extent

    im = axis.imshow(current_grid, **imshow_kw)
    im.set_extent(extent)

    # Plot also the land
    axis.imshow(
        np.isnan(current_grid),
        cmap=maps.clear2black_cmap(),
        extent=extent,
        vmin=0,
        vmax=1,
        origin="upper",
    )

    # Add a colorbar
    axis.get_figure().colorbar(im, ax=axis)


fig, axis = plt.subplots(figsize=(18, 9))

mdt_imshow(residual, axis)

fig.suptitle(f"Residual: smoothed $\sigma=${sigma_km}km")

Choose tiles to use in GAN training
----

In [None]:
"""
Plot the selected tiles on a map, according to a selection function

This will e.g. select for tiles that have 80% of their power in frequencies above 5% of the minimum
"""

from current_denoising.plotting import maps


def select_tile(tile: np.ndarray) -> bool:
    """
    Selection function for a tile.
    RMS and FFT based.
    """
    if ioutils.tile_rms(tile) > 0.05:
        return False
    if ioutils.fft_fraction(tile, 0.05) < 0.80:
        return False
    return True


rng = np.random.default_rng(seed=0)

tile_size = 32
num_tiles = 512
latitude_threshhold = 60.0
forbidden_mask = ioutils.distance_from_land(residual) < 20

# Get the tiles from our map based on the criteria
# Return the indices (their locations) for plotting
tiles, indices = ioutils.extract_tiles(
    rng,
    residual,
    num_tiles=num_tiles,
    forbidden_mask=forbidden_mask,
    tile_criterion=select_tile,
    max_latitude=latitude_threshhold,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
fig, axis = plt.subplots(figsize=(16, 9))

# Put the tiles into a single numpy array for plotting
tile_grid = np.zeros_like(residual)
for tile, (y, x) in zip(tiles, indices):
    tile_grid[y : y + tile_size, x : x + tile_size] = tile

# Put NaNs back in for plotting
tile_grid = np.where(np.isnan(residual), np.nan, tile_grid)

mdt_imshow(tile_grid, axis=axis, vmin=-0.3, vmax=0.3, cmap="seismic")

for t in (latitude_threshhold, -latitude_threshhold):
    axis.axhline(t, color="r", linestyle="--")
axis.text(-197, latitude_threshhold, f"{latitude_threshhold}" + r"$\degree$", color="r")

fig.tight_layout()

In [None]:
"""
Show some of the training dataset
"""

fig, axes = plt.subplots(5, 5, figsize=(8, 8))

for axis, tile in zip(axes.flat, tiles):
    axis.imshow(tile, cmap="seismic", vmin=-0.5, vmax=0.5)
    axis.set_xticks([])
    axis.set_yticks([])

fig.suptitle("Training Tiles")
fig.tight_layout()

Generate some tiles using the GAN
----
Load the model from memory - it should already be trained

In [None]:
import torch
from current_denoising.generation import dcgan

# We'll read the generator from here
output_dir = pathlib.Path("outputs/mdt_gan/test/")

# TODO nicer way of recording the hyperparams in the generator object for later...?
# These must match the values with which the generator was trained
latent_channels = 32
latent_size = 4

generator = dcgan.Generator(tile_size, latent_channels, latent_size)
generator.load_state_dict(torch.load(output_dir / "generator_final.pth"))
_ = generator.to("cuda")

In [None]:
"""
Generate some tiles
"""

n_gen = 25

gen_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_gen,
    noise_size=latent_size,
    device="cuda",
)
scaled = tiles.mean() + tiles.std() * (gen_tiles - gen_tiles.mean()) / (
    gen_tiles.std() + 1e-8
)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))

for axis, img in zip(axes.flat, scaled):
    im = axis.imshow(img, vmin=-0.5, vmax=0.5, cmap="seismic")
    im.set_extent([0, scaled.shape[1], 0, scaled.shape[2]])

fig.suptitle(f"Output size: {scaled[0].shape}")
fig.tight_layout()

cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

In [None]:
"""
Generate some tiles, larger than the training set, by passing the generator a slightly larger random latent vector
"""

big_gen_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_gen,
    noise_size=4 * latent_size,
    device="cuda",
)
big_scaled = tiles.mean() + tiles.std() * (big_gen_tiles - big_gen_tiles.mean()) / (
    big_gen_tiles.std() + 1e-8
)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))

for axis, img in zip(axes.flat, big_scaled):
    im = axis.imshow(img, vmin=-0.5, vmax=0.5, cmap="seismic")
    im.set_extent([0, big_scaled.shape[1], 0, big_scaled.shape[2]])

fig.suptitle(f"Output size: {big_scaled[0].shape}")
fig.tight_layout()

cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

Stitch these tiles together
----
As a comparison, stitch the tiles together using quilting - this will likely give us something less good than the native GAN outfilling.

In [None]:
"""
Compare real 128x128 patches to quilted ones to ones generated with our GAN

"""

from current_denoising.generation import quilting

real_tiles = ioutils.extract_tiles(
    rng,
    residual,
    num_tiles=128,
    forbidden_mask=forbidden_mask,
    tile_criterion=select_tile,
    max_latitude=latitude_threshhold,
    tile_size=128,
)

fig, axes = plt.subplots(4, 3, figsize=(9, 12))

for axs, gen_tile, real_tile in zip(axes, big_scaled, real_tiles):
    patches = dcgan.generate_tiles(
        generator,
        n_tiles=n_gen,
        noise_size=latent_size,
        device="cuda",
    )
    patches = tiles.mean() + tiles.std() * (patches - patches.mean()) / (
        patches.std() + 1e-8
    )

    torch.seed()
    quilt = quilting.quilt(
        patches, target_size=(128, 128), patch_overlap=4, repeat_penalty=0
    )

    kw = {"cmap": "seismic", "vmin": -0.5, "vmax": 0.5}
    axs[0].imshow(quilt, **kw)
    axs[1].imshow(real_tile, **kw)
    axs[2].imshow(gen_tile, **kw)

axes[0, 0].set_title("Quilted")
axes[0, 1].set_title("Real Noise")
axes[0, 2].set_title("Native GAN Outfilling")

fig.tight_layout()

Train the denoising model
----
Train the denoising model with synthetic noise + real patches of clean MDT

Evaluate the denoising model on different noisy MDTs
----