PyTorch UNet implementation using IceNet library for data download and post-processing of sea ice forecasting.

This notebook has been designed to be independent of other notebooks.

### Highlights
The key features of this notebook are:
* [1. Download](#1.-Download) 
* [2. Data Processing](#2.-Data-Processing)
* [3. Train](#3.-Train)
* [4. Prediction](#4.-Prediction)
* [5. Outputs and Plotting](#5.-Outputs-and-Plotting)

Please note that this notebook relies on a pytorch data loader implementation which is only available from icenet v0.2.8+.

To install the necessary python packages, you can use the conda `icenet-notebooks/pytorch/environment.yml` environment file on a Linux system to be able to set-up the necessary pytorch + tensorflow + CUDA + other modules which could be a tricky mix to get working manually:

```bash
conda env create -f environment.yml
```

### Contributions
#### PyTorch implementation of UnetDiffusion
Maria Carolina Novitasari

#### PyTorch implementation of IceNet

Andrew McDonald ([icenet-gan](https://github.com/ampersandmcd/icenet-gan))

Bryn Noel Ubald (Refactor, updates for daily predictions and matching icenet library)

#### Notebook
Bryn Noel Ubald (author)

#### PyTorch Integration
Bryn Noel Ubald

Ryan Chan

### How to Download Daily Data for IceNet

#### DOWNLOAD SIC Data  

To download Sea Ice Concentration (SIC) data, modify the script below with the desired date range:

```python
sic = SICDownloader(
    dates=[
        pd.to_datetime(date).date()  # Dates to download SIC data for
        for date in pd.date_range("2020-01-01", "2020-12-31", freq="D")
    ],
    delete_tempfiles=True,           # Delete temporary downloaded files after use
    north=False,                     # Use mask for the Northern Hemisphere (set to True if needed)
    south=True,                      # Use mask for the Southern Hemisphere
    parallel_opens=True,             # Enable parallel processing with dask.delayed
)

sic.download()
```

#### Download ERA5 Data  

##### Setup ERA5 API

Use the following link to set up the ERA5 API: [https://cds.climate.copernicus.eu/how-to-api?](https://cds.climate.copernicus.eu/how-to-api?).

Run the following script with your desired dates:

#### ERA5 Downloader  

```python
import pandas as pd
from icenet.data.interfaces.cds import ERA5Downloader

era5 = ERA5Downloader(
    var_names=["tas", "zg", "uas", "vas"],      # Name of variables to download
    dates=[                                     # Dates to download the variable data for
        pd.to_datetime(date).date()
        for date in pd.date_range("2020-01-01", "2020-12-31", freq="D")
    ],
    path="./data",                              # Location to download data to (default is `./data`)
    delete_tempfiles=True,                      # Whether to delete temporary downloaded files
    levels=[None, [250, 500], None, None],      # The levels at which to obtain the variables for (e.g. for zg, it is the pressure levels)
    max_threads=4,                              # Maximum number of concurrent downloads
    north=False,                                # Boolean: Whether require data across northern hemisphere
    south=True,                                 # Boolean: Whether require data across southern hemisphere
    use_toolbox=False)                          # Experimental, alternative download method

era5.download()                                 # Start downloading
```

The prototype data currently in use (South Pole, 2020) can be downloaded from **Baskerville** at the following path: `/vjgo8416-ice-frcst/shared/prototype_data/`

In [None]:
import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple
from torchmetrics import Metric
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
from torchmetrics import MetricCollection
from torch_ema import ExponentialMovingAverage
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

# We also set the logging level so that we get some feedback from the API
import logging
logging.basicConfig(level=logging.INFO)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_float32_matmul_precision('medium')

In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
print(os.getcwd())

from models import GaussianDiffusion, UNetDiffusion, LitDiffusion
from trainers import train_diffusion_icenet

In [None]:
from datetime import datetime
import sys

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
sys.stdout = open(f'training_logs/training_log_{timestamp}.txt', 'w')

## 1. Download

In [None]:
import numpy
from icenet.data.sic.mask import Masks
from icenet.data.interfaces.cds import ERA5Downloader
from icenet.data.sic.osisaf import SICDownloader

In [None]:
# Unset SLURM_NTASKS if it's causing issues
if "SLURM_NTASKS" in os.environ:
    del os.environ["SLURM_NTASKS"]

# Optionally, set SLURM_NTASKS_PER_NODE if needed
os.environ["SLURM_NTASKS_PER_NODE"] = "1"  # or whatever value is appropriate

### Mask data

Create masks for masking data.

In [None]:
masks = Masks(north=False, south=True)
masks.generate(save_polarhole_masks=False)

### Climate and Sea Ice data

Download climate variables from ERA5 and sea ice concentration from OSI-SAF.

In [None]:
era5 = ERA5Downloader(
    var_names=["tas", "zg", "uas", "vas"],
    levels=[None, [250, 500], None, None],
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    max_threads=64,
    north=False,
    south=True,
    # NOTE: there appears to be a bug with the toolbox API at present (icenet#54)
    use_toolbox=False
)

# era5.download()

In [None]:
sic = SICDownloader(
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    north=False,
    south=True,
    parallel_opens=False,
)

# sic.download()

Re-grid ERA5 reanalysis data, and rotate wind vector data from ERA5 to align with EASE2 projection.

In [None]:
era5.regrid()
era5.rotate_wind_data()

## 2. Data Processing

Process downloaded datasets.

To make life easier, setting up train, val, test dates.

In [None]:
processing_dates = dict(
    train=[pd.to_datetime(el) for el in pd.date_range("2020-01-01", "2020-03-31")],
    val=[pd.to_datetime(el) for el in pd.date_range("2020-04-03", "2020-04-23")],
    test=[pd.to_datetime(el) for el in pd.date_range("2020-04-01", "2020-04-02")],
)
processed_name = "notebook_api_pytorch_data"

Next, we create the data producer and configure them for the dataset we want to create.

In [None]:
from icenet.data.processors.era5 import IceNetERA5PreProcessor
from icenet.data.processors.meta import IceNetMetaPreProcessor
from icenet.data.processors.osi import IceNetOSIPreProcessor

pp = IceNetERA5PreProcessor(
    ["uas", "vas"],
    ["tas", "zg500", "zg250"],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

osi = IceNetOSIPreProcessor(
    ["siconca"],
    [],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

meta = IceNetMetaPreProcessor(
    processed_name,
    north=False,
    south=True
)

Next, we initialise the data processors using `init_source_data` which scans the data source directories to understand what data is available for processing based on the parameters. Since we named the processed data `"notebook_api_data"` above, it will create a data loader config file, `loader.notebook_api_data.json`, in the current directory.

In [None]:
# Causes hanging on training, when generating sample.
pp.init_source_data(
    lag_days=1,
)
pp.process()

osi.init_source_data(
    lag_days=1,
)
osi.process()

meta.process()

At this point the preprocessed data is ready to convert or create a configuration for the network dataset.

### Dataset creation

As with the `icenet_dataset_create` command we can create a dataset configuration for training the network. As before this can include cached data for the network in the format of a TFRecordDataset compatible set of tfrecords. To achieve this we create the `IceNetDataLoader`, which can both generate `IceNetDataSet` configurations (which easily provide the necessary functionality for training and prediction) as well as individual data samples for direct usage.

In [None]:
from icenet.data.loaders import IceNetDataLoaderFactory

implementation = "dask"
loader_config = "loader.notebook_api_pytorch_data.json"
dataset_name = "notebook_api_pytorch_data"
lag = 1

dl = IceNetDataLoaderFactory().create_data_loader(
    implementation,
    loader_config,
    dataset_name,
    lag,
    n_forecast_days=7,
    north=False,
    south=True,
    output_batch_size=1,
    generate_workers=4)

At this point we can either use `generate` or `write_dataset_config_only` to produce a ready-to-go `IceNetDataSet` configuration. Both of these will generate a dataset config, `dataset_config.notebook_api_pytorch_data.json` (recall we set the dataset name as `notebook_api_pytorch_data` above).

In this case, for pytorch, will read data in directly, rather than using cached tfrecords inputs.

In [None]:
dl.write_dataset_config_only()

We can now create the IceNetDataSet object:

In [None]:
from icenet.data.dataset import IceNetDataSetPyTorch
dataset_config = f"dataset_config.{dataset_name}.json"

## 3. Train

We implement a custom PyTorch class for training.

## IceNet2 U-Net Diffusion model

Maria's work (PyTorch Diffusion using U-Net)

In [None]:
class Interpolate(nn.Module):
    def __init__(self, scale_factor, mode):
        super().__init__()
        self.interp = F.interpolate
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode)
        return x

Define custom metrics for use in validation and monitoring

Define custom loss functions

A _LightningModule_ wrapper for UNetDiffusion model.

Function for training UNetDiffusion model using PyTorch Lightning.

Conduct actual training run.

In [None]:
seed = 45

# Training configuration
learning_rate = 1e-5 #1e-4
max_epochs = 1 #500
filter_size = 3
n_filters_factor = 0.5
timesteps = 1000

batch_size = 8 #8 #4 #16 #32 #64 #16
shuffle = False
persistent_workers=True
num_workers = 8

print("batch_size...", batch_size)
print("num_workers...", num_workers)

# Print all training parameters
print("Training parameters:")
print(f"  configuration_path: {dataset_config}")
print(f"  learning_rate: {learning_rate}")
print(f"  max_epochs: {max_epochs}")
print(f"  batch_size: {batch_size}")
print(f"  n_workers: {num_workers}")
print(f"  filter_size: {filter_size}")
print(f"  n_filters_factor: {n_filters_factor}")
print(f"  seed: {seed}")
print(f"  timesteps: {timesteps}")

# Call the training function
model, trainer, checkpoint_callback = train_diffusion_icenet(
    configuration_path=dataset_config,
    learning_rate=learning_rate,
    max_epochs=max_epochs,
    batch_size=batch_size,
    n_workers=num_workers,
    filter_size=filter_size,
    n_filters_factor=n_filters_factor,
    seed=seed,
    timesteps=timesteps,
    persistent_workers=persistent_workers
) 


## 4. Prediction

Predicts using the best checkpoint from the training.

In [None]:
checkpoint_callback.best_k_models

In [None]:
best_checkpoint = checkpoint_callback.best_model_path
best_checkpoint

In [None]:
# Load the best result from the checkpoint
# best_model = LitUNet.load_from_checkpoint(best_checkpoint)

#mc
best_model = LitDiffusion.load_from_checkpoint(best_checkpoint)

# disable randomness, dropout, etc...
best_model.eval()

In [None]:
test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="test")
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers,
                             persistent_workers=persistent_workers, shuffle=False)

# automatically load the best weights (if best_model isn't added)
trainer.test(dataloaders=test_dataloader)

In [None]:
# # cosine results
# [{'test_loss': 0.48640450835227966,
#   'test_accuracy': 51.62199783325195,
#   'test_accuracy_0': 51.56485366821289,
#   'test_accuracy_1': 51.77238464355469,
#   'test_accuracy_2': 52.235984802246094,
#   'test_accuracy_3': 51.62677764892578,
#   'test_accuracy_4': 51.081172943115234,
#   'test_accuracy_5': 51.31882858276367,
#   'test_accuracy_6': 51.75397872924805,
#   'test_sieerror': 729145600.0,
#   'test_sieerror_0': 104656872.0,
#   'test_sieerror_1': 104813752.0,
#   'test_sieerror_2': 104287504.0,
#   'test_sieerror_3': 104360624.0,
#   'test_sieerror_4': 104103128.0,
#   'test_sieerror_5': 103494376.0,
#   'test_sieerror_6': 103429376.0}]

In [None]:
logging.info("Generating predictions")

predictions = trainer.predict(best_model, dataloaders=test_dataloader)

In [None]:
for worker, prediction in enumerate(predictions):
    print(f"Worker: {worker} | Prediction: {prediction.shape}")

## 5. Outputs and Plotting

Create prediction output directory

In [None]:
# dataset = "pytorch_notebook"
network_name = "api_pytorch_dataset"
output_name = "example_pytorch_forecast_diff"
output_folder = os.path.join(".", "results", "predict", output_name,
                                "{}.{}".format(network_name, seed))
os.makedirs(output_folder, exist_ok=output_folder)

Convert and output predictions to numpy files

In [None]:
idx = 0
for workers, prediction in enumerate(predictions):
    for batch in range(prediction.shape[0]):
        date = pd.Timestamp(test_dataset.dates[idx].replace('_', '-'))
        output_path = os.path.join(output_folder, date.strftime("%Y_%m_%d.npy"))
        print("prediction shape...",prediction.shape)
        # forecast = prediction[batch, :, :, :, :].movedim(-2, 0)
        forecast = prediction[batch, :, :, :].movedim(-1, 0)
        forecast_np = forecast.detach().cpu().numpy()
        np.save(output_path, forecast_np)
        idx += 1

In [None]:
forecast.shape

Create a csv file with all the test dates we have predicted for, and to use in generating the final netCDF output using `icenet_output`.

In [None]:
# !printf "2020-04-01\n2020-04-02" | tee testdates_diff.csv

In [None]:
# !icenet_output -m -o results/predict example_pytorch_forecast_diff notebook_api_pytorch_data testdates_diff.csv

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import matplotlib.pyplot as plt
import os
from datetime import datetime

# Change this to the actual version dir
log_dir = "lightning_logs/version_1065922" #_batch8_sxx8"
# version_1064080" #version_1063301"
# log_dir = f"lightning_logs/version_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

# Load the logs
event_acc = EventAccumulator(log_dir)
event_acc.Reload()

# List all scalar tags to find the correct name
print("Available tags:", event_acc.Tags()['scalars'])

# Get the scalar events for val_loss
val_loss_events = event_acc.Scalars('val_loss')

# FIX: Use index as epoch number instead of .step
steps = list(range(1, len(val_loss_events) + 1))
values = [e.value for e in val_loss_events]

# Plot
plt.figure(figsize=(8, 5))
plt.plot(steps, values, label='Validation Loss', color='blue')
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss Over Training")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Update with your actual path
log_dir = "lightning_logs/version_1065922" #_stop_needto_finetune"
# log_dir = f"lightning_logs/version_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

event_acc = EventAccumulator(log_dir)
event_acc.Reload()

# List all available scalar tags
print("Available scalar tags:")
print(event_acc.Tags()['scalars'])  # e.g. val_accuracy, val_accuracy_0, etc.

# Get accuracy for all lead times (overall accuracy)
accuracy_events = event_acc.Scalars('val_accuracy')

# FIX: Use index as epoch number instead of .step
steps = list(range(1, len(accuracy_events) + 1))
values = [e.value for e in accuracy_events]

# Plot
plt.figure(figsize=(8, 5))
plt.plot(steps, values, label='Validation Accuracy', color='green')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy Over Epochs")
plt.legend()
plt.grid(True)
plt.show()


Plotting the forecast

In [None]:
import xarray as xr
import datetime as dt
from IPython.display import HTML

In [None]:
# from icenet.plotting.video import xarray_to_video as xvid
# from icenet.data.sic.mask import Masks

# ds = xr.open_dataset("results/predict/example_pytorch_forecast_diff.nc")
# land_mask = Masks(south=True, north=False).get_land_mask()
# ds.info()

Animate result

In [None]:
# forecast_date = ds.time.values[0]
# fc = ds.sic_mean.isel(time=0).drop_vars("time").rename(dict(leadtime="time"))
# fc['time'] = [pd.to_datetime(forecast_date) \
#               + dt.timedelta(days=int(e)) for e in fc.time.values]

# anim = xvid(fc, 15, figsize=(4,4), mask=land_mask)
# HTML(anim.to_jshtml())

Check min/max of predicted SIC fraction

In [None]:
# print( forecast_np[:, :, :, 0].shape )
# fmin, fmax = np.min(forecast_np[:, :, :, 0]), np.max(forecast_np[:, :, :, 0])
# print( f"First forecast day min: {fmin:.4f}, max: {fmax:.4f}" )

#### Load original input dataset

This is the original input dataset (pre-normalisation) for comparison.

In [None]:
# # Load original input dataset (domain not normalised)
# xr.plot.contourf(xr.open_dataset("data/osisaf/south/siconca/2020.nc").isel(time=92).ice_conc, levels=50)

## Version
- IceNet Codebase: v0.2.8