Train Denoising Model
====

First we'll look at the data used to train the model, then we'll train it.

The model itself is a basic U-Net (with an attention mechanism); we'll train it on pairs of tiles from our clean MDT and noisy MDT residuals.

In [None]:
import pathlib

# Need to mount the SING RDSF dir somewhere
rdsf_dir = pathlib.Path("~/geog_rdsf/").expanduser()

noisy_filepath = (
    rdsf_dir
    / "data"
    / "projects"
    / "SING"
    / "richard_stuff"
    / "Table2"
    / "dtu18_eigen-6c4_do0280_rr0004.dat"
)
clean_filepath = (
    rdsf_dir / "data" / "projects" / "dtop" / "cmip6" / "cmip6_historical_mdts_yr5.dat"
)

assert noisy_filepath.exists()
assert clean_filepath.exists()

In [None]:
import numpy as np
from current_denoising.generation import ioutils

# It's called read_currents, but actually just reads the array
noisy_mdt = ioutils.read_currents(noisy_filepath)
noisy_mdt[noisy_mdt == -1.9e19] = np.nan

clean_mdt = ioutils.read_clean_mdt(
    path=clean_filepath,
    metadata_path=clean_filepath.with_stem(clean_filepath.stem + "_meta").with_suffix(
        ".txt"
    ),
    year=2001,
    model="CMCC-CM2-HR4",
    remove_mean=True,
)

In [None]:
"""
Find the residuals
"""

from current_denoising.utils import util

sigma_km = 200

noisy_residual = util.get_residual(noisy_mdt, sigma_km)
clean_residual = util.get_residual(clean_mdt, sigma_km)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import CenteredNorm

from current_denoising.plotting import maps


def mdt_imshow(current_grid: np.ndarray, axis: plt.Axes, norm=None, **kwargs):
    """Imshow for MDTs - no extent set"""
    lat, long = util.lat_long_grid(current_grid.shape)
    extent = kwargs.get("extent", [long[0], long[-1], lat[0], lat[-1]])

    imshow_kw = {
        "origin": "upper",
        "cmap": "seismic",
    }
    imshow_kw.update(kwargs)
    imshow_kw["extent"] = extent

    im = axis.imshow(current_grid, norm=norm, **imshow_kw)
    im.set_extent(extent)

    # Plot also the land
    axis.imshow(
        np.isnan(current_grid),
        cmap=maps.clear2black_cmap(),
        extent=extent,
        vmin=0,
        vmax=1,
        origin="upper",
    )

    # Add a colorbar
    axis.get_figure().colorbar(im, ax=axis)


fig, axes = plt.subplots(2, 1, figsize=(18, 18))

mdt_imshow(noisy_residual, axes[0], norm=CenteredNorm(vcenter=0.0, halfrange=0.5))
mdt_imshow(clean_residual, axes[1], norm=CenteredNorm(vcenter=0.0, halfrange=0.5))

axes[0].set_ylabel("Noisy")
axes[1].set_ylabel("Clean")

fig.suptitle(f"Residual: smoothed $\sigma=${sigma_km}km")

The signal is readily apparent in the clean (lower) plot. Important regions include the large currents near the East coast of the USA, around Japan, around Australia and around Madagascar.

We will not train our denoiser directly on these samples, since this will leave us with no statistically independent testing data.
We will instead train a GAN model to generate tiles of noise (see [the other notebook](./2_train_gan.ipynb)), apply these noise tiles
to tiles of our clean residual and train the denoiser on this.

It may be apparent from the noisy (upper) plot above that the noise strength is not uniform everywhere - it roughly scales with the
size of the signal, and some other geographic features like proximity
to the coast.
To model this, we will apply a noise strength factor to our synthetic noise tiles - in the first instance, this is just the
distance from the coast (such that regions nearer the coast have
more noise applied).

In [None]:
"""
TODO - something reasonable here
"""

strength_map = np.ones_like(noisy_residual)

plt.imshow(strength_map)
plt.colorbar()

plt.title("Noise strength map")

plt.axis("off")
plt.tight_layout()

We will extract tiles from the clean residual and apply GAN generated noise to them:

In [None]:
"""
Plot some e
"""

import warnings

import torch

from current_denoising.generation import dcgan
from current_denoising.denoising import data

# We need the noise tiles (that were used for GAN training)
# so that we can rescale the generated noise to the right
# mean/std
tile_size = 32
latitude_threshhold = 60.0
forbidden_mask = ioutils.distance_from_land(noisy_residual) < 20

tiles, indices = ioutils.extract_tiles(
    noisy_residual,
    forbidden_mask=forbidden_mask,
    tile_criterion=ioutils.select_tile,
    max_latitude=latitude_threshhold,
    tile_size=tile_size,
    return_indices=True,
)

latent_channels = 32
latent_size = 4
n_gen = 4

generator = dcgan.Generator(tile_size, latent_channels, latent_size)
generator.load_state_dict(
    torch.load(pathlib.Path("outputs/mdt_gan/test/") / "generator_final.pth")
)
_ = generator.to("cuda")

# Generate some GAN tiles
# By passing a larger latent size to the generator, we get
# larger output - this also tells us what size tiles to extract.
gen_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_gen,
    noise_size=4 * latent_size,
    device="cuda",
)
gen_tiles = tiles.mean() + tiles.std() * (gen_tiles - gen_tiles.mean()) / (
    gen_tiles.std() + 1e-8
)

# Apply noise
max_nan_fraction = 0.3
rng = np.random.default_rng()

# We will get warned about not having enough tiles; we don't care
# so suppress this
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    tile_pairs, indices = data.get_training_pairs(
        clean_residual,
        strength_map,
        gen_tiles,
        max_latitude=latitude_threshhold,
        max_nan_fraction=max_nan_fraction,
        rng=rng,
        return_indices=True,
    )

# Get the real noisy tiles out at the right positions, to compare
real_tiles = util.tiles_from_indices(noisy_residual, indices, gen_tiles.shape[1])

fig, axes = plt.subplots(4, 4, figsize=(12, 12))
kw = {"origin": "upper", "cmap": "seismic", "vmin": -0.5, "vmax": 0.5}
for axs, clean, gen_tile, noisy, real in zip(
    axes.T, tile_pairs[:, 0], gen_tiles, tile_pairs[:, 1], real_tiles
):
    axs[0].imshow(clean, **kw)
    axs[1].imshow(gen_tile, **kw)
    axs[2].imshow(noisy, **kw)
    axs[3].imshow(real, **kw)

    for axis in axs:
        axis.set_xticks([])
        axis.set_yticks([])

axes[0, 0].set_ylabel("Clean")
axes[1, 0].set_ylabel("Synthetic Noise")
axes[2, 0].set_ylabel("Synthetic Noisy")
axes[3, 0].set_ylabel("Real Noisy")

Now that we've seens some examples of the training data and how it's built, we'll make our actual dataloaders.

In [None]:
"""
Generate some training data + perform train/test split
"""

from sklearn.model_selection import train_test_split

n_gen = 196  # Needs to be > the number of tiles
batch_size = 4  # Should probably be small - we drop the last batch each iteration

gen_tiles = dcgan.generate_tiles(
    generator,
    n_tiles=n_gen,
    noise_size=4 * latent_size,
    device="cuda",
)
gen_tiles = tiles.mean() + tiles.std() * (gen_tiles - gen_tiles.mean()) / (
    gen_tiles.std() + 1e-8
)

tile_pairs, indices = data.get_training_pairs(
    clean_residual,
    strength_map,
    gen_tiles,
    max_latitude=latitude_threshhold,
    max_nan_fraction=max_nan_fraction,
    rng=rng,
    return_indices=True,
)
clean_tiles, noisy_tiles = np.moveaxis(tile_pairs, 1, 0)

# Get the real noisy tiles out at the right positions, to compare
real_tiles = util.tiles_from_indices(noisy_residual, indices, gen_tiles.shape[1])

(
    clean_train,
    clean_test,
    noisy_train,
    noisy_test,
    real_train,
    real_test,
    idx_train,
    idx_test,
) = train_test_split(clean_tiles, noisy_tiles, real_tiles, indices, train_size=0.8)

train_config = data.DataConfig(train=True, batch_size=batch_size, num_workers=4)
train_loader = data.dataloader(clean_train, noisy_train, train_config)

test_config = data.DataConfig(train=False, batch_size=batch_size, num_workers=0)
test_loader = data.dataloader(clean_test, noisy_test, test_config)

In [None]:
"""
Plot some tiles - the training data has augmentations applied
"""

fig, axes = plt.subplots(2, 4, figsize=(8, 4))

batch = next(iter(train_loader))
for ax_row, clean, noisy in zip(axes.T, batch[0], batch[1], strict=True):
    ax_row[0].imshow(clean.squeeze(), **kw)
    ax_row[0].imshow(torch.isnan(clean).squeeze(), cmap=maps.clear2black_cmap())

    ax_row[1].imshow(torch.isnan(noisy).squeeze(), cmap=maps.clear2black_cmap())
    ax_row[1].imshow(noisy.squeeze(), **kw)

    for ax in ax_row:
        ax.set_xticklabels([])
        ax.set_yticklabels([])

fig.suptitle("Training Data")
fig.tight_layout()

In [None]:
"""
Plot validation data (which doesn't have augmentations applied)
"""
fig, axes = plt.subplots(2, 4, figsize=(8, 4))

batch = next(iter(test_loader))
for ax_row, clean, noisy in zip(axes.T, batch[0], batch[1], strict=True):
    ax_row[0].imshow(clean.squeeze(), **kw)

    ax_row[1].imshow(noisy.squeeze(), **kw)

    for ax in ax_row:
        ax.set_xticklabels([])
        ax.set_yticklabels([])

fig.suptitle("Validation Data")
fig.tight_layout()

In [None]:
from current_denoising.denoising import model, train

net = model.get_attention_unet(4, 0.1)
net, train_loss, val_loss = train.train_model(
    net, "cuda", n_epochs=250, train_data=train_loader, val_data=test_loader
)

In [None]:
from current_denoising.plotting import training

fig = training.plot_losses(train_loss, val_loss)

In [None]:
"""
Save the denoiser
"""
output_dir = pathlib.Path("outputs/mdt_gan/test/")
torch.save(net.state_dict(), output_dir / "denoiser.pth")

In [None]:
"""
Denoise on some real noisy tiles

These patches were from the testing data - i.e. not in the regions used to train
the model
"""
from current_denoising.denoising import inference

fig, axes = plt.subplots(3, 4, figsize=(12, 9))

kw = {"vmin": -0.5, "vmax": 0.5, "cmap": "seismic"}
for axs, noisy, clean in zip(axes.T, real_test, clean_test):
    axs[0].imshow(noisy, **kw)
    axs[1].imshow(clean, **kw)

    denoised = inference.denoise(net, noisy).squeeze()
    axs[2].imshow(denoised, **kw)

    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])

axes[0, 0].set_ylabel("Noisy (original)")
axes[1, 0].set_ylabel("Clean (target)")
axes[2, 0].set_ylabel("Denoised")
fig.suptitle("Noisy Residual")