In [None]:
%load_ext autoreload
%autoreload 2

GAN Training
====
An illustration of the GAN training data/loop

We'll take some patches of pure-noise (from the deep ocean, with low RMS - we expect these tiles to be pure noise, since the signal is slowly varying)
and train a GAN to reproduce them.

In [None]:
"""
Take some tiles from our noisy data and display them, to illustrate what's going on
"""

import pathlib
import numpy as np
from current_denoising.plotting import maps
from current_denoising.generation import ioutils


# This is the "real" file the Laura made plots of in her paper doi:10.1017/eds.2023.41
filepath = pathlib.Path(
    "/home/mh19137/geog_rdsf/data/projects/SING/richard_stuff/Table2/currents/dtu18_eigen-6c4_do0280_rr0004_cs.dat"
)
assert filepath.exists()
data = ioutils.read_currents(filepath)

# Extract some tiles, rejecting the ones with high latitude or RMS
# An RMS of 0.20 is the 50th percentile in the dtu18_eigen-6c4_do0280_rr0004_cs data,
# and we use the same latitude threshold as Laura did
tile_size = 32  # in grid points
rms_threshold = 0.18
latitude_threshold = 64.0
rng = np.random.default_rng(1234)

tiles, indices = ioutils.extract_tiles(
    rng,
    data,
    num_tiles=512,
    max_rms=rms_threshold,
    max_latitude=64.0,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
import matplotlib.pyplot as plt

tile_grid = np.ones_like(data) * np.nan

for tile, (y, x) in zip(tiles, indices):
    tile_grid[y : y + tile_size, x : x + tile_size] = tile

fig, axis = plt.subplots(figsize=(10, 5))
maps.imshow(tile_grid, axis=axis)
axis.imshow(
    1.4 * np.isnan(data),
    cmap=maps.clear2black_cmap(),
    extent=[-180, 180, -90, 90],
    vmin=0,
    vmax=1.4,
)

In [None]:
"""
Turn our images into a dataloader with the right transforms
"""

from current_denoising.generation import dcgan

dataset = dcgan.TileLoader(tiles)

In [None]:
import pathlib
import torch

batch_size = 64
config = {
    "n_epochs": 100,
    "n_critic": 5,
    "lambda_gp": 10,
    "learning_rate": 5e-5,
    "d_g_lr_ratio": 4,
    "latent_dim": 64,
    "img_size": tile_size,
    "channels": 1,
    "batch_size": batch_size,
    "dataloader": torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=8
    ),
    "output_dir": pathlib.Path("outputs/gan/"),
    "plot_interval": 10,
}
if not config["output_dir"].is_dir():
    config["output_dir"].mkdir(parents=True, exist_ok=True)

In [None]:
"""
We might want to train lots of slightly different models, so write big function for the training + monitoring plots

"""

import matplotlib.cm as cm
import matplotlib.colors as colors
from torch.autograd import Variable

from current_denoising.generation import dcgan
from current_denoising.plotting import training, img_validation


def _fid_plot(gen_loss, disc_loss, fid_scores, config, plot_dir):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    _ = training.plot_losses(
        gen_loss,
        disc_loss,
        labels=("Generator Loss", "Discriminator Loss"),
        axis=axes[0],
    )

    axes[1].plot([20 * i for i, _ in enumerate(fid_scores)], fid_scores)

    axes[1].set_title("fid_score")
    title = f"lr={config['learning_rate']}, d_g_lr_ratio={config['d_g_lr_ratio']}"
    axes[0].set_title("Losses")
    fig.savefig(plot_dir / "fid.png")


def _gp_plot(gps, w_dists, config, plot_dir):
    fig, axes = plt.subplots(2, 1, figsize=(15, 5))

    gp_arr = config["lambda_gp"] * np.array(gps)
    w_dist_arr = np.array(w_dists)

    x = np.arange(len(w_dists))

    axes[0].plot(x, gp_arr.mean(axis=1), label="Gradient Penalty")
    axes[0].fill_between(x, gp_arr.min(axis=1), gp_arr.max(axis=1), alpha=0.2)

    axes[0].plot(x, w_dist_arr.mean(axis=1), label="Wasserstein Distance")
    axes[0].fill_between(x, w_dist_arr.min(axis=1), w_dist_arr.max(axis=1), alpha=0.2)

    axes[1].plot(gp_arr.mean(axis=1) / w_dist_arr.mean(axis=1), color="C2")
    axes[1].axhline(0.1, color="k", linestyle="dashed")
    axes[1].axhline(0.6, color="k", linestyle="dashed")

    axes[1].set_title("Ratio; high -> GP dominates, low -> WD dominates")
    axes[0].legend()
    fig.tight_layout()
    fig.savefig(plot_dir / "gp_wd.png")


def _grad_plot(g_grads, d_grads, config, plot_dir):
    fig, axes = plt.subplots(2, 1, figsize=(15, 5))
    axes[0].plot(g_grads, label="Generator Gradients", color="C0")
    axes[0].plot(d_grads, label="Discriminator Gradients", color="C1")
    axes[1].plot(
        config["d_g_lr_ratio"] * np.array(g_grads) / d_grads,
        color="C2",
        label="ratio",
    )
    axes[1].axhline(0.85, color="k", linestyle="dashed")
    axes[1].axhline(1.2, color="k", linestyle="dashed")
    for axis in axes:
        axis.legend()
    fig.tight_layout()
    fig.savefig(plot_dir / "grads.png")


def _grad_norm_plot(grad_norms, config, plot_dir):
    fig, axis = plt.subplots()
    grad_norms = np.array(grad_norms)
    x = np.arange(len(grad_norms))
    axis.plot(x, np.mean(grad_norms, axis=1), color="C0")
    axis.fill_between(
        x,
        np.min(grad_norms, axis=1),
        np.max(grad_norms, axis=1),
        alpha=0.2,
        color="C0",
    )
    axis.axhline(0.9, color="k", linestyle="dashed")
    axis.axhline(1.1, color="k", linestyle="dashed")
    axis.set_title(f"Gradient Norm")
    fig.tight_layout()
    fig.savefig(plot_dir / "grad_norm.png")


def train_gan(config: dict, plot_dir: pathlib.Path):
    """
    Train the GAN and make lots of debug plots
    """
    generator = dcgan.Generator(config)
    discriminator = dcgan.Discriminator(config)

    # Train the GAN
    (
        generator,
        discriminator,
        gen_loss,
        disc_loss,
        fid_scores,
        w_dists,
        gps,
        g_grads,
        d_grads,
        grad_norms,
    ) = dcgan.train(generator, discriminator, config)

    # Plot training losses + FID
    _fid_plot(gen_loss, disc_loss, fid_scores, config, plot_dir)

    # Plot contributions of gradient penalty and Wasserstein distance to discriminator loss
    _gp_plot(gps, w_dists, config, plot_dir)

    # Plot gradients
    _grad_plot(g_grads, d_grads, config, plot_dir)

    # Plot grad norm
    _grad_norm_plot(grad_norms, config, plot_dir)

    # Generate some images and display them
    z_g = Variable(
        torch.cuda.FloatTensor(
            np.random.normal(0, 1, (batch_size, config["latent_dim"]))
        )
    )
    gen_imgs = generator(z_g)

    fig = img_validation.show(gen_imgs, cmap="turbo")
    mappable = cm.ScalarMappable(
        norm=colors.Normalize(vmin=0.0, vmax=1.4), cmap="turbo"
    )
    mappable.set_array([])
    fig.colorbar(mappable, ax=fig.axes)
    fig.savefig(plot_dir / "generated.png")

    # Plot histograms of the generated and real images
    fig, axis = plt.subplots(1, 1, figsize=(8, 5))
    hist_kw = {
        "bins": np.linspace(-1, 1, 150),
        "density": True,
        "alpha": 0.5,
        "histtype": "step",
        "linewidth": 2,
    }
    img_validation.hist(gen_imgs, axis=axis, **hist_kw, label="Generated images")
    axes[0].set_title(f"Generated images - {title}")
    img_validation.hist(
        next(iter(config["dataloader"])),
        axis=axis,
        **hist_kw,
        label="Real images",
        linestyle="dashed",
    )
    axis.set_title("Image Hists - " + title)
    axis.legend()
    fig.tight_layout()
    fig.savefig(plot_dir / "hists.png")

    # Plot FFTs of the generated and real images
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    img_validation.fft(gen_imgs, axis=axes[0])
    fig.suptitle(title)
    axes[0].set_title("Generated images FFT")

    img_validation.fft(next(iter(config["dataloader"])), axis=axes[1])
    axes[1].set_title("Real images FFT")
    fig.tight_layout()
    fig.savefig(plot_dir / "ffts.png")

In [None]:
config["n_epochs"] = 40
train_gan(config, config["output_dir"])

Tuning Learning Rate
----
Here, we're looking for:
 - FID drops early (first 15 epochs)
 - Gradient norm near 1
 - $\lambda$ GP / |W| mostly between 0.1 and 0.6
 - Wasserstein distance peaks at first, then small positive plateau (without oscillations)

In [None]:
%%capture
"""
Train a lot of different GANs with different learning rates
"""

config["n_epochs"] = 41
lrs = [
    1e-3,
    7e-4,
    5e-4,
    3e-4,
    # 2e-4,
    # 1e-4,
    # 5e-5,
]
for lr in lrs:
    config["learning_rate"] = lr
    config["output_dir"] = pathlib.Path(f"outputs/gan_lr_study/lr_{lr:.0e}/")
    if not config["output_dir"].is_dir():
        config["output_dir"].mkdir(parents=True, exist_ok=True)
    train_gan(config, config["output_dir"])

Balance Pass
----
We now want to tune the ratio of the critic/generator's learning rates.

Here we will look for:
 - Effective step size (grads.png) eatio near 1 (roughly - between 0.3 and 2.0)
 - FID still decreasing
 - W still peaks then decreases, without too much noise
 - $\lambda$ GP/|W| still between 0.1 and 0.6

In [None]:
%%capture
"""
Now that I have an LR to start with, train a few GANs with different d_g_lr_ratio
"""
config["learning_rate"] = 0.01
for ratio in [0.5, 1, 2, 4, 8, 16]:
    config["d_g_lr_ratio"] = ratio
    config["output_dir"] = pathlib.Path(f"outputs/gan_dg_ratio_study/ratio_{ratio:.1f}/")
    if not config["output_dir"].is_dir():
        config["output_dir"].mkdir(parents=True, exist_ok=True)
    train_gan(config, config["output_dir"])

In [None]:
"""
Now that we have a sensible d_g_lr_ratio, train a longer run
"""
config["learning_rate"] = 0.005
config["d_g_lr_ratio"] = 9
config["n_epochs"] = 800
config["output_dir"] = pathlib.Path("outputs/gan_final/")
train_gan(config, config["output_dir"])