In [1]:
# import all packages used in this notebook
# add importable modules from src to system path for use in this notebook
import os
import sys
src_path = os.path.abspath(os.path.join(".."))
if src_path not in sys.path:
    sys.path.append(src_path)
import src
import wandb
from torchvision.ops.focal_loss import sigmoid_focal_loss
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np

/users/anddon76/icenet/icenet-gan/notebooks


# Forecast
Let's continue by generating forecasts with trained models in this notebook.

As before, we maintain alignment with the [`icenet-paper` repository](https://github.com/tom-andersson/icenet-paper).

### 1. Construct test dataset
We'll need data to make a forecast. Let's load our held-out test data to evaluate our models' capacity for generalisation.

In [2]:
test_dataset = src.IceNetDataset(f"{src.dataloader_config_folder}/2023_06_24_1235_icenet_gan.json", mode="test")
test_dataset.obs_forecast_IDs

Setting the data generator's random seed to 42
Checking forecast start dates for missing SIC dates... Setting up the variable paths for dataset_no_cmip... Done.
Setting the number of input months for each input variable.
Loading and augmenting the polar holes... Done in 0s.

on_epoch_end called
Setup complete.



[Timestamp('2018-02-01 00:00:00'),
 Timestamp('2018-03-01 00:00:00'),
 Timestamp('2018-04-01 00:00:00'),
 Timestamp('2018-05-01 00:00:00'),
 Timestamp('2018-06-01 00:00:00'),
 Timestamp('2018-07-01 00:00:00'),
 Timestamp('2018-08-01 00:00:00'),
 Timestamp('2018-09-01 00:00:00'),
 Timestamp('2018-10-01 00:00:00'),
 Timestamp('2018-11-01 00:00:00'),
 Timestamp('2018-12-01 00:00:00'),
 Timestamp('2019-01-01 00:00:00'),
 Timestamp('2019-02-01 00:00:00'),
 Timestamp('2019-03-01 00:00:00'),
 Timestamp('2019-04-01 00:00:00'),
 Timestamp('2019-05-01 00:00:00'),
 Timestamp('2019-06-01 00:00:00')]

### 2. Load trained models
We'll also need our trained models to make forecasts.

Our best gan training run is [here](https://wandb.ai/andrewmcdonald/icenet-gan/runs/wq09bzy7/overview?workspace=user-andrewmcdonald) with colloquial name `radiant-sponge-59-great-unet` and hash name `wq09bzy7`.

Our best GAN training run is [here](https://wandb.ai/andrewmcdonald/icenet-gan/runs/4iuiyi32/overview?workspace=user-andrewmcdonald) with colloquial name `stilted-armadillo-99-great-onestep-gan` and hash name `4iuiyi32`.

In [3]:
api = wandb.Api()
unet_name = "wq09bzy7"
unet_run = api.run(f"{src.config.WANDB_USERNAME}/icenet-gan/{unet_name}")
gan_name = "4iuiyi32"
gan_run = api.run(f"{src.config.WANDB_USERNAME}/icenet-gan/{gan_name}")


In [4]:
unet_run.config

{'name': 'default',
 'seed': 42,
 'model': 'unet',
 'devices': 1,
 'criterion': 'focal',
 'n_workers': 8,
 'precision': 16,
 'batch_size': 10,
 'max_epochs': 100,
 'accelerator': 'auto',
 'filter_size': 3,
 'fast_dev_run': False,
 'learning_rate': 0.0001,
 'n_to_visualise': 1,
 'n_filters_factor': 1,
 'dataloader_config': '2023_06_24_1235_icenet_gan.json',
 'limit_val_batches': 1,
 'log_every_n_steps': 10,
 'limit_train_batches': 1,
 'num_sanity_val_steps': 1}

In [5]:
# construct blank model
unet = src.UNet(
    input_channels=test_dataset.tot_num_channels,
    filter_size=unet_run.config["filter_size"],
    n_filters_factor=unet_run.config["n_filters_factor"],
    n_forecast_months=test_dataset.config["n_forecast_months"]
)

# construct criteria
if unet_run.config["criterion"] == "ce":
    criterion = nn.CrossEntropyLoss(reduction="none")
elif unet_run.config["criterion"] == "focal":
    criterion = sigmoid_focal_loss  # reduction="none" by default

# lightning will load checkpointed weights onto this model
unet_best_ckpt = f"{src.config.WANDB_DIR}/radiant-sponge-59-great-unet/checkpoints/best-epoch=11-step=456.ckpt"
lit_module_unet = src.LitUNet.load_from_checkpoint(unet_best_ckpt, model=unet, criterion=criterion)
lit_module_unet

LitUNet(
  (model): UNet(
    (conv1a): Conv2d(50, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv1b): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2a): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2b): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3a): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv3b): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv4a): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv4b): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine

In [6]:
gan_run.config

{'name': 'default',
 'seed': 42,
 'model': 'gan',
 'sigma': 1,
 'devices': 1,
 'criterion': 'focal',
 'n_workers': 8,
 'precision': 16,
 'batch_size': 2,
 'max_epochs': 100,
 'accelerator': 'auto',
 'd_lr_factor': 1,
 'filter_size': 3,
 'fast_dev_run': False,
 'learning_rate': 0.0001,
 'n_to_visualise': 1,
 'generator_lambda': 500,
 'n_filters_factor': 1,
 'dataloader_config': '2023_06_24_1235_icenet_gan.json',
 'limit_val_batches': 1,
 'log_every_n_steps': 10,
 'discriminator_mode': 'onestep',
 'limit_train_batches': 1,
 'num_sanity_val_steps': 1,
 'discriminator_criterion': 'ce',
 'generator_fake_criterion': 'ce',
 'generator_structural_criterion': 'ce'}

In [7]:
# construct blank models
generator = src.Generator(
    input_channels=test_dataset.tot_num_channels,
    filter_size=gan_run.config["filter_size"],
    n_filters_factor=gan_run.config["n_filters_factor"],
    n_forecast_months=test_dataset.config["n_forecast_months"]
)
discriminator = src.Discriminator(
    input_channels=test_dataset.tot_num_channels,
    filter_size=gan_run.config["filter_size"],
    n_filters_factor=gan_run.config["n_filters_factor"],
    n_forecast_months=test_dataset.config["n_forecast_months"],
    mode=gan_run.config["discriminator_mode"]
)

# construct criteria
generator_fake_criterion = nn.BCEWithLogitsLoss(reduction="none")

if gan_run.config["generator_structural_criterion"] == "l1":
    generator_structural_criterion = nn.L1Loss(reduction="none")
elif gan_run.config["generator_structural_criterion"] == "ce":
    generator_structural_criterion = nn.CrossEntropyLoss(reduction="none")
elif gan_run.config["generator_structural_criterion"] == "focal":
    generator_structural_criterion = sigmoid_focal_loss

discriminator_criterion = nn.BCEWithLogitsLoss(reduction="none")

# lightning will load checkpointed weights onto this model
gan_best_ckpt = f"{src.config.WANDB_DIR}/stilted-armadillo-99-great-onestep-gan/checkpoints/best-epoch=7-step=3024-v1.ckpt"
lit_module_gan = src.LitGAN.load_from_checkpoint(
    gan_best_ckpt,
    generator=generator,
    discriminator=discriminator,
    generator_fake_criterion=generator_fake_criterion,
    generator_structural_criterion=generator_structural_criterion,
    discriminator_criterion=discriminator_criterion
)
lit_module_gan

LitGAN(
  (generator): Generator(
    (conv1a): Conv2d(51, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv1b): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2a): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2b): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3a): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv3b): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv4a): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv4b): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn4): BatchNorm2d(512, eps=1e-05, momentum=0.1

### 3. Pass data through trained models
Let's pass our held-out test data through each trained model and visualise the output tensors.

We'll check our available GPU memory with `!nvidia-smi` and will enable `torch.no_grad()` to save memory since we won't need gradients in this setting.

In [8]:
test_dataloader_unet = DataLoader(test_dataset, batch_size=unet_run.config["batch_size"], num_workers=8,
                                  persistent_workers=True, pin_memory=False, shuffle=False)
test_dataloader_gan = DataLoader(test_dataset, batch_size=gan_run.config["batch_size"], num_workers=8,
                                  persistent_workers=True, pin_memory=False, shuffle=False)

In [9]:
!nvidia-smi

Tue Jun 27 17:30:06 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A2           On   | 00000000:98:00.0 Off |                    0 |
|  0%   37C    P0    19W /  60W |   1382MiB / 15356MiB |     21%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [10]:
# pass batches through unet and accumulate into list
y_hat_unet = []
with torch.no_grad():
    for batch in test_dataloader_unet:
        x, y, sample_weight = batch
        pred_unet = lit_module_unet(x.to(lit_module_unet.device)).detach().cpu().numpy()
        y_hat_unet.extend(pred_unet)
y_hat_unet = np.array(y_hat_unet)
y_hat_unet.shape



(17, 432, 432, 3, 6)

In [11]:
# pass batches through gan and accumulate into list
y_hat_gan = []
with torch.no_grad():
    for batch in test_dataloader_gan:
        x, y, sample_weight = batch
        pred_gan = lit_module_gan(x.to(lit_module_gan.device)).detach().cpu().numpy()
        y_hat_gan.extend(pred_gan)
y_hat_gan = np.array(y_hat_gan)
y_hat_gan.shape

(17, 432, 432, 3, 6)

In [12]:
torch.cuda.empty_cache()  # free pytorch memory
!nvidia-smi

Tue Jun 27 17:30:23 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A2           On   | 00000000:98:00.0 Off |                    0 |
|  0%   44C    P0    45W /  60W |    616MiB / 15356MiB |     24%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### 4. Format and save forecasts
Let's save our forecasts as netCDF files before we move to evaluate them in the next notebook.

This will allow us to extend our downstream evaluation methods without needing to load and run our computationally-intensive models each time we want to obtain a forecast.

This will also make our forecasts accessible to others who may be interested in sea ice but not in the technical nuances of deep learning.

### 5. All set
With our forecasts safe and sound inside netCDF files, we're clear to move forward.

We'll continue by evaluating these forecasts in the next notebook.