PyTorch UNet implementation using IceNet library for data download and post-processing of sea ice forecasting.

This notebook has been designed to be independent of other notebooks.

### Highlights
The key features of this notebook are:
* [1. Download](#1.-Download) 
* [2. Data Processing](#2.-Data-Processing)
* [3. Train](#3.-Train)
* [4. Prediction](#4.-Prediction)
* [5. Outputs and Plotting](#5.-Outputs-and-Plotting)

Please note that this notebook relies on a pytorch data loader implementation which is only available from icenet v0.2.8+.

To install the necessary python packages, you can use the conda `icenet-notebooks/pytorch/environment.yml` environment file on a Linux system to be able to set-up the necessary pytorch + tensorflow + CUDA + other modules which could be a tricky mix to get working manually:

```bash
conda env create -f environment.yml
```

### Contributions
#### PyTorch implementation of UnetDiffusion
Maria Carolina Novitasari

#### PyTorch implementation of IceNet

Andrew McDonald ([icenet-gan](https://github.com/ampersandmcd/icenet-gan))

Bryn Noel Ubald (Refactor, updates for daily predictions and matching icenet library)

#### Notebook
Bryn Noel Ubald (author)

#### PyTorch Integration
Bryn Noel Ubald

Ryan Chan

### How to Download Daily Data for IceNet

#### DOWNLOAD SIC Data  

To download Sea Ice Concentration (SIC) data, modify the script below with the desired date range:

```python
sic = SICDownloader(
    dates=[
        pd.to_datetime(date).date()  # Dates to download SIC data for
        for date in pd.date_range("2020-01-01", "2020-12-31", freq="D")
    ],
    delete_tempfiles=True,           # Delete temporary downloaded files after use
    north=False,                     # Use mask for the Northern Hemisphere (set to True if needed)
    south=True,                      # Use mask for the Southern Hemisphere
    parallel_opens=True,             # Enable parallel processing with dask.delayed
)

sic.download()
```

#### Download ERA5 Data  

##### Setup ERA5 API

Use the following link to set up the ERA5 API: [https://cds.climate.copernicus.eu/how-to-api?](https://cds.climate.copernicus.eu/how-to-api?).

Run the following script with your desired dates:

#### ERA5 Downloader  

```python
import pandas as pd
from icenet.data.interfaces.cds import ERA5Downloader

era5 = ERA5Downloader(
    var_names=["tas", "zg", "uas", "vas"],      # Name of variables to download
    dates=[                                     # Dates to download the variable data for
        pd.to_datetime(date).date()
        for date in pd.date_range("2020-01-01", "2020-12-31", freq="D")
    ],
    path="./data",                              # Location to download data to (default is `./data`)
    delete_tempfiles=True,                      # Whether to delete temporary downloaded files
    levels=[None, [250, 500], None, None],      # The levels at which to obtain the variables for (e.g. for zg, it is the pressure levels)
    max_threads=4,                              # Maximum number of concurrent downloads
    north=False,                                # Boolean: Whether require data across northern hemisphere
    south=True,                                 # Boolean: Whether require data across southern hemisphere
    use_toolbox=False)                          # Experimental, alternative download method

era5.download()                                 # Start downloading
```

The prototype data currently in use (South Pole, 2020) can be downloaded from **Baskerville** at the following path: `/vjgo8416-ice-frcst/shared/prototype_data/`

In [None]:
import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple
from torchmetrics import Metric
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
from torchmetrics import MetricCollection

# We also set the logging level so that we get some feedback from the API
import logging
logging.basicConfig(level=logging.INFO)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_float32_matmul_precision('medium')

In [None]:
from datetime import datetime
import sys

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
sys.stdout = open(f'training_log_{timestamp}.txt', 'w')

## 1. Download

In [None]:
import numpy
from icenet.data.sic.mask import Masks
from icenet.data.interfaces.cds import ERA5Downloader
from icenet.data.sic.osisaf import SICDownloader

In [None]:
# Unset SLURM_NTASKS if it's causing issues
if "SLURM_NTASKS" in os.environ:
    del os.environ["SLURM_NTASKS"]

# Optionally, set SLURM_NTASKS_PER_NODE if needed
os.environ["SLURM_NTASKS_PER_NODE"] = "1"  # or whatever value is appropriate

### Mask data

Create masks for masking data.

In [None]:
masks = Masks(north=False, south=True)
masks.generate(save_polarhole_masks=False)

### Climate and Sea Ice data

Download climate variables from ERA5 and sea ice concentration from OSI-SAF.

In [None]:
era5 = ERA5Downloader(
    var_names=["tas", "zg", "uas", "vas"],
    levels=[None, [250, 500], None, None],
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    max_threads=64,
    north=False,
    south=True,
    # NOTE: there appears to be a bug with the toolbox API at present (icenet#54)
    use_toolbox=False
)

# era5.download()

In [None]:
sic = SICDownloader(
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    north=False,
    south=True,
    parallel_opens=False,
)

# sic.download()

Re-grid ERA5 reanalysis data, and rotate wind vector data from ERA5 to align with EASE2 projection.

In [None]:
era5.regrid()
era5.rotate_wind_data()

## 2. Data Processing

Process downloaded datasets.

To make life easier, setting up train, val, test dates.

In [None]:
processing_dates = dict(
    train=[pd.to_datetime(el) for el in pd.date_range("2020-01-01", "2020-03-31")],
    val=[pd.to_datetime(el) for el in pd.date_range("2020-04-03", "2020-04-23")],
    test=[pd.to_datetime(el) for el in pd.date_range("2020-04-01", "2020-04-02")],
)
processed_name = "notebook_api_pytorch_data"

Next, we create the data producer and configure them for the dataset we want to create.

In [None]:
from icenet.data.processors.era5 import IceNetERA5PreProcessor
from icenet.data.processors.meta import IceNetMetaPreProcessor
from icenet.data.processors.osi import IceNetOSIPreProcessor

pp = IceNetERA5PreProcessor(
    ["uas", "vas"],
    ["tas", "zg500", "zg250"],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

osi = IceNetOSIPreProcessor(
    ["siconca"],
    [],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

meta = IceNetMetaPreProcessor(
    processed_name,
    north=False,
    south=True
)

Next, we initialise the data processors using `init_source_data` which scans the data source directories to understand what data is available for processing based on the parameters. Since we named the processed data `"notebook_api_data"` above, it will create a data loader config file, `loader.notebook_api_data.json`, in the current directory.

In [None]:
# Causes hanging on training, when generating sample.
pp.init_source_data(
    lag_days=1,
)
pp.process()

osi.init_source_data(
    lag_days=1,
)
osi.process()

meta.process()

At this point the preprocessed data is ready to convert or create a configuration for the network dataset.

### Dataset creation

As with the `icenet_dataset_create` command we can create a dataset configuration for training the network. As before this can include cached data for the network in the format of a TFRecordDataset compatible set of tfrecords. To achieve this we create the `IceNetDataLoader`, which can both generate `IceNetDataSet` configurations (which easily provide the necessary functionality for training and prediction) as well as individual data samples for direct usage.

In [None]:
from icenet.data.loaders import IceNetDataLoaderFactory

implementation = "dask"
loader_config = "loader.notebook_api_pytorch_data.json"
dataset_name = "notebook_api_pytorch_data"
lag = 1

dl = IceNetDataLoaderFactory().create_data_loader(
    implementation,
    loader_config,
    dataset_name,
    lag,
    n_forecast_days=7,
    north=False,
    south=True,
    output_batch_size=1,
    generate_workers=4)

At this point we can either use `generate` or `write_dataset_config_only` to produce a ready-to-go `IceNetDataSet` configuration. Both of these will generate a dataset config, `dataset_config.notebook_api_pytorch_data.json` (recall we set the dataset name as `notebook_api_pytorch_data` above).

In this case, for pytorch, will read data in directly, rather than using cached tfrecords inputs.

In [None]:
dl.write_dataset_config_only()

We can now create the IceNetDataSet object:

In [None]:
from icenet.data.dataset import IceNetDataSetPyTorch
dataset_config = f"dataset_config.{dataset_name}.json"

In [None]:
batch_size = 16
shuffle = False
persistent_workers=True
num_workers = 8

## 3. Train

We implement a custom PyTorch class for training.

## IceNet2 U-Net Diffusion model

Maria's work (PyTorch Diffusion using U-Net)

In [None]:
class Interpolate(nn.Module):
    def __init__(self, scale_factor, mode):
        super().__init__()
        self.interp = F.interpolate
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode)
        return x

In [None]:
class GaussianDiffusion:
    """
    Implements the forward and reverse processes of a Denoising Diffusion Probabilistic Model (DDPM),
    including support for cosine and linear beta schedules.
    """
    
    def __init__(self, timesteps: int = 1000, beta_schedule: str = 'cosine'):
        """
        Initialize diffusion parameters and precompute useful constants.

        Args:
            timesteps (int): Total number of diffusion steps.
            beta_schedule (str): Type of beta schedule to use. Options: 'linear', 'cosine'.
        """
        self.timesteps = timesteps
        
        if beta_schedule == 'linear':
            self.betas = torch.linspace(1e-4, 0.02, timesteps)
        elif beta_schedule == 'cosine':
            self.betas = self._cosine_beta_schedule(timesteps)
            
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
    
    # def _cosine_beta_schedule(self, timesteps, s=0.008):
    def _cosine_beta_schedule(self, timesteps, s=0.015):
        """
        Compute beta schedule using a cosine function.

        Args:
            timesteps (int): Total number of timesteps.
            s (float): Small offset to prevent singularities near 0.

        Returns:
            torch.Tensor: Beta values of shape (timesteps,).
        """
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        
        return torch.clip(betas, 0, 0.999)
    
    def q_sample(self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor = None) -> torch.Tensor:
        """
        Add noise to x_start at timestep t, using the forward diffusion process.

        Args:
            x_start (torch.Tensor): Original input tensor (clean image).
            t (torch.Tensor): Timesteps for each sample in the batch (shape: [B]).
            noise (torch.Tensor, optional): Noise to add. If None, standard Gaussian noise is used.

        Returns:
            torch.Tensor: Noisy sample at timestep t.
        """
        if noise is None:
            noise = torch.randn_like(x_start)
            
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_sample(self, x: torch.Tensor, t: torch.Tensor, pred_noise: torch.Tensor) -> torch.Tensor:
        """
        Perform a single reverse diffusion step.

        Args:
            x (torch.Tensor): Current noisy sample at timestep t.
            t (torch.Tensor): Timesteps for each sample in the batch (shape: [B]).
            pred_noise (torch.Tensor): Model's predicted noise (εθ) for x at timestep t.

        Returns:
            torch.Tensor: Sample from the previous timestep (x_{t-1}).
        """
        betas_t = self._extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alphas_t = self._extract(self.sqrt_recip_alphas, t, x.shape)
        
        # Equation 11 in the paper (our pred_noise is εθ)
        model_mean = sqrt_recip_alphas_t * (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t)
        
        # Create mask for where t == 0
        nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        
        # Only add noise if t != 0
        posterior_variance_t = self._extract(self.posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        
        return model_mean + nonzero_mask * torch.sqrt(posterior_variance_t) * noise

    def _extract(self, a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor:
        """
        Extract values from a tensor at specific timesteps t and reshape for broadcasting.

        Args:
            a (torch.Tensor): 1D tensor containing precomputed values (e.g., alpha or beta schedule).
            t (torch.Tensor): Timesteps for each sample in the batch (shape: [B]).
            x_shape (Tuple[int]): Target shape for broadcasting (same as input sample x).

        Returns:
            torch.Tensor: Extracted and reshaped values for each timestep in the batch.
        """
        a = a.to(t.device)
        out = a[t]  # (batch_size,) # Reshape for broadcasting: [batch_size, 1, 1, 1, 1]
      
        return out.view((t.shape[0],) + (1,) * (len(x_shape) - 1))

In [None]:
#mc
class UNetDiffusion(nn.Module):
    """
    U-Net architecture for conditional DDPM-based forecasting.
    Inputs include noisy predictions, time step embeddings, and conditioning inputs.
    Supports configurable depth, filter size, and number of forecast days/classes.
    """
    
    def __init__(self,
                 input_channels,
                 filter_size=3,
                 n_filters_factor=1,
                 n_forecast_days=7,
                 n_output_classes=1,
                 timesteps=1000,
                 **kwargs):
        """
        Initialize the U-Net diffusion model.

        Args:
            input_channels (int): Number of input conditioning channels (e.g., meteorological variables).
            filter_size (int): Convolution kernel size for all conv layers.
            n_filters_factor (float): Scaling factor for channel depth across the network.
            n_forecast_days (int): Number of days to forecast.
            n_output_classes (int): Number of output regression targets per forecast day.
            timesteps (int): Number of diffusion timesteps.
            **kwargs: Additional arguments (ignored).
        """
        super(UNetDiffusion, self).__init__()

        self.input_channels = input_channels
        self.filter_size = filter_size
        self.n_filters_factor = n_filters_factor
        self.n_forecast_days = n_forecast_days
        self.n_output_classes = n_output_classes
        self.timesteps = timesteps
        
        # Time embedding
        self.time_embed_dim = 256
        self.time_embed = nn.Sequential(
            nn.Linear(self.time_embed_dim, self.time_embed_dim * 4),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim * 4, self.time_embed_dim),
        )
        
        # Channel calculations
        start_out_channels = 64
        reduced_channels = self._make_divisible(int(start_out_channels * n_filters_factor), 8)
        channels = {
            start_out_channels * 2**pow: self._make_divisible(reduced_channels * 2**pow, 8)
            for pow in range(4)
        }

        self.initial_conv_channels = (n_output_classes * n_forecast_days) + input_channels
        
        # Encoder
        self.conv1 = self.conv_block(self.initial_conv_channels, channels[64])
        self.conv2 = self.conv_block(channels[64], channels[128])
        self.conv3 = self.conv_block(channels[128], channels[256])
        self.conv4 = self.conv_block(channels[256], channels[256])

        # Bottleneck
        self.conv5 = self.bottleneck_block(channels[256], channels[512])

        # Decoder
        self.up6 = self.upconv_block(channels[512], channels[256])
        self.up7 = self.upconv_block(channels[256], channels[256])
        self.up8 = self.upconv_block(channels[256], channels[128])
        self.up9 = self.upconv_block(channels[128], channels[64])

        self.up6b = self.conv_block(channels[512] + self.time_embed_dim, channels[256])
        self.up7b = self.conv_block(channels[512] + self.time_embed_dim, channels[256])
        self.up8b = self.conv_block(channels[256] + self.time_embed_dim, channels[128])
        self.up9b = self.conv_block(channels[128] + self.time_embed_dim, channels[64], final=True)

        # Final layer
        self.final_layer = nn.Conv2d(channels[64], n_output_classes * n_forecast_days, kernel_size=1, padding="same")

    def forward(self, x, t, y, sample_weight):
        """
        Forward pass of the U-Net diffusion model.

        Args:
            x (torch.Tensor): Noisy forecast tensor of shape [B, H, W, n_classes, n_forecast_days].
            t (torch.Tensor): Diffusion timestep tensor of shape [B].
            y (torch.Tensor): Conditioning input tensor of shape [B, H, W, input_channels].
            sample_weight (torch.Tensor or None): Optional weighting mask [B, H, W, n_classes, n_forecast_days].

        Returns:
            torch.Tensor: Predicted denoised forecast of shape [B, H, W, n_classes, n_forecast_days].
        """
        # Time embedding
        t = self._timestep_embedding(t)
        t = self.time_embed(t)
        
        # Concatenate with conditional input
        x = torch.cat([x, y], dim=-1)  # [b,h,w,(d*c)+input_channels]
        
        # Convert to channel-first format
        x = torch.movedim(x, -1, 1)  # [b,channels,h,w]

        # Encoder pathway
        bn1 = self.conv1(x)
        conv1 = F.max_pool2d(bn1, kernel_size=2)
        bn2 = self.conv2(conv1)
        conv2 = F.max_pool2d(bn2, kernel_size=2)
        bn3 = self.conv3(conv2)
        conv3 = F.max_pool2d(bn3, kernel_size=2)
        bn4 = self.conv4(conv3)
        conv4 = F.max_pool2d(bn4, kernel_size=2)

        # Bottleneck
        bn5 = self.conv5(conv4)

        # Decoder with time embedding
        up6 = self.up6(bn5)
        up6 = torch.cat([bn4, up6], dim=1)
        up6 = self._add_time_embedding(up6, t)
        up6 = self.up6b(up6)
        
        up7 = self.up7(up6)
        up7 = torch.cat([bn3, up7], dim=1)
        up7 = self._add_time_embedding(up7, t)
        up7 = self.up7b(up7)
        
        up8 = self.up8(up7)
        up8 = torch.cat([bn2, up8], dim=1)
        up8 = self._add_time_embedding(up8, t)
        up8 = self.up8b(up8)
        
        up9 = self.up9(up8)
        up9 = torch.cat([bn1, up9], dim=1)
        up9 = self._add_time_embedding(up9, t)
        up9 = self.up9b(up9)

        # Final output
        output = self.final_layer(up9)  # [b, c_out, h, w]
        output = torch.movedim(output, 1, -1)  # [b, h, w, c_out]

        b, h, w, c = output.shape
        output = output.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))

        return output
        
    def _make_divisible(self, v, divisor):
        """
        Ensures a value is divisible by a specified divisor.

        Args:
            v (int): Value to adjust.
            divisor (int): Value to divide by.

        Returns:
            int: Adjusted value divisible by divisor.
        """
        return max(divisor, (v // divisor) * divisor)

    def _get_num_groups(self, channels):
        """
        Determines the maximum number of groups that divide `channels` for GroupNorm.

        Args:
            channels (int): Number of feature channels.

        Returns:
            int: Optimal number of groups.
        """
        num_groups = 8  # Start with preferred group count
        while num_groups > 1:
            if channels % num_groups == 0:
                return num_groups
            num_groups -= 1
        return 1  # Fallback to GroupNorm(1,...) which is equivalent to LayerNorm

    def _timestep_embedding(self, timesteps, dim=256, max_period=10000):
        """
        Converts timestep integers into sinusoidal positional embeddings.

        Args:
            timesteps (torch.Tensor): Timestep tensor [B].
            dim (int): Embedding dimension.
            max_period (int): Frequency range.

        Returns:
            torch.Tensor: Embedding tensor of shape [B, dim].
        """
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    
    def _add_time_embedding(self, x, t):
        """
        Concatenates time embedding across spatial dimensions.

        Args:
            x (torch.Tensor): Feature map tensor [B, C, H, W].
            t (torch.Tensor): Time embedding tensor [B, D].

        Returns:
            torch.Tensor: Time-conditioned feature map [B, C+D, H, W].
        """
        b, c, h, w = x.shape
        t = t.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, h, w)
        return torch.cat([x, t], dim=1)
    
    def conv_block(self, in_channels, out_channels, final=False):
        """
        Standard convolutional block with GroupNorm and SiLU activation.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            final (bool): Whether to add an extra conv layer at the end.

        Returns:
            nn.Sequential: Conv block.
        """
        num_groups = self._get_num_groups(out_channels)
        
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=self.filter_size, padding="same"),
            nn.GroupNorm(num_groups, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding="same"),
            nn.GroupNorm(num_groups, out_channels),
            nn.SiLU(),
        ]
        if not final:
            return nn.Sequential(*layers)
        else:
            final_layers = [
                nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding="same"),
                nn.GroupNorm(num_groups, out_channels),
                nn.SiLU(),
            ]
            return nn.Sequential(*(layers + final_layers))

    def bottleneck_block(self, in_channels, out_channels):
        """
        Bottleneck block at the center of the U-Net.

        Args:
            in_channels (int): Input channel size.
            out_channels (int): Output channel size.

        Returns:
            nn.Sequential: Bottleneck block.
        """
        num_groups = self._get_num_groups(out_channels)
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=self.filter_size, padding="same"),
            nn.GroupNorm(num_groups, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding="same"),
            nn.GroupNorm(num_groups, out_channels),
            nn.SiLU(),
        )

    def upconv_block(self, in_channels, out_channels):
        """
        Upsampling block with interpolation and convolution.

        Args:
            in_channels (int): Input channel size.
            out_channels (int): Output channel size.

        Returns:
            nn.Sequential: Upsampling block.
        """
        num_groups = self._get_num_groups(out_channels)
        return nn.Sequential(
            Interpolate(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"),
            nn.GroupNorm(num_groups, out_channels),
            nn.SiLU()
        )

Define custom metrics for use in validation and monitoring

In [None]:
class IceNetAccuracy(Metric):
    """Binary accuracy metric for use at multiple leadtimes.

    Reference: https://lightning.ai/docs/torchmetrics/stable/pages/implement.html
    """    

    # Set class properties
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = True

    def __init__(self, leadtimes_to_evaluate: list):
        """Custom loss/metric for binary accuracy in classifying SIC>15% for multiple leadtimes.

        Args:
            leadtimes_to_evaluate: A list of leadtimes to consider
                e.g., [0, 1, 2, 3, 4, 5] to consider first six days in accuracy computation or
                e.g., [0] to only look at the first day's accuracy
                e.g., [5] to only look at the sixth day's accuracy
        """
        super().__init__()
        self.leadtimes_to_evaluate = leadtimes_to_evaluate
        self.add_state("weighted_score", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("possible_score", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):
        # preds and target are shape (b, h, w, t)
        preds = (preds > 0.15).long() # torch.Size([2, 432, 432, 7])
        target = (target > 0.15).long() # torch.Size([2, 432, 432, 7])
        
        sample_weight = sample_weight.squeeze()
        base_score = preds[:, :, :, self.leadtimes_to_evaluate] == target[:, :, :, self.leadtimes_to_evaluate]
        self.weighted_score += torch.sum(base_score * sample_weight[:, :, :, self.leadtimes_to_evaluate])
        self.possible_score += torch.sum(sample_weight[:, :, :, self.leadtimes_to_evaluate])

    def compute(self):
        return self.weighted_score.float() / self.possible_score * 100.0


class SIEError(Metric):
    """
    Sea Ice Extent error metric (in km^2) for use at multiple leadtimes.
    """ 

    # Set class properties
    is_differentiable: bool = False
    higher_is_better: bool = False
    full_state_update: bool = True

    def __init__(self, leadtimes_to_evaluate: list):
        """Construct an SIE error metric (in km^2) for use at multiple leadtimes.
            leadtimes_to_evaluate: A list of leadtimes to consider
                e.g., [0, 1, 2, 3, 4, 5] to consider six days in computation or
                e.g., [0] to only look at the first day
                e.g., [5] to only look at the sixth day
        """
        super().__init__()
        self.leadtimes_to_evaluate = leadtimes_to_evaluate
        self.add_state("pred_sie", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("true_sie", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):
        # preds and target are shape (b, h, w, t)
        preds = (preds > 0.15).long()
        target = (target > 0.15).long()
        self.pred_sie += preds[:, :, :, self.leadtimes_to_evaluate].sum()
        self.true_sie += target[:, :, :, self.leadtimes_to_evaluate].sum()

    def compute(self):
        return (self.pred_sie - self.true_sie) * 25**2 # each pixel is 25x25 km

Define custom loss functions

In [None]:
class WeightedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, inputs, targets, sample_weights):
        """
        Weighted BCEWithLogitsLoss loss.

        Compute BCEWithLogitsLoss loss weighted by masking.

        Using BCEWithLogitsLoss instead of BCELoss, as pytorch docs mentions it is
        more numerically stable.
        https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
        
        """
        # Computing using nn.BCEWithLogitsLoss base class. This class must be instantiated via:
        # >>> criterion = WeightedBCEWithLogitsLoss(reduction='none')
        loss = super().forward(
                            (inputs.movedim(-2, 1)),
                            (targets.movedim(-1, 1))
                         )*sample_weights.movedim(-1, 1)
        
        return loss.mean()

class WeightedL1Loss(nn.L1Loss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, inputs, targets, sample_weights):
        """
        Weighted L1 loss.

        Compute L1 loss weighted by masking.
        
        """
        y_hat = torch.sigmoid(inputs)

        loss = super().forward(
                            (100*y_hat), 
                            (100*targets)
                         )*sample_weights
        
        return loss.mean()

class WeightedMSELoss(nn.MSELoss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, inputs, targets, sample_weights):
        """
        Weighted MSE loss.

        Compute MSE loss weighted by masking.
        
        """
        y_hat = inputs
        y_hat = y_hat.squeeze()
        targets = targets.squeeze()

        sample_weights = sample_weights.squeeze()
        loss = super().forward((100*y_hat), (100*targets))*sample_weights
        
        return loss.mean()

A _LightningModule_ wrapper for UNetDiffusion model.

Function for training UNetDiffusion model using PyTorch Lightning.

In [None]:
#mc

class LitDiffusion(pl.LightningModule):
    """
    PyTorch Lightning wrapper for training and evaluating a diffusion-based model
    (e.g., DDPM) for conditional forecasting. Handles training loop, sampling, 
    metrics, and optimizer configuration.
    """
    def __init__(self,
                 model: nn.Module,
                 learning_rate: float,
                 criterion: callable,
                 timesteps: int = 1000):
        """
        Initialize the LightningModule for DDPM training.

        Args:
            model (nn.Module): The U-Net-style diffusion model.
            learning_rate (float): Optimizer learning rate.
            criterion (callable): Loss function used for evaluation (e.g., WeightedBCE or WeightedMSE).
            timesteps (int): Number of diffusion steps (T).
        """
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.timesteps = timesteps
        self.diffusion = GaussianDiffusion(timesteps=timesteps)
        self.criterion = criterion
        
        self.n_output_classes = model.n_output_classes

        metrics = {
            "val_accuracy": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),
            "val_sieerror": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))
        }
        for i in range(self.model.n_forecast_days):
            metrics[f"val_accuracy_{i}"] = IceNetAccuracy(leadtimes_to_evaluate=[i])
            metrics[f"val_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i])
        self.metrics = MetricCollection(metrics)

        test_metrics = {
            "test_accuracy": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),
            "test_sieerror": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))
        }
        for i in range(self.model.n_forecast_days):
            test_metrics[f"test_accuracy_{i}"] = IceNetAccuracy(leadtimes_to_evaluate=[i])
            test_metrics[f"test_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i])
        self.test_metrics = MetricCollection(test_metrics)

        self.save_hyperparameters()

    def forward(self, x):
        """
        Run the model in inference mode.

        Args:
            x (torch.Tensor or tuple): Either a tensor [B, H, W, C] or a batch from DataLoader.

        Returns:
            torch.Tensor: Generated sample(s).
        """
        if isinstance(x, (list, tuple)):
            x = x[0]  # Extract input features from batch tuple
        elif not isinstance(x, torch.Tensor):
            x = torch.tensor(x, device=self.device)
        return self.sample(x)


    def training_step(self, batch):
        """
        One training step using DDPM loss (predicted noise vs. true noise).

        Args:
            batch (tuple): (x, y, sample_weight).

        Returns:
            dict: {"loss": loss}
        """
        x, y, sample_weight = batch
        y = y.squeeze(-1)  # Removes the last dimension (size 1)
        
        # Sample random timesteps
        t = torch.randint(0, self.timesteps, (x.shape[0],), device=x.device).long()
        
        # Create noisy version
        noise = torch.randn_like(y)
        noisy_y = self.diffusion.q_sample(y, t, noise)
        
        # Predict the noise
        pred_noise = self.model(noisy_y, t, x, sample_weight)
        
        pred_noise = pred_noise.squeeze()
        noise = noise.squeeze()

        # Calculate loss
        loss = F.mse_loss(pred_noise, noise) #, reduction='none')  

        # print(f"loss: {loss.shape}")  # Should be [B,H,W,432]
        # print(f"sample_weight: {sample_weight.shape}")  # Likely [B,H,W,7,1]
        
        # # Apply sample weights
        # if sample_weight is not None:
        #     # Ensure proper broadcasting: [B,H,W,1] -> [B,H,W,C]
        #     loss = loss * sample_weight.squeeze(-1)
        
        # # Final reduction
        # noise_loss = loss.mean()  # Scalar value
        
        # outputs = self.sample(x, sample_weight)
        # y_hat = torch.sigmoid(outputs)
        # pred_loss = self.criterion(y_hat, y, sample_weight)

        # alpha = 0.5
        # loss = (alpha * noise_loss) + ((1 - alpha) * pred_loss)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return {"loss": loss}

    def validation_step(self, batch):
        """
        One validation step using structural loss.

        Args:
            batch (tuple): (x, y, sample_weight)

        Returns:
            dict: {"val_loss": loss}
        """
        x, y, sample_weight = batch
        y = y.squeeze(-1)  # Removes the last dimension (size 1)
        
        outputs = self.sample(x, sample_weight)
        y_hat = torch.sigmoid(outputs)

        loss = self.criterion(y_hat, y, sample_weight)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)  # epoch-level loss

        self.metrics.update(y_hat, y, sample_weight)  
        return {"val_loss": loss}

    def sample(self, x, sample_weight, num_samples=1):
        """
        Perform reverse diffusion sampling starting from noise.

        Args:
            x (torch.Tensor): Conditioning input [B, H, W, C].
            sample_weight (torch.Tensor or None): Optional weights.
            num_samples (int): Not used (for future batching).

        Returns:
            torch.Tensor: Final denoised output [B, H, W, C_out].
        """
        shape = (x.shape[0], *x.shape[1:-1], self.model.n_forecast_days * self.n_output_classes)
        device = x.device
        
        # Start from pure noise
        y = torch.randn(shape, device=device)
        
        for t in reversed(range(0, self.timesteps)):
            t_batch = torch.full((x.shape[0],), t, device=device, dtype=torch.long)
            pred_noise = self.model(y, t_batch, x, sample_weight)
        
            pred_noise = pred_noise.squeeze(3)
            y = self.diffusion.p_sample(y, t_batch, pred_noise)
            
        return y

    def on_validation_epoch_end(self):
        """
        Called at the end of validation to log all collected metrics.
        """
        self.log_dict(self.metrics.compute(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.metrics.reset()

    def test_step(self, batch, batch_idx):
        """
        One test step using structural loss and full metric evaluation.

        Args:
            batch (tuple): (x, y, sample_weight)
            batch_idx (int): Batch index.

        Returns:
            torch.Tensor: Loss value.
        """
        x, y, sample_weight = batch
        y = y.squeeze(-1)  # Removes the last dimension (size 1)
        
        outputs = self.sample(x, sample_weight)
        y_hat = torch.sigmoid(outputs)
        loss = self.criterion(y_hat, y, sample_weight)
        
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)  # epoch-level loss
        self.test_metrics.update(y_hat, y, sample_weight)
    
        return loss

    def on_test_epoch_end(self):
        """
        Called at the end of test loop to log all collected metrics.
        """
        self.log_dict(self.test_metrics.compute(), on_step=False, on_epoch=True, sync_dist=True)
        self.test_metrics.reset()

    def configure_optimizers(self):
        """
        Set up the optimizer.

        Returns:
            torch.optim.Optimizer: Adam optimizer.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
        
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Perform prediction on a single batch.

        Args:
            batch (tuple): (x, y, sample_weight).
            batch_idx (int): Batch index.
            dataloader_idx (int): Index of the DataLoader (if using multiple).

        Returns:
            torch.Tensor: Prediction tensor [B, H, W, C_out].
        """
        x, y, sample_weight = batch
        y = y.squeeze(-1)  # Removes the last dimension (size 1)
        
        outputs = self.sample(x, sample_weight)
        y_hat = torch.sigmoid(outputs)
            
        loss = self.criterion(y_hat, y, sample_weight)

        return y_hat
        

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

def train_diffusion_icenet(configuration_path,
                          learning_rate,
                          max_epochs,
                          batch_size,
                          n_workers,
                          filter_size,
                          n_filters_factor,
                          seed,
                          timesteps=1000):
    """
    Train IceNet diffusion model using the specified parameters.

    Args:
        configuration_path (str): Path to IceNet configuration YAML file.
        learning_rate (float): Learning rate for optimizer.
        max_epochs (int): Number of training epochs.
        batch_size (int): Mini-batch size.
        n_workers (int): Number of workers for data loading.
        filter_size (int): Convolution kernel size used in UNet layers.
        n_filters_factor (float): Scaling factor for number of filters in UNet.
        seed (int): Random seed for reproducibility.
        timesteps (int): Number of diffusion steps (T).

    Returns:
        tuple: (model, trainer, checkpoint_callback)
            model (UNetDiffusion): Trained model.
            trainer (pl.Trainer): PyTorch Lightning trainer used for training.
            checkpoint_callback (ModelCheckpoint): Callback used for saving the best model.
    """
    # init
    pl.seed_everything(seed)
    
    # configure datasets and dataloaders
    train_dataset = IceNetDataSetPyTorch(configuration_path, mode="train")
    val_dataset = IceNetDataSetPyTorch(configuration_path, mode="val")
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers,
                                 persistent_workers=persistent_workers, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers,
                               persistent_workers=persistent_workers, shuffle=False)

    #mc debug
    # Check the shape of a batch of data from the train dataloader
    for batch in train_dataloader:
        # Assuming the batch contains (x, y, sample_weight)
        x, y, sample_weight = batch
        break  # We only need to inspect one batch

    # construct diffusion model
    model = UNetDiffusion(
        input_channels=train_dataset._num_channels,
        filter_size=filter_size,
        n_filters_factor=n_filters_factor,
        n_forecast_days=train_dataset._n_forecast_days,
        timesteps=timesteps
    )
    
    criterion = WeightedMSELoss(reduction="none")
    
    # configure PyTorch Lightning module
    lit_module = LitDiffusion(
        model=model,
        learning_rate=learning_rate,
        criterion=criterion,
        timesteps=timesteps
    )

    # set up trainer configuration
    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        log_every_n_steps=5,
        max_epochs=max_epochs,
        num_sanity_val_steps=1,
        fast_dev_run=False,
    )
    # checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
    checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
    trainer.callbacks.append(checkpoint_callback)

    # train model
    print(f"Training {len(train_dataset)} examples / {len(train_dataloader)} batches (batch size {batch_size}).")
    print(f"Validating {len(val_dataset)} examples / {len(val_dataloader)} batches (batch size {batch_size}).")
    trainer.fit(lit_module, train_dataloader, val_dataloader)

    return model, trainer, checkpoint_callback

In [None]:
print("""
seed = 45

#mc
model, trainer, checkpoint_callback = train_diffusion_icenet(
    configuration_path=dataset_config,
    learning_rate=3e-4, #3e-4, #1e-4,
    max_epochs=150,
    batch_size=batch_size,
    n_workers=num_workers,
    filter_size=3,
    n_filters_factor=0.5, #0.7, #1.0, #0.4,
    seed=seed,
    timesteps=1000
)
""")

Conduct actual training run.

In [None]:
seed = 45

#mc
model, trainer, checkpoint_callback = train_diffusion_icenet(
    configuration_path=dataset_config,
    learning_rate=3e-4, #3e-4, #1e-4,
    max_epochs=150,
    batch_size=batch_size,
    n_workers=num_workers,
    filter_size=3,
    n_filters_factor=0.5, #0.7, #1.0, #0.4,
    seed=seed,
    timesteps=1000
)

## 4. Prediction

Predicts using the best checkpoint from the training.

In [None]:
checkpoint_callback.best_k_models

In [None]:
best_checkpoint = checkpoint_callback.best_model_path
best_checkpoint

In [None]:
# Load the best result from the checkpoint
# best_model = LitUNet.load_from_checkpoint(best_checkpoint)

#mc
best_model = LitDiffusion.load_from_checkpoint(best_checkpoint)

# disable randomness, dropout, etc...
best_model.eval()

In [None]:
test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="test")
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers,
                             persistent_workers=persistent_workers, shuffle=False)

# automatically load the best weights (if best_model isn't added)
trainer.test(dataloaders=test_dataloader)

In [None]:
# # cosine results
# [{'test_loss': 0.48640450835227966,
#   'test_accuracy': 51.62199783325195,
#   'test_accuracy_0': 51.56485366821289,
#   'test_accuracy_1': 51.77238464355469,
#   'test_accuracy_2': 52.235984802246094,
#   'test_accuracy_3': 51.62677764892578,
#   'test_accuracy_4': 51.081172943115234,
#   'test_accuracy_5': 51.31882858276367,
#   'test_accuracy_6': 51.75397872924805,
#   'test_sieerror': 729145600.0,
#   'test_sieerror_0': 104656872.0,
#   'test_sieerror_1': 104813752.0,
#   'test_sieerror_2': 104287504.0,
#   'test_sieerror_3': 104360624.0,
#   'test_sieerror_4': 104103128.0,
#   'test_sieerror_5': 103494376.0,
#   'test_sieerror_6': 103429376.0}]

In [None]:
logging.info("Generating predictions")

predictions = trainer.predict(best_model, dataloaders=test_dataloader)

In [None]:
for worker, prediction in enumerate(predictions):
    print(f"Worker: {worker} | Prediction: {prediction.shape}")

## 5. Outputs and Plotting

Create prediction output directory

In [None]:
# dataset = "pytorch_notebook"
network_name = "api_pytorch_dataset"
output_name = "example_pytorch_forecast_diff"
output_folder = os.path.join(".", "results", "predict", output_name,
                                "{}.{}".format(network_name, seed))
os.makedirs(output_folder, exist_ok=output_folder)

Convert and output predictions to numpy files

In [None]:
idx = 0
for workers, prediction in enumerate(predictions):
    for batch in range(prediction.shape[0]):
        date = pd.Timestamp(test_dataset.dates[idx].replace('_', '-'))
        output_path = os.path.join(output_folder, date.strftime("%Y_%m_%d.npy"))
        print("prediction shape...",prediction.shape)
        # forecast = prediction[batch, :, :, :, :].movedim(-2, 0)
        forecast = prediction[batch, :, :, :].movedim(-1, 0)
        forecast_np = forecast.detach().cpu().numpy()
        np.save(output_path, forecast_np)
        idx += 1

In [None]:
forecast.shape

Create a csv file with all the test dates we have predicted for, and to use in generating the final netCDF output using `icenet_output`.

In [None]:
!printf "2020-04-01\n2020-04-02" | tee testdates_diff.csv

In [None]:
!icenet_output -m -o results/predict example_pytorch_forecast_diff notebook_api_pytorch_data testdates_diff.csv

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import matplotlib.pyplot as plt
import os
from datetime import datetime

# Change this to the actual version dir
log_dir = "lightning_logs/version_1062817"
# log_dir = f"lightning_logs/version_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

# Load the logs
event_acc = EventAccumulator(log_dir)
event_acc.Reload()

# List all scalar tags to find the correct name
print("Available tags:", event_acc.Tags()['scalars'])

# Get the scalar events for val_loss
val_loss_events = event_acc.Scalars('val_loss')

# FIX: Use index as epoch number instead of .step
steps = list(range(1, len(val_loss_events) + 1))
values = [e.value for e in val_loss_events]

# Plot
plt.figure(figsize=(8, 5))
plt.plot(steps, values, label='Validation Loss', color='blue')
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss Over Training")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Update with your actual path
log_dir = "lightning_logs/version_1062817"
# log_dir = f"lightning_logs/version_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

event_acc = EventAccumulator(log_dir)
event_acc.Reload()

# List all available scalar tags
print("Available scalar tags:")
print(event_acc.Tags()['scalars'])  # e.g. val_accuracy, val_accuracy_0, etc.

# Get accuracy for all lead times (overall accuracy)
accuracy_events = event_acc.Scalars('val_accuracy')

# FIX: Use index as epoch number instead of .step
steps = list(range(1, len(accuracy_events) + 1))
values = [e.value for e in accuracy_events]

# Plot
plt.figure(figsize=(8, 5))
plt.plot(steps, values, label='Validation Accuracy', color='green')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy Over Epochs")
plt.legend()
plt.grid(True)
plt.show()


Plotting the forecast

In [None]:
import xarray as xr
import datetime as dt
from IPython.display import HTML

In [None]:
from icenet.plotting.video import xarray_to_video as xvid
from icenet.data.sic.mask import Masks

ds = xr.open_dataset("results/predict/example_pytorch_forecast_diff.nc")
land_mask = Masks(south=True, north=False).get_land_mask()
ds.info()

Animate result

In [None]:
forecast_date = ds.time.values[0]
fc = ds.sic_mean.isel(time=0).drop_vars("time").rename(dict(leadtime="time"))
fc['time'] = [pd.to_datetime(forecast_date) \
              + dt.timedelta(days=int(e)) for e in fc.time.values]

anim = xvid(fc, 15, figsize=(4,4), mask=land_mask)
HTML(anim.to_jshtml())

Check min/max of predicted SIC fraction

In [None]:
print( forecast_np[:, :, :, 0].shape )
fmin, fmax = np.min(forecast_np[:, :, :, 0]), np.max(forecast_np[:, :, :, 0])
print( f"First forecast day min: {fmin:.4f}, max: {fmax:.4f}" )

#### Load original input dataset

This is the original input dataset (pre-normalisation) for comparison.

In [None]:
# Load original input dataset (domain not normalised)
xr.plot.contourf(xr.open_dataset("data/osisaf/south/siconca/2020.nc").isel(time=92).ice_conc, levels=50)

## Version
- IceNet Codebase: v0.2.8