In [None]:
%load_ext autoreload
%autoreload 2

GAN Training
====
An illustration of the GAN training data/loop

We'll take some patches of pure-noise (from the deep ocean, with low RMS - we expect these tiles to be pure noise, since the signal is slowly varying)
and train a GAN to reproduce them.

In [None]:
"""
Take some tiles from our noisy data and display them, to illustrate what's going on
"""

import pathlib
import numpy as np
from current_denoising.plotting import maps
from current_denoising.generation import ioutils


# This is the "real" file the Laura made plots of in her paper doi:10.1017/eds.2023.41
filepath = pathlib.Path(
    "/home/mh19137/geog_rdsf/data/projects/SING/richard_stuff/Table2/currents/dtu18_eigen-6c4_do0280_rr0004_cs.dat"
)
assert filepath.exists()
data = ioutils.read_currents(filepath)

# Extract some tiles, rejecting the ones with high latitude or RMS
# An RMS of 0.20 is the 50th percentile in the dtu18_eigen-6c4_do0280_rr0004_cs data,
# and we use the same latitude threshold as Laura did
tile_size = 32  # in grid points
rms_threshold = 0.18
latitude_threshold = 64.0
rng = np.random.default_rng(1234)

tiles, indices = ioutils.extract_tiles(
    rng,
    data,
    num_tiles=512,
    max_rms=rms_threshold,
    max_latitude=64.0,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
import matplotlib.pyplot as plt

tile_grid = np.ones_like(data) * np.nan

for tile, (y, x) in zip(tiles, indices):
    tile_grid[y : y + tile_size, x : x + tile_size] = tile

fig, axis = plt.subplots(figsize=(10, 5))
maps.imshow(tile_grid, axis=axis)
axis.imshow(
    1.4 * np.isnan(data),
    cmap=maps.clear2black_cmap(),
    extent=[-180, 180, -90, 90],
    vmin=0,
    vmax=1.4,
)
fig.colorbar(axis.images[0], ax=axis, label="m/s")

In [None]:
"""
Turn our images into a dataloader with the right transforms
"""

from current_denoising.generation import dcgan

dataset = dcgan.TileLoader(tiles)

In [None]:
"""
We might want to train lots of slightly different models, so write big function for the training + monitoring plots

"""

import pathlib

import torch
import matplotlib.cm as cm
import matplotlib.colors as colors
from torch.autograd import Variable
from scipy.stats import wasserstein_distance

from current_denoising.generation import dcgan
from current_denoising.plotting import training, img_validation


def _gp_plot(training_metrics, lambda_gp, plot_dir):
    fig = training_metrics.plot_gp_wd_ratio(lambda_gp)

    fig.savefig(plot_dir / "gp_wd.png")
    plt.close(fig)


def _grad_plot(train_metrics, g_lr, d_lr, plot_dir):
    fig = train_metrics.plot_param_gradients(g_lr, d_lr)
    fig.savefig(plot_dir / "grads.png")
    plt.close(fig)


def _grad_norm_plot(train_metrics, plot_dir):
    fig = train_metrics.plot_critic_grad_norms()
    fig.savefig(plot_dir / "grad_norm.png")
    plt.close(fig)


def _imgs_plot(imgs, plot_dir, title):
    fig = img_validation.show(imgs, cmap="turbo")
    mappable = cm.ScalarMappable(
        norm=colors.Normalize(vmin=0.0, vmax=1.4), cmap="turbo"
    )
    mappable.set_array([])
    fig.colorbar(mappable, ax=fig.axes)
    fig.suptitle(f"Generated images")
    fig.savefig(plot_dir / "generated.png")
    plt.close(fig)


def _hist_plot(imgs, dataloader, plot_dir, title):
    fig, axis = plt.subplots(1, 1, figsize=(8, 5))
    hist_kw = {
        "bins": np.linspace(-1, 1, 150),
        "density": True,
        "alpha": 0.5,
        "histtype": "step",
        "linewidth": 2,
    }
    img_validation.hist(imgs, axis=axis, **hist_kw, label="Generated images")
    img_validation.hist(
        next(iter(dataloader)),
        axis=axis,
        **hist_kw,
        label="Real images",
        linestyle="dashed",
    )
    axis.set_title(f"Image Hists - {title}")
    axis.legend()
    fig.tight_layout()
    fig.savefig(plot_dir / "hists.png")
    plt.close(fig)


def _fft_plot(imgs, plot_dir, title):
    # Plot FFTs of the generated and real images
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    _, gen_fft = img_validation.fft(imgs, axis=axes[0])
    axes[0].set_title("Generated images FFT")

    _, real_fft = img_validation.fft(next(iter(dataloader)), axis=axes[1])
    axes[1].set_title(f"Real images FFT {title}")
    fig.tight_layout()
    fig.savefig(plot_dir / "ffts.png")
    plt.close(fig)

    print(real_fft.sum(), gen_fft.sum())
    fft_mse = np.mean((gen_fft - real_fft) **2)

    return fft_mse


def train_gan(
    hyperparams: dcgan.GANHyperParams,
    dataloader: torch.utils.data.DataLoader,
    img_size: int,
    batch_size: int,
    output_dir: pathlib.Path,
) -> tuple[torch.nn.Module, torch.nn.Module]:
    """
    Train the GAN and make lots of debug plots

    Returns G + D and some metrics
    """
    metrics = {}

    # Train the GAN
    (generator, discriminator, train_metrics) = dcgan.train_new_gan(
        dataloader,
        hyperparams,
        "cuda",
        img_size=img_size,
        output_dir=output_dir,
    )

    title = f"g_lr={hyperparams.g_lr=}, d_lr={hyperparams.d_lr}"

    # Plot training losses
    fig = train_metrics.plot_scores()
    fig.suptitle(title)
    fig.tight_layout()
    fig.savefig(output_dir / "losses.png")
    plt.close(fig)

    # Plot contributions of gradient penalty and Wasserstein distance to discriminator loss
    _gp_plot(train_metrics, hyperparams.lambda_gp, output_dir)

    # Plot gradients
    _grad_plot(train_metrics, hyperparams.g_lr, hyperparams.d_lr, output_dir)

    # Plot grad norm
    _grad_norm_plot(train_metrics, output_dir)

    # Generate some images and display them
    gen_imgs = dcgan._gen_imgs(generator, batch_size)
    _imgs_plot(gen_imgs, output_dir, title)
    _hist_plot(gen_imgs, dataloader, output_dir, title)
    fft_mse = _fft_plot(gen_imgs, output_dir, title)

    # Track metrics
    metrics["mean_wd_gp_ratio"] = (
        hyperparams.lambda_gp
        * np.mean(train_metrics.gradient_penalties, axis=1)
        / np.mean(train_metrics.wasserstein_dists, axis=1)
    ).mean()

    # Don't want the first epochs - need to give the model some time to stabilise
    metrics["mean_gradient_ratio"] = abs(
        (
            (hyperparams.g_lr * train_metrics.generator_param_gradients[-30:])
            / (hyperparams.d_lr * train_metrics.critic_param_gradients[-30:])
        ).mean()
    )

    # Don't want the first epochs - need to give the model some time to stabilise
    metrics["avg_grad_norm"] = np.mean(train_metrics.critic_interp_grad_norms, axis=1)[
        -30:
    ].mean()

    metrics["hist_wasserstein"] = wasserstein_distance(
        gen_imgs.detach().cpu().numpy().flatten(),
        next(iter(dataloader)).cpu().numpy().flatten(),
    )
    metrics["fft_mse"] = fft_mse

    return generator, discriminator, metrics

In [None]:
batch_size = 128

output_dir = pathlib.Path("outputs/gan_test/")
output_dir.mkdir(parents=True, exist_ok=True)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
)
# LR of 0.00002 is too low
# LR of 0.02 is too high
hyperparams = dcgan.GANHyperParams(
    n_epochs=5,
    g_lr=0.001,
    d_lr=0.0002,
    n_critic=5,
    lambda_gp=12,
    generator_latent_dim=64,
    n_discriminator_blocks=4,
)

train_gan(
    hyperparams,
    dataloader,
    img_size=tile_size,
    batch_size=batch_size,
    output_dir=output_dir,
)

Tuning
----
Here, we're looking for:
 - Gradient norm near 1
 - $\lambda$ GP / |W| mostly between 0.1 and 0.6
 - Wasserstein distance peaks at first, then small positive plateau (without oscillations)

We will perform a random search over some of our parameters so that we can find a decent set to train with.

In [None]:
"""
Train a lot of different GANs with different hyperparameters
"""

import pickle

plt.ioff()

n_runs = 500
n_epochs = 75

rng = np.random.default_rng(0)
metrics: list[dict] = []
out_dir = pathlib.Path("outputs/gan_tuning/")
out_dir.mkdir(parents=True, exist_ok=True)

for i in range(n_runs):
    hyperparams = dcgan.GANHyperParams(
        n_epochs=n_epochs,
        g_lr=10 ** rng.uniform(-5, -2),
        d_lr=10 ** rng.uniform(-5, -2),
        n_critic=rng.integers(1, 10),
        lambda_gp=np.random.uniform(0, 20),
        generator_latent_dim=64,
        n_discriminator_blocks=4,
    )

    output_dir = out_dir / str(i)
    output_dir.mkdir(parents=True, exist_ok=True)

    _, _, tmp = train_gan(
        hyperparams,
        dataloader,
        img_size=tile_size,
        batch_size=batch_size,
        output_dir=output_dir,
    )

    tmp["g_lr"] = hyperparams.g_lr
    tmp["d_lr"] = hyperparams.d_lr
    tmp["lambda_gp"] = hyperparams.lambda_gp
    tmp["n_critic"] = hyperparams.n_critic
    metrics.append(tmp)

with open("metrics.pkl", "wb") as f:
    pickle.dump(metrics, f)

In [None]:
"""
Plot the metrics
"""
import matplotlib.pyplot as plt
import pickle

# Lots of plots will have been made with interactive mode off, so close them all
# and then turn it back on
plt.close("all")
plt.ion()

with open("metrics.pkl", "rb") as f:
    metrics = pickle.load(f)

fig, axes = plt.subplots(1, 4, figsize=(12, 4))
param_names = ["g_lr", "d_lr", "lambda_gp", "n_critic"]
metric_names = [k for k in metrics[0] if k not in param_names]
for axis, key in zip(axes.flat, metric_names):
    axis.plot([d[key] for d in metrics], "o", markersize=1)
    axis.set_title(key)
    if "wd_gp" in key:
        axis.axhline(0.1, color="k", linestyle="dashed")
        axis.axhline(0.6, color="k", linestyle="dashed")
        axis.set_ylim(-3, 3)
    if "gradient_ratio" in key:
        axis.axhline(0.85, color="k", linestyle="dashed")
        axis.axhline(1.2, color="k", linestyle="dashed")
        axis.set_ylim(0, 2)
    if "grad_norm" in key:
        axis.axhline(0.9, color="k", linestyle="dashed")
        axis.axhline(1.1, color="k", linestyle="dashed")
        axis.set_ylim(0, 2)
    
axes[3].set_yscale("log")
fig.supxlabel("Run number")

fig.tight_layout()

In [None]:
"""
Plot the metrics against the hyperparameters
"""

fig, axes = plt.subplots(len(metric_names), len(param_names), figsize=(12, 12))
for ax, key in zip(axes, metric_names):
    for axis, param in zip(ax, param_names):
        sc = axis.scatter(
            [d[param] for d in metrics],
            [d[key] for d in metrics],
            alpha=0.7,
            label=param,
            s=1,
        )


for axis, title in zip(axes[0], param_names):
    axis.set_title(title)

for axis, title in zip(axes[:, 0], metric_names):
    axis.set_ylabel(title)
    axis.set_xscale("log")

for axis, title in zip(axes[:, 1], metric_names):
    axis.set_xscale("log")

for axis in axes[0]:
    axis.set_ylim(0, 1)
    axis.axhline(0.6, color="k", linestyle="dashed")
    axis.axhline(0.1, color="k", linestyle="dashed")

for a in axes[1:3]:
    for axis in a:
        axis.set_ylim(0, 2)
        axis.axhline(0.8, color="k", linestyle="dashed")
        axis.axhline(1.2, color="k", linestyle="dashed")

for axis in axes[3]:
    axis.set_yscale("log")

example_params = {"g_lr": 1e-3, "d_lr": 1e-3, "lambda_gp": 5, "n_critic": 7}
for col, param in zip(axes.T, example_params.values()):
    for axis in col:
        axis.axvline(param, color="r", linestyle="--", alpha=0.3)

fig.tight_layout()

In [None]:
"""
Make a new dataframe showing the divergence away from the expected values
"""

import pandas as pd

loss_df = pd.DataFrame(metrics)


def wd_gp_loss(x):
    """punish for being below 0.1 or above 0.6"""
    if 0.1 < x < 0.6:
        return 0
    if x <= 0.1:
        return 0.1 - x
    return x - 0.6


loss_df["wd_gp_loss"] = loss_df["mean_wd_gp_ratio"].apply(wd_gp_loss)

loss_df["grad_ratio_loss"] = loss_df["mean_gradient_ratio"].apply(lambda x: abs(x - 1))

loss_df["grad_norm_loss"] = loss_df["avg_grad_norm"].apply(lambda x: abs(x - 1))
loss_df["wasserstein_loss"] = loss_df["hist_wasserstein"]

loss_df = loss_df.drop(columns=metric_names)
loss_column_names = loss_df.columns[4:]
loss_df

In [None]:
"""
Plot 1d scatter plots of the losses against run number
"""

fig, axes = plt.subplots(1, 4, figsize=(12, 4))
for axis, key in zip(axes.flat, loss_column_names):
    axis.plot(loss_df[key], "o", markersize=1)
    axis.set_title(key)

axes[0].set_ylim(-1, 5)
axes[1].set_ylim(-1, 5)
axes[2].set_ylim(-1, 3)
axes[3].set_yscale("log")
fig.supxlabel("Run number")

In [None]:
"""
Plot a pairplot-esque grid of losses and hyperparams
"""

scatter_kw = {"s": 5}

for loss_name in loss_column_names:
    fig, axes = plt.subplots(2, 3, figsize=(9, 6))
    for ax, m1 in zip(axes, param_names):
        for axis, m2 in zip(ax, [p for p in param_names if p != m1]):
            im = axis.scatter(
                loss_df[m2], loss_df[m1], c=loss_df[loss_name], **scatter_kw
            )
            axis.set_xlabel(m2)
            axis.set_ylabel(m1)

            if m1.endswith("lr"):
                axis.set_yscale("log")
            if m2.endswith("lr"):
                axis.set_xscale("log")

    fig.suptitle(loss_name)

    fig.tight_layout()

    cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
    fig.colorbar(im, cax=cax)

In [None]:
"""
Plot a 2d scatter plot of hist wasserstein distance against the LRs
"""
from matplotlib.colors import LogNorm
# TODO also do this with some metric for the FFTs

fig, axis = plt.subplots()

scatter = axis.scatter(
    [d["g_lr"] for d in metrics],
    [d["d_lr"] for d in metrics],
    c=[d["hist_wasserstein"] for d in metrics],
    s=10,
    norm=LogNorm(),
    cmap="plasma_r"
)
axis.loglog()
fig.colorbar(scatter, label="Hist Wasserstein Distance")
axis.set_xlabel("g_lr")
axis.set_ylabel("d_lr")

fig.tight_layout()

In [None]:
"""
Plot a 3D surface of hist wasserstein distance against the LRs
"""

import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")

g_lrs = loss_df["g_lr"]
d_lrs = loss_df["d_lr"]
wds = loss_df["wasserstein_loss"]

log_g_lrs = np.log10(g_lrs)
log_d_lrs = np.log10(d_lrs)

surf = ax.plot_trisurf(
    log_g_lrs,
    log_d_lrs,
    wds,
    cmap="magma",
    alpha=0.5,
)

ax.set_xlabel("log10(g_lr)")
ax.set_ylabel("log10(d_lr)")
ax.set_zlabel("Hist Wasserstein Distance")
fig.colorbar(surf, ax=ax, label="Hist Wasserstein Distance", shrink=0.5)

ax.view_init(azim=45)

fig.tight_layout()

In [None]:
"""
Now that we have a sensible d_g_lr_ratio, train a longer run
"""

assert False

In [None]:
import torch

batch_size = 128
config = {
    "n_epochs": 300,
    "n_critic": 5,
    "lambda_gp": 13.7,
    "learning_rate": 0.005,
    "d_g_lr_ratio": 12.7,
    "latent_dim": 64,
    "img_size": tile_size,
    "channels": 1,
    "batch_size": batch_size,
    "dataloader": torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=8
    ),
    "output_dir": pathlib.Path("outputs/gan_final/"),
    "plot_interval": 10,
}

generator, discriminator, _ = train_gan(config, config["output_dir"])

In [None]:
torch.save(generator.state_dict(), config["output_dir"] / "generator_final.pth")
torch.save(discriminator.state_dict(), config["output_dir"] / "discriminator_final.pth")