GAN Training
====
Train the GAN, make some monitoring plots and store it to disk.

In [None]:
import pathlib

# Need to mount the SING RDSF dir somewhere
rdsf_dir = pathlib.Path("~/geog_rdsf/").expanduser()

noisy_filepath = (
    rdsf_dir
    / "data"
    / "projects"
    / "SING"
    / "richard_stuff"
    / "Table2"
    / "dtu18_eigen-6c4_do0280_rr0004.dat"
)
clean_filepath = (
    rdsf_dir / "data" / "projects" / "dtop" / "cmip6" / "cmip6_historical_mdts_yr5.dat"
)

assert noisy_filepath.exists()
assert clean_filepath.exists()

In [None]:
import numpy as np
from current_denoising.generation import ioutils

# It's called read_currents, but actually just reads the array
noisy_mdt = ioutils.read_currents(noisy_filepath)
noisy_mdt[noisy_mdt == -1.9e19] = np.nan

clean_mdt = ioutils.read_clean_mdt(
    path=clean_filepath,
    metadata_path=clean_filepath.with_stem(clean_filepath.stem + "_meta").with_suffix(
        ".txt"
    ),
    year=2001,
    model="CMCC-CM2-HR4",
)

In [None]:
"""
Find the residual between the MDT and the smoothed MDT
"""

from current_denoising.utils import util

sigma_km = 200
residual = util.get_residual(noisy_mdt, sigma_km)

In [None]:
"""
Select some tiles from the residual based on some criteria
"""

from current_denoising.plotting import maps


rng = np.random.default_rng(seed=0)

tile_size = 32
num_tiles = 512
latitude_threshhold = 60.0
forbidden_mask = ioutils.distance_from_land(residual) < 20

# Get the tiles from our map based on the criteria
# Return the indices (their locations) for plotting
tiles, indices = ioutils.extract_tiles(
    residual,
    forbidden_mask=forbidden_mask,
    tile_criterion=ioutils.select_tile,
    max_latitude=latitude_threshhold,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
"""
Train a simple GAN on these tiles
"""

from current_denoising.generation import dcgan

dataset = dcgan.TileLoader(tiles)

In [None]:
import torch
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from scipy.stats import wasserstein_distance

from current_denoising.plotting import training, img_validation


def _gp_plot(training_metrics, lambda_gp, plot_dir):
    fig = training_metrics.plot_gp_wd_ratio(lambda_gp)

    fig.savefig(plot_dir / "gp_wd.png")
    plt.close(fig)


def _grad_plot(train_metrics, g_lr, d_lr, plot_dir):
    fig = train_metrics.plot_param_gradients(g_lr, d_lr)
    fig.savefig(plot_dir / "grads.png")
    plt.close(fig)


def _grad_norm_plot(train_metrics, plot_dir):
    fig = train_metrics.plot_critic_grad_norms()
    fig.savefig(plot_dir / "grad_norm.png")
    plt.close(fig)


def _imgs_plot(imgs, plot_dir, title):
    fig = img_validation.show(imgs, cmap="turbo")
    mappable = cm.ScalarMappable(
        norm=colors.Normalize(vmin=0.0, vmax=1.4), cmap="turbo"
    )
    mappable.set_array([])
    fig.colorbar(mappable, ax=fig.axes)
    fig.suptitle(f"Generated images")
    fig.savefig(plot_dir / "generated.png")
    plt.close(fig)


def _hist_plot(imgs, dataloader, plot_dir, title):
    fig, axis = plt.subplots(1, 1, figsize=(8, 5))
    hist_kw = {
        "bins": np.linspace(-1, 1, 150),
        "density": True,
        "alpha": 0.5,
        "histtype": "step",
        "linewidth": 2,
    }
    img_validation.hist(imgs, axis=axis, **hist_kw, label="Generated images")
    img_validation.hist(
        next(iter(dataloader)),
        axis=axis,
        **hist_kw,
        label="Real images",
        linestyle="dashed",
    )
    axis.set_title(f"Image Hists - {title}")
    axis.legend()
    fig.tight_layout()
    fig.savefig(plot_dir / "hists.png")
    plt.close(fig)


def _fft_plot(imgs, plot_dir, title):
    # Plot FFTs of the generated and real images
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    _, gen_fft = img_validation.fft(imgs, axis=axes[0])
    axes[0].set_title("Generated images FFT")

    _, real_fft = img_validation.fft(next(iter(dataloader)), axis=axes[1])
    axes[1].set_title(f"Real images FFT {title}")
    fig.tight_layout()
    fig.savefig(plot_dir / "ffts.png")
    plt.close(fig)

    fft_mse = np.mean((gen_fft - real_fft) ** 2)

    return fft_mse


def train_gan(
    hyperparams: dcgan.GANHyperParams,
    dataloader: torch.utils.data.DataLoader,
    img_size: int,
    batch_size: int,
    output_dir: pathlib.Path,
) -> tuple[torch.nn.Module, torch.nn.Module]:
    """
    Train the GAN and make lots of debug plots

    Returns G + D and some metrics
    """
    metrics = {}

    # Train the GAN
    (generator, discriminator, train_metrics) = dcgan.train_new_gan(
        dataloader,
        hyperparams,
        "cuda",
        img_size=img_size,
        output_dir=output_dir,
    )

    title = f"g_lr={hyperparams.g_lr=}, d_lr={hyperparams.d_lr}"

    # Plot training losses
    fig = train_metrics.plot_scores()
    fig.suptitle(title)
    fig.tight_layout()
    fig.savefig(output_dir / "losses.png")
    plt.close(fig)

    # Plot contributions of gradient penalty and Wasserstein distance to discriminator loss
    _gp_plot(train_metrics, hyperparams.lambda_gp, output_dir)

    # Plot gradients
    _grad_plot(train_metrics, hyperparams.g_lr, hyperparams.d_lr, output_dir)

    # Plot grad norm
    _grad_norm_plot(train_metrics, output_dir)

    # Generate some images and display them
    gen_imgs = generator.gen_imgs(batch_size, hyperparams.generator_latent_size)
    _imgs_plot(gen_imgs, output_dir, title)
    _hist_plot(gen_imgs, dataloader, output_dir, title)
    fft_mse = _fft_plot(gen_imgs, output_dir, title)

    # Track metrics
    metrics["mean_wd_gp_ratio"] = (
        hyperparams.lambda_gp
        * np.mean(train_metrics.gradient_penalties, axis=1)
        / np.mean(train_metrics.wasserstein_dists, axis=1)
    ).mean()

    # Don't want the first epochs - need to give the model some time to stabilise
    metrics["mean_gradient_ratio"] = abs(
        (
            (hyperparams.g_lr * train_metrics.generator_param_gradients[-30:])
            / (hyperparams.d_lr * train_metrics.critic_param_gradients[-30:])
        ).mean()
    )

    # Don't want the first epochs - need to give the model some time to stabilise
    metrics["avg_grad_norm"] = np.mean(train_metrics.critic_interp_grad_norms, axis=1)[
        -30:
    ].mean()

    metrics["hist_wasserstein"] = wasserstein_distance(
        gen_imgs.detach().cpu().numpy().flatten(),
        next(iter(dataloader)).cpu().numpy().flatten(),
    )
    metrics["fft_mse"] = fft_mse

    return generator, discriminator, metrics

In [None]:
batch_size = 16

output_dir = pathlib.Path("outputs/mdt_gan/test/")
output_dir.mkdir(parents=True, exist_ok=True)

hyperparams = dcgan.GANHyperParams(
    n_epochs=1000,
    g_lr=0.0002,
    d_lr=0.0002,
    n_critic=5,
    lambda_gp=20,
    generator_latent_channels=32,
    generator_latent_size=4,
    n_discriminator_blocks=4,
)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)
gen, critic, metrics = train_gan(
    hyperparams,
    dataloader,
    img_size=tile_size,
    batch_size=batch_size,
    output_dir=output_dir,
)

In [None]:
torch.save(gen.state_dict(), output_dir / "generator_final.pth")
torch.save(critic.state_dict(), output_dir / "discriminator_final.pth")

In [None]:
"""
Generate and display some example tiles
"""

n_gen = 25
gen_tiles = dcgan.generate_tiles(
    gen,
    n_tiles=n_gen,
    noise_size=hyperparams.generator_latent_size,
    device="cuda",
)

# Give them the right mean and std
scaled = tiles.mean() + tiles.std() * (gen_tiles - gen_tiles.mean()) / (
    gen_tiles.std() + 1e-8
)

fig, axes = plt.subplots(5, 5, figsize=(12, 12))
for axis, img in zip(axes.flat, scaled):
    im = axis.imshow(img, cmap="seismic", vmin=-0.5, vmax=0.5)
    im.set_extent([0, 32, 0, 32])

fig.tight_layout()

cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")

In [None]:
"""
Plot histograms of generated and training tiles - 
"""
fig, axis = plt.subplots()
kw = {"histtype": "step", "bins": np.linspace(-0.3, 0.3, 100), "density": True}
axis.hist(tiles.flat, **kw)
axis.hist(scaled.flat, **kw)

scaled.std(), tiles.std()

In [None]:
"""
Generate and display some example tiles
"""

n_gen = 25
gen_tiles = dcgan.generate_tiles(
    gen,
    n_tiles=n_gen,
    noise_size=4*hyperparams.generator_latent_size,
    device="cuda",
)

# Give them the right mean and std
scaled = tiles.mean() + tiles.std() * (gen_tiles - gen_tiles.mean()) / (
    gen_tiles.std() + 1e-8
)

fig, axes = plt.subplots(5, 5, figsize=(12, 12))
for axis, img in zip(axes.flat, scaled):
    im = axis.imshow(img, cmap="seismic", vmin=-0.5, vmax=0.5)
    im.set_extent([0, 32, 0, 32])

fig.tight_layout()

cax = fig.add_axes([1.05, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cax, label="Mean Dynamic Topography /m")