GAN Training Data
====

An illustration of where the training data for the GAN comes from, and what we do with the generated patches

Generation
----
We generate training data tiles by randomly sampling them from the noisy current gridded field.
Then we train a GAN to generate new tiles (not shown here), that we can stitch together to apply to clean data to train our denoising model.

In [None]:
import pathlib

# This is the "real" file the Laura made plots of in her paper doi:10.1017/eds.2023.41
filepath = pathlib.Path(
    "/home/mh19137/geog_rdsf/data/projects/SING/richard_stuff/Table2/currents/dtu18_eigen-6c4_do0280_rr0004_cs.dat"
)
assert filepath.exists()

In [None]:
from current_denoising.generation import ioutils

data = ioutils.read_currents(filepath)

In [None]:
import matplotlib.pyplot as plt

from current_denoising.plotting import maps

fig, axis = plt.subplots(1, 1, figsize=(12, 6))

# This isn't a great colourmap but it makes the plot match up with Figure 3 in Laura's paper
imshow_kw = {"origin": "upper", "vmin": 0, "vmax": 1.4, "cmap": "turbo"}
im = axis.imshow(data, **imshow_kw)

axis.set_title(filepath.stem)
fig.colorbar(im)

lat, long = maps.lat_long_grid(data.shape)
extent = [long[0], long[-1], lat[0], lat[-1]]
im.set_extent(extent)

fig.tight_layout()

In [None]:
""" Extact and plot some tiles"""
import numpy as np
from current_denoising.generation import ioutils

rng = np.random.default_rng(1234)

tile_size = 32
tiles, indices = ioutils.extract_tiles(
    rng,
    data,
    num_tiles=16,
    max_rms=np.inf,
    max_latitude=np.inf,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
def plot_tiles(tiles, indices) -> plt.Figure:
    fig, axes = plt.subplots(4, 4, figsize=(9, 9))
    for axis, tile, (y, x) in zip(axes.flat, tiles, indices):
        im = axis.imshow(tile, **imshow_kw)
        im.set_extent([long[x], long[x + tile_size], lat[y], lat[y + tile_size]])

    return fig


fig = plot_tiles(tiles, indices)
fig.suptitle("Example patches")


### Choosing the right patches
We don't want to take out patches from regions too near the poles (the gridded field is distorted here); neither do we want to take patches which contain land or high-current information (e.g. around the islands in the Carribean).

To see the distributions of RMS values in the data, see [here](./plot_tile_rms.ipynb).
We can exclude the regions of high RMS, and sample our tiles again:


In [None]:
# Tiles rejecting the ones with high latitude or RMS
rms_threshold = (
    0.20  # This is the 50th percentile in the dtu18_eigen-6c4_do0280_rr0004_cs data
)
latitude_threshold = 64.0
tiles, indices = ioutils.extract_tiles(
    rng,
    data,
    num_tiles=1024,
    max_rms=rms_threshold,
    max_latitude=64.0,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
# Plot them on a world map
from matplotlib import colors


def clear2black_cmap() -> colors.Colormap:
    """
    Colormap that varies from clear to black
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba("black", alpha=1)
    return colors.ListedColormap([c_white, c_black], "clear2black")


tile_grid = np.ones_like(data) * np.nan

for tile, (y, x) in zip(tiles, indices):
    tile_grid[y : y + tile_size, x : x + tile_size] = tile

fig, axis = plt.subplots(1, 1, figsize=(12, 6))
axis.imshow(
    ~np.isnan(data),
    origin="upper",
    cmap=clear2black_cmap(),
    extent=(long[0], long[-1], lat[0], lat[-1]),
)
im = axis.imshow(tile_grid, **imshow_kw)

im.set_extent([long[0], long[-1], lat[0], lat[-1]])


for t in (latitude_threshold, -latitude_threshold):
    axis.axhline(t, color="r", linestyle="--")
axis.text(-198, -latitude_threshold, f"-{latitude_threshold}" + r"$\degree$", color="r")

fig.colorbar(im, ax=axis)
fig.suptitle(f"Extracted patches; patch RMS {rms_threshold=:.2f}m/s")
fig.tight_layout()

These are the tiles that are used to train the GAN.

I can't demonstrate this for now, since all the GPU servers are down, but we would train a GAN on these tiles and generate new ones. Let's pretend we have done this...

Application to data
----
Once we have generated noise tiles, we will want to apply them to the data.

The first step will be to quilt several noise tiles into a larger region - let's take a "batch" of 24 8x8 degree noise tiles and stitch them together to make a 32x32 degree quilt...

In [None]:
from current_denoising.generation import quilting

tiles = ioutils.extract_tiles(
    rng, data, num_tiles=1024, max_rms=rms_threshold, max_latitude=64.0, tile_size=tile_size
)

First let's see how it looks if we just stick the tiles together naively:

In [None]:
naive_tiles = tiles[:16, :, :].reshape(4, 4, tile_size, tile_size)
naive_quilt = quilting.naive_quilt(naive_tiles, patch_overlap=0, target_size=(128, 128))


def plot_quilt(quilt, *, axis):
    if axis is None:
        _, axis = plt.subplots(figsize=(6, 6))
    axis.imshow(quilt, cmap="turbo", vmin=0, vmax=5 * rms_threshold)
    axis.set_xticks([])
    axis.set_yticks([])


fig, axis = plt.subplots(figsize=(6, 6))
plot_quilt(naive_quilt, axis=axis)
fig.suptitle("Naive stitching - sharp boundaries between tiles")
fig.tight_layout()

Then we can do a slightly better "quilt" by matching up patches so that similar edges appear together, and by stitching them using an optimal seam through their overlap region.

For comparison, we'll also plot some "real" quilts - a some patches of noise from the real data that have the same size as our stitched-together quilts.

In [None]:
plt.close(fig)
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

target_size = 128

real_quilts = ioutils.extract_tiles(
    rng, data, num_tiles=3, max_rms=rms_threshold, max_latitude=64, tile_size=target_size
)

for real_quilt, (ax1, ax2) in zip(real_quilts, axes.T):
    # Randomise the order of tiles, so that we get a different one in the top corner
    rng.shuffle(tiles)
    quilt = quilting.quilt(
        tiles,
        target_size=(target_size, target_size),
        patch_overlap=4,
        allow_rotation=False,
        repeat_penalty=0,
    )
    plot_quilt(quilt, axis=ax1)

    plot_quilt(real_quilt, axis=ax2)


axes[0, 0].set_ylabel("Quilted patches")
axes[1, 0].set_ylabel("Real noise")

fig.suptitle("Quilted patches vs real noise")
fig.tight_layout()

### Adding to the signal

### Weighting