In [1]:
import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader

# We also set the logging level so that we get some feedback from the API
import logging
logging.basicConfig(level=logging.INFO)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_float32_matmul_precision('medium')

## 1. Download

In [2]:
import numpy
from icenet.data.sic.mask import Masks
from icenet.data.sic.osisaf import SICDownloader

### Mask data

Create masks for masking data.

In [3]:
masks = Masks(north=False, south=True)
masks.generate(save_polarhole_masks=False)

INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_01.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_02.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_03.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_04.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_05.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_06.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_07.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_08.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_09.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_10.npy, already exists
INFO:root:Skipping ./data/masks/south/masks/active_grid_cell_mask_11.npy, already exists
INFO:root:Skipping ./

### Sea Ice data

Download sea ice concentration from OSI-SAF.

In [4]:
sic = SICDownloader(
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    north=False,
    south=True,
    parallel_opens=False,
)

sic.download()

INFO:root:Downloading SIC datafiles to .temp intermediates...
INFO:root:Excluding 121 dates already existing from 121 dates requested.
INFO:root:Opening for interpolation: ['./data/osisaf/south/siconca/2020.nc']
INFO:root:Processing 0 missing dates


## 2. Data Processing

Process downloaded datasets.

To make life easier, setting up train, val, test dates.

In [5]:
processing_dates = dict(
    train=[pd.to_datetime(el) for el in pd.date_range("2020-01-01", "2020-03-31")],
    val=[pd.to_datetime(el) for el in pd.date_range("2020-04-03", "2020-04-23")],
    test=[pd.to_datetime(el) for el in pd.date_range("2020-04-01", "2020-04-02")],
)
processed_name = "notebook_api_pytorch_data"

Next, we create the data producer and configure them for the dataset we want to create.

In [6]:
from icenet.data.processors.meta import IceNetMetaPreProcessor
from icenet.data.processors.osi import IceNetOSIPreProcessor


osi = IceNetOSIPreProcessor(
    ["siconca"],
    [],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

meta = IceNetMetaPreProcessor(
    processed_name,
    north=False,
    south=True
)

Next, we initialise the data processors using `init_source_data` which scans the data source directories to understand what data is available for processing based on the parameters. Since we named the processed data `"notebook_api_data"` above, it will create a data loader config file, `loader.notebook_api_data.json`, in the current directory.

In [7]:
osi.init_source_data(
    lag_days=1,
)
osi.process()

meta.process()

INFO:root:Processing 91 dates for train category
INFO:root:Including lag of 1 days
INFO:root:Including lead of 93 days
INFO:root:No data found for 2019-12-31, outside data boundary perhaps?
INFO:root:Processing 21 dates for val category
INFO:root:Including lag of 1 days
INFO:root:Including lead of 93 days
INFO:root:Processing 2 dates for test category
INFO:root:Including lag of 1 days
INFO:root:Including lead of 93 days
INFO:root:Got 1 files for siconca
INFO:root:Opening files for siconca
INFO:root:Filtered to 121 units long based on configuration requirements
INFO:root:No normalisation for siconca
INFO:root:Loading configuration ./loader.notebook_api_pytorch_data.json
INFO:root:Writing configuration to ./loader.notebook_api_pytorch_data.json
INFO:root:Loading configuration ./loader.notebook_api_pytorch_data.json
INFO:root:Writing configuration to ./loader.notebook_api_pytorch_data.json


At this point the preprocessed data is ready to convert or create a configuration for the network dataset.

### Dataset creation

As with the `icenet_dataset_create` command we can create a dataset configuration for training the network. As before this can include cached data for the network in the format of a TFRecordDataset compatible set of tfrecords. To achieve this we create the `IceNetDataLoader`, which can both generate `IceNetDataSet` configurations (which easily provide the necessary functionality for training and prediction) as well as individual data samples for direct usage.

In [8]:
from icenet.data.loaders import IceNetDataLoaderFactory

implementation = "dask"
loader_config = "loader.notebook_api_pytorch_data.json"
dataset_name = "notebook_api_pytorch_data"
lag = 1

dl = IceNetDataLoaderFactory().create_data_loader(
    implementation,
    loader_config,
    dataset_name,
    lag,
    n_forecast_days=7,
    north=False,
    south=True,
    output_batch_size=1,
    generate_workers=4)

INFO:root:Loading configuration loader.notebook_api_pytorch_data.json


At this point we can either use `generate` or `write_dataset_config_only` to produce a ready-to-go `IceNetDataSet` configuration. Both of these will generate a dataset config, `dataset_config.notebook_api_pytorch_data.json` (recall we set the dataset name as `notebook_api_pytorch_data` above).

In this case, for pytorch, will read data in directly, rather than using cached tfrecords inputs.

In [9]:
dl.write_dataset_config_only()

INFO:root:Writing dataset configuration without data generation
INFO:root:91 train dates in total, NOT generating cache data.
INFO:root:21 val dates in total, NOT generating cache data.
INFO:root:2 test dates in total, NOT generating cache data.
INFO:root:Writing configuration to ./dataset_config.notebook_api_pytorch_data.json


We can now create the IceNetDataSet object:

In [10]:
from icenet.data.dataset import IceNetDataSetPyTorch
dataset_config = f"dataset_config.{dataset_name}.json"

## 3. Train

We implement a custom PyTorch class for training.

## Persistence Model

Simple persistence model implementation.

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter


class PersistenceModel:
    """Simple persistence model: copy the latest frame N times."""
    def __init__(self, forecast_steps: int = 12):
        self.forecast_steps = forecast_steps

    def __call__(self, X: np.ndarray) -> np.ndarray:
        # X shape: (batch, channels, time, height, width)
        latest_frame = X[:, :, -1:, :, :]  # shape: (B, C, 1, H, W)

        # Repeat the last frame across the forecast steps
        y_hat = np.repeat(latest_frame, self.forecast_steps, axis=2)

        # Clip to valid range [0, 1]
        return y_hat.clip(0, 1)


In [12]:
train_dataset = IceNetDataSetPyTorch(dataset_config, mode="train")

INFO:root:Loading configuration dataset_config.notebook_api_pytorch_data.json
INFO:root:Loading configuration /Users/npedrazzini/Desktop/ice-station-zebra/notebook/loader.notebook_api_pytorch_data.json


In [13]:
sic = train_dataset[1][0][:,:,0]  # SIC - channel 0
X_input = sic[np.newaxis, np.newaxis, np.newaxis, :, :]  # (1, 1, 1, 432, 432)

model = PersistenceModel(forecast_steps=12)


## 4. Prediction

In [14]:
prediction = model(X_input)  # (1, 1, 12, 432, 432)

In [17]:
import numpy as np
import xarray as xr
import pandas as pd
import datetime as dt
from icenet.plotting.video import xarray_to_video as xvid
from icenet.data.sic.mask import Masks
from IPython.display import HTML

: 

In [15]:



# forecast = prediction.squeeze(0).squeeze(0)

# forecast_start_date = pd.Timestamp("2020-09-01") #random date
# time = [forecast_start_date + dt.timedelta(days=int(lead)) for lead in range(forecast.shape[0])]

# da = xr.DataArray(
#     forecast,
#     dims=("time", "y", "x"),
#     coords=dict(
#         time=time,
#         y=np.arange(forecast.shape[1]),
#         x=np.arange(forecast.shape[2]),
#         yc=("y", np.arange(forecast.shape[1])),  # add yc coordinate
#         xc=("x", np.arange(forecast.shape[2]))   # add xc coordinate
#     ),
#     name="sic_mean"
# )

# land_mask = Masks(south=True, north=False).get_land_mask()
# print(land_mask.shape)

In [16]:
# anim = xvid(da, fps=15, figsize=(4, 4), mask=land_mask)

# from IPython.display import HTML
# HTML(anim.to_jshtml())