Denoise the MDT directly
====
Instead of working with currents (gradient of MDT, times some constants + coriolis parameter) we can probably smooth + denoise the MDT directly

Read in MDTs
----

In [None]:
import pathlib

# Need to mount the SING RDSF dir somewhere
rdsf_dir = pathlib.Path("~/geog_rdsf/").expanduser()

noisy_filepath = (
    rdsf_dir
    / "data"
    / "projects"
    / "SING"
    / "richard_stuff"
    / "Table2"
    / "dtu18_eigen-6c4_do0280_rr0004.dat"
)
clean_filepath = (
    rdsf_dir / "data" / "projects" / "dtop" / "cmip6" / "cmip6_historical_mdts_yr5.dat"
)

assert noisy_filepath.exists()
assert clean_filepath.exists()

In [None]:
import numpy as np
from current_denoising.generation import ioutils

# It's called read_currents, but actually just reads the array
noisy_mdt = ioutils.read_currents(noisy_filepath)
noisy_mdt[noisy_mdt == -1.9e19] = np.nan

In [None]:
"""
Find the residual between the MDT and the smoothed MDT
"""

from current_denoising.utils import util
from current_denoising.generation import mdt

sigma_km = 200


def get_residual(img: np.ndarray, sigma_km: float) -> np.ndarray:
    """
    Get the residual (difference) between the provided MDT gridded field and a Gaussian smoothed version.

    :param img: the (2d) gridded field from which a residual is extracted.
                Land values should be marked with NaN.
    :param sigma_km: the size, in km, of the gaussian smoothing filter.

    :returns: the (2d) residual between `img` and the smoothed `img`.

    Notes:
    Performs an approximate smoothing - the size of the smoothing kernel varies with latitude, which is
    much more accurate than assuming the grid points are equal sized, but it is still not exact.

    The original image may contain NaNs, which will be replaced by their nearest neighbour
    values (keeping them as NaN would cause the NaNs to propagate throughout the entire
    grid; replacing them with a constant (e.g. 0.0 or the global mean) would distort the residual
    near the coasts, which is what we care about). This nearest-neighbour search is not a true
    physical nearest neighbour, since it doesn't account for the variation of grid point sizes
    with latitude. However, since the nearest neighbour search is performed over very short
    distances, it should be good enough.


    """
    # Find the difference between the original image and a smoothed version, where we replace
    # NaN with their nearest values
    nan_mask = np.isnan(img)
    smoothed = mdt.gauss_smooth(mdt.fill_nan_with_nearest(img), sigma_km)

    # Put NaN values back in to indicate land, and remove the global mean
    smoothed = np.where(nan_mask, np.nan, smoothed)

    residual = img - smoothed

    return ioutils._remove_nanmean(residual)


residual = get_residual(noisy_mdt, sigma_km)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import CenteredNorm

from current_denoising.plotting import maps


def mdt_imshow(current_grid: np.ndarray, axis: plt.Axes, norm=None, **kwargs):
    """Imshow for MDTs - no extent set"""
    lat, long = util.lat_long_grid(current_grid.shape)
    extent = kwargs.get("extent", [long[0], long[-1], lat[0], lat[-1]])

    imshow_kw = {
        "origin": "upper",
        "cmap": "seismic",
    }
    imshow_kw.update(kwargs)
    imshow_kw["extent"] = extent

    im = axis.imshow(current_grid, norm=norm, **imshow_kw)
    im.set_extent(extent)

    # Plot also the land
    axis.imshow(
        np.isnan(current_grid),
        cmap=maps.clear2black_cmap(),
        extent=extent,
        vmin=0,
        vmax=1,
        origin="upper",
    )

    # Add a colorbar
    axis.get_figure().colorbar(im, ax=axis)


fig, axis = plt.subplots(figsize=(18, 9))

mdt_imshow(residual, axis, norm=CenteredNorm(vcenter=0.0, halfrange=0.5))

fig.suptitle(f"Residual: smoothed $\sigma=${sigma_km}km")

Choose tiles to use in GAN training
----

In [None]:
"""
Plot the selected tiles on a map, according to a selection function

This will e.g. select for tiles that have 80% of their power in frequencies above 5% of the minimum
"""

tile_size = 32
latitude_threshhold = 60.0
forbidden_mask = ioutils.distance_from_land(residual) < 20

# Get the tiles from our map based on the criteria
# Return the indices (their locations) for plotting
tiles, indices = ioutils.extract_tiles(
    residual,
    forbidden_mask=forbidden_mask,
    tile_criterion=ioutils.select_tile,
    max_latitude=latitude_threshhold,
    tile_size=tile_size,
    return_indices=True,
)

num_tiles = len(tiles)

In [None]:
fig, axis = plt.subplots(figsize=(16, 9))


def _plot_tiles(tiles, indices, tile_size, axis, **imshow_kw):
    """
    Plot some tiles on an axis, placing them at the provided index
    """
    # Put the tiles into a single numpy array for plotting
    tile_grid = np.zeros_like(residual)
    for tile, (y, x) in zip(tiles, indices):
        tile_grid[y : y + tile_size, x : x + tile_size] = tile

    # Put NaNs back in for plotting
    tile_grid = np.where(np.isnan(residual), np.nan, tile_grid)

    mdt_imshow(tile_grid, axis=axis, norm=CenteredNorm(vcenter=0.0, halfrange=0.5))


_plot_tiles(tiles, indices, tile_size, axis, cmap="seismic")


for t in (latitude_threshhold, -latitude_threshhold):
    axis.axhline(t, color="r", linestyle="--")
axis.text(-197, latitude_threshhold, f"{latitude_threshhold}" + r"$\degree$", color="r")

fig.suptitle(f"{num_tiles} non-overlapping tiles")
fig.tight_layout()

In [None]:
"""
Show some of the training dataset
"""

fig, axes = plt.subplots(5, 5, figsize=(8, 8))

for axis, tile in zip(axes.flat, tiles):
    axis.imshow(tile, cmap="seismic", vmin=-0.5, vmax=0.5)
    axis.set_xticks([])
    axis.set_yticks([])

fig.suptitle("Training Tiles")
fig.tight_layout()

Generate some tiles using the GAN
----
Load the model from memory - it should already be trained

In [None]:
import torch
from current_denoising.generation import dcgan

# We'll read the generator from here
output_dir = pathlib.Path("explanatory_notebooks/outputs/mdt_gan/test/")

# TODO nicer way of recording the hyperparams in the generator object for later...?
# These must match the values with which the generator was trained
latent_channels = 32
latent_size = 4

generator = dcgan.Generator(tile_size, latent_channels, latent_size)
generator.load_state_dict(torch.load(output_dir / "generator_final.pth"))
_ = generator.to("cuda")

In [None]:
"""
Generate some tiles
"""

n_gen = 25

gen_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_gen,
    noise_size=latent_size,
    device="cuda",
)
scaled = tiles.mean() + tiles.std() * (gen_tiles - gen_tiles.mean()) / (
    gen_tiles.std() + 1e-8
)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))

for axis, img in zip(axes.flat, scaled):
    im = axis.imshow(img, vmin=-0.5, vmax=0.5, cmap="seismic")
    im.set_extent([0, scaled.shape[1], 0, scaled.shape[2]])

fig.suptitle(f"Output size: {scaled[0].shape}")
fig.tight_layout()

cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

In [None]:
"""
Generate some tiles, larger than the training set, by passing the generator a slightly larger random latent vector
"""

big_gen_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_gen,
    noise_size=4 * latent_size,
    device="cuda",
)
big_scaled = tiles.mean() + tiles.std() * (big_gen_tiles - big_gen_tiles.mean()) / (
    big_gen_tiles.std() + 1e-8
)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))

for axis, img in zip(axes.flat, big_scaled):
    im = axis.imshow(img, vmin=-0.5, vmax=0.5, cmap="seismic")
    im.set_extent([0, big_scaled.shape[1], 0, big_scaled.shape[2]])

fig.suptitle(f"Output size: {big_scaled[0].shape}")
fig.tight_layout()

cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

Stitch these tiles together
----
As a comparison, stitch the tiles together using quilting - this will likely give us something less good than the native GAN outfilling.

In [None]:
"""
Compare real 128x128 patches to quilted ones to ones generated with our GAN

"""

from current_denoising.generation import quilting

real_tiles = ioutils.extract_tiles(
    residual,
    forbidden_mask=None,
    # Choose tiles with not much NaN in them, that also pass our RMS and FFT criteria
    tile_criterion=lambda tile: ioutils.select_tile(tile)
    and (np.sum(np.isnan(tile)) / len(tile.flat) < 0.2),
    max_latitude=latitude_threshhold,
    tile_size=128,
    allow_nan=True,
)

fig, axes = plt.subplots(4, 3, figsize=(9, 12))

for axs, gen_tile, real_tile in zip(axes, big_scaled, real_tiles):
    patches = dcgan.generate_tiles(
        generator,
        n_tiles=n_gen,
        noise_size=latent_size,
        device="cuda",
    )
    patches = tiles.mean() + tiles.std() * (patches - patches.mean()) / (
        patches.std() + 1e-8
    )

    torch.seed()
    quilt = quilting.quilt(
        patches, target_size=(128, 128), patch_overlap=4, repeat_penalty=0
    )

    kw = {"cmap": "seismic", "vmin": -0.5, "vmax": 0.5}
    axs[0].imshow(quilt, **kw)
    axs[1].imshow(real_tile, **kw)
    axs[2].imshow(gen_tile, **kw)

axes[0, 0].set_title("Quilted")
axes[0, 1].set_title("Real Noise")
axes[0, 2].set_title("Native GAN Outfilling")

fig.tight_layout()

Apply this noise to the clean data
----

In [None]:
clean_mdt = ioutils.read_clean_mdt(
    path=clean_filepath,
    metadata_path=clean_filepath.with_stem(clean_filepath.stem + "_meta").with_suffix(
        ".txt"
    ),
    year=2001,
    model="CMCC-CM2-HR4",
    remove_mean=True,
)

clean_residual = get_residual(clean_mdt, sigma_km)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(12, 12))
mdt_imshow(clean_mdt, axes[0], cmap="RdYlGn", vmin=-1.5, vmax=1.5)

axes[0].set_title("Clean MDT")

mdt_imshow(clean_residual, axes[1], cmap="seismic", vmin=-0.5, vmax=0.5)

axes[1].set_title("Clean Residual")

In [None]:
"""
Our noise should have a location-depdendent strength
"""

from current_denoising.generation import applying_noise

# TODO This just does a naive, latitude-unaware Gaussian smooth of the MDT gradient
# Is this the right thing for our strength map? I doubt it
# No we want to do an RMS/land-based one - it's very rough so it doesn't matter how accurate we are
grad = np.gradient(residual)
grad = np.sqrt(grad[0] ** 2 + grad[1] ** 2)
strength_map = applying_noise.noise_strength_map(grad, filter_size=10)

strength_map = np.ones_like(strength_map)

In [None]:
fig, axis = plt.subplots(figsize=(12, 6))

im = axis.imshow(strength_map)
fig.colorbar(im, ax=axis)

fig.suptitle("Noise strength map")
fig.tight_layout()

We will randomly extract patches of clean signal and apply our noise to them - these pairs will make our denoising training dataset.

In [None]:
"""
Extract tiles from the clean residual and apply GAN-generated noise to them
to build up our training dataset

Also extract the corresponding "true" noisy and clean tiles for later comparison
"""

from current_denoising.denoising import data
from sklearn.model_selection import train_test_split

denoising_patch_size = 128
max_nan_fraction = 0.3

rng = np.random.default_rng()

# Generate some noise tiles
n_noise_tiles = 32
noise_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_noise_tiles,
    noise_size=4 * latent_size,
    device="cuda",
)

# Rescale the noise tiles to have the same mean and std as the real noise
noise_tiles = noise_tiles - noise_tiles.mean()
noise_tiles = tiles.std() * noise_tiles / noise_tiles.std()

assert all(
    x == denoising_patch_size for x in noise_tiles.shape[1:]
), f"{noise_tiles.shape=}, {denoising_patch_size=}"

# Extract all the training pairs, and return indices
tile_pairs, indices = data.get_training_pairs(
    clean_residual,
    strength_map,
    noise_tiles,
    max_latitude=latitude_threshhold,
    max_nan_fraction=max_nan_fraction,
    rng=rng,
    return_indices=True,
)

clean_tiles, noisy_tiles = np.moveaxis(tile_pairs, 1, 0)

# Use these indices to get the corresponding real noisy tiles out
real_tiles = util.tiles_from_indices(residual, indices, denoising_patch_size)

(
    clean_train,
    clean_test,
    noisy_train,
    noisy_test,
    real_train,
    real_test,
    idx_train,
    idx_test,
) = train_test_split(clean_tiles, noisy_tiles, real_tiles, indices, train_size=0.8)

fig, axes = plt.subplots(3, 2, figsize=(16, 12))

# Put the tiles into a single numpy array for plotting
_plot_tiles(clean_train, idx_train, denoising_patch_size, axes[0, 0])
_plot_tiles(noisy_train, idx_train, denoising_patch_size, axes[1, 0])
_plot_tiles(real_train, idx_train, denoising_patch_size, axes[2, 0])

_plot_tiles(clean_test, idx_test, denoising_patch_size, axes[0, 1])
_plot_tiles(noisy_test, idx_test, denoising_patch_size, axes[1, 1])
_plot_tiles(real_test, idx_test, denoising_patch_size, axes[2, 1])

axes[0, 0].set_title("Training Data")
axes[0, 1].set_title("Test Data")

axes[0, 0].set_ylabel("Clean")
axes[1, 0].set_ylabel("Synthetic")
axes[2, 0].set_ylabel("Noisy (real)")

Train the denoising model
----
Train the denoising model with synthetic noise + real patches of clean MDT

In [None]:
"""
This is what the training and testing data looks like
"""
batch_size=4
train_config = data.DataConfig(train=True, batch_size=batch_size, num_workers=4)
train_loader = data.dataloader(clean_train, noisy_train, train_config)

test_config = data.DataConfig(train=False, batch_size=batch_size, num_workers=0)
test_loader = data.dataloader(clean_test, noisy_test, test_config)

In [None]:
"""
Plot some tiles
"""

fig, axes = plt.subplots(2, 4, figsize=(8, 4))

batch = next(iter(train_loader))
for ax_row, clean, noisy in zip(axes.T, batch[0], batch[1], strict=True):
    ax_row[0].imshow(clean.squeeze(), **kw)
    ax_row[0].imshow(torch.isnan(clean).squeeze(), cmap=maps.clear2black_cmap())

    ax_row[1].imshow(torch.isnan(noisy).squeeze(), cmap=maps.clear2black_cmap())
    ax_row[1].imshow(noisy.squeeze(), **kw)

    for ax in ax_row:
        ax.set_xticklabels([])
        ax.set_yticklabels([])

fig.suptitle("Training Data")
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(8, 4))

batch = next(iter(test_loader))
for ax_row, clean, noisy in zip(axes.T, batch[0], batch[1], strict=True):
    ax_row[0].imshow(clean.squeeze(), **kw)

    ax_row[1].imshow(noisy.squeeze(), **kw)

    for ax in ax_row:
        ax.set_xticklabels([])
        ax.set_yticklabels([])

fig.suptitle("Validation Data")
fig.tight_layout()

In [None]:
"""
Train the denoiser
"""

from current_denoising.denoising import model, train

net = model.get_attention_unet(4, 0.1)
net, train_loss, val_loss = train.train_model(
    net, "cuda", n_epochs=250, train_data=train_loader, val_data=test_loader
)

In [None]:
from current_denoising.plotting import training

fig = training.plot_losses(train_loss, val_loss)

Evaluate the denoising model on different noisy MDTs
----

In [None]:
"""
First show it on some test data
"""

n_test = 4
test_noise = dcgan.generate_tiles(
    generator, n_tiles=n_test, noise_size=4 * latent_size, device="cuda"
)
test_noise = test_noise - test_noise.mean()
test_noise = tiles.std() * test_noise / test_noise.std()

test_data = data.get_training_pairs(
    clean_residual, strength_map, test_noise, max_latitude=latitude_threshhold, rng=rng, max_nan_fraction=0.8
)

In [None]:
from current_denoising.denoising import inference

fig, axes = plt.subplots(3, 4, figsize=(12, 9))

for axs, (noisy, clean) in zip(axes.T, test_data):
    axs[0].imshow(clean, **kw)
    axs[1].imshow(noisy, **kw)

    axs[2].imshow(inference.denoise(net, noisy), **kw)

    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])

axes[0, 0].set_ylabel("Noisy (original)")
axes[1, 0].set_ylabel("Clean (target)")
axes[2, 0].set_ylabel("Denoised")
fig.suptitle("Validation Data")

In [None]:
"""
Denoise on some real noisy tiles

These patches were from the testing data - i.e. not in the regions used to train
the model
"""

fig, axes = plt.subplots(3, 4, figsize=(12, 9))

kw = {"vmin": -0.5, "vmax": 0.5, "cmap": "seismic"}
for axs, noisy, clean in zip(axes.T, real_test, clean_test):
    axs[0].imshow(noisy, **kw)
    axs[1].imshow(clean, **kw)

    denoised = inference.denoise(net, noisy).squeeze()
    axs[2].imshow(denoised, **kw)

    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])

axes[0, 0].set_ylabel("Noisy (original)")
axes[1, 0].set_ylabel("Clean (target)")
axes[2, 0].set_ylabel("Denoised")
fig.suptitle("Noisy Residual")