# Pipeline EASE2‑3 km – Kara Sea sea‑ice forecast

End‑to‑end workflow that downloads / pre‑processes ERA5 + OSI‑SAF data at 3 km, builds PyTorch datasets and trains a UNet‑lite model to predict sea‑ice formation.

## 1 Overview
```
┌──────────────┐   ┌────────────┐   ┌─────────────┐   ┌───────────────┐   ┌───────────┐
│  Download &  │   │ Re‑grid to │   │ Crop + Mask │   │  DataLoader   │   │  CNN/UNet │
│  Averaging   ├──►│  EASE2‑3   ├──►│ 313×313 ROI │──►│  + Normalise  ├──►│  Training │
└──────────────┘   └────────────┘   └─────────────┘   └───────────────┘   └────┬──────┘
                                                                                │ preds
                                                                                ▼
                                                                            Post‑proc
```

## 2 Source datasets
| Source | Variables (daily) | Coverage | Native res. | Tool |
|--------|-------------------|----------|-------------|------|
| ERA5 (CDS) | uas, vas, tas (single), zg250, zg500 (pressure) | 1979‑present | 0.25° | custom `ERA5Downloader` |
| OSI‑SAF SIC v2p0 | sea‑ice concentration | 1979‑present | EASE2‑25 km | re‑gridded to 3 km |
| Land/Sea mask | water_mask | NH | EASE2‑3 km | script |
| Bathymetry (optional) | GEBCO | global | ~500 m | preprocess |

## 3 Patched ERA5Downloader (ease_target & bbox)

In [None]:
# ice_patch_era5.py
from typing import Optional, Tuple, List
import numpy as np
import iris
from icenet.data.interfaces.downloader import ClimateDownloader

class ERA5Downloader(ClimateDownloader):
    """ERA5 downloader with `ease_target` (e.g. 'EASE2-3') and `bbox` support."""

    def __init__(self, *args,
                 ease_target: str = "EASE2-25",
                 bbox: Optional[Tuple[float, float, float, float]] = None,
                 **kwargs):
        self._ease_target = ease_target
        self._bbox = bbox
        super().__init__(*args, **kwargs)

    # 1) dynamic template cube
    @property
    def sic_ease_cube(self):
        if self._ease_target not in self._sic_ease_cubes:
            if self._ease_target == "EASE2-25":
                return super().sic_ease_cube
            from iris.coord_systems import PolarStereo
            from iris.coords import DimCoord
            if self.north:
                crs = PolarStereo(central_longitude=45, true_scale_latitude=70,
                                  latitude_of_projection_origin=90)
            else:
                crs = PolarStereo(central_longitude=0, true_scale_latitude=-71,
                                  latitude_of_projection_origin=-90)
            res_km = float(self._ease_target.split("-")[1])
            nx = ny = 313  # Kara ROI size, tweak if needed
            x = DimCoord(np.arange(nx) * res_km*1000,
                         'projection_x_coordinate', units='m', coord_system=crs)
            y = DimCoord(np.arange(ny) * res_km*1000,
                         'projection_y_coordinate', units='m', coord_system=crs)
            template = iris.cube.Cube(np.zeros((ny, nx), np.float32),
                                       dim_coords_and_dims=[(y,0),(x,1)])
            self._sic_ease_cubes[self._ease_target] = template
        return self._sic_ease_cubes[self._ease_target]

    # 2) override download request to add bbox
    def _single_api_download(self, var, level, req_dates, download_path):
        retrieve_dict = self._build_request_dict(var, level, req_dates)
        if self._bbox:
            lon_w, lat_s, lon_e, lat_n = self._bbox
            retrieve_dict['area'] = [lat_n, lon_w, lat_s, lon_e]
        else:
            retrieve_dict['area'] = self.hemisphere_loc
        return super()._single_api_download(var, level, req_dates, download_path)

    def _build_request_dict(self, var, level, req_dates):
        var_prefix = var[:-len(str(level))] if level else var
        return {
            "product_type":"reanalysis",
            "variable":self._cdi_map[var_prefix],
            "year":req_dates[0].year,
            "month": list({f"{d.month:02d}" for d in req_dates}),
            "day":[f"{d:02d}" for d in range(1,32)],
            "time":[f"{h:02d}:00" for h in range(24)],
            "format":"netcdf",
            **({"pressure_level":level} if level else {})
        }


### Re‑grid lat‑lon NetCDF to EASE2‑3 km

In [None]:
# regrid_kara.py
import iris, numpy as np
from icenet.data.utils import assign_lat_lon_coord_system

KARA_TEMPLATE = iris.load_cube("template_ease2-3_kara.nc")  # 313×313

def regrid_latlon_to_ease3(latlon_nc: str, out_nc: str):
    cube = iris.load_cube(latlon_nc)
    cube = assign_lat_lon_coord_system(cube)
    cube_e3 = cube.regrid(KARA_TEMPLATE, iris.analysis.Linear())
    iris.save(cube_e3, out_nc, fill_value=np.nan)


### ERA5 post‑processing

In [None]:
# postproc_vars.py
import iris
from icenet.data.utils import gridcell_angles_from_dim_coords, invert_gridcell_angles, rotate_grid_vectors

TEMPLATE = iris.load_cube("template_ease2-3_kara.nc")

def postprocess_var(nc_path: str):
    cube = iris.load_cube(nc_path)
    name = cube.name()
    if name in ("zg250", "zg500"):
        cube.data /= 9.80665
        cube.units = "m"
        iris.save(cube, nc_path, fill_value=np.nan)
    elif name == "tas" and cube.units == "K":
        cube.convert_units("celsius")
        iris.save(cube, nc_path, fill_value=np.nan)

def rotate_uas_vas(u_nc: str, v_nc: str):
    uas = iris.load_cube(u_nc)
    vas = iris.load_cube(v_nc)
    angles = gridcell_angles_from_dim_coords(TEMPLATE)
    invert_gridcell_angles(angles)
    uas_r, vas_r = rotate_grid_vectors(uas, vas, angles)
    iris.save(uas_r, u_nc, fill_value=np.nan)
    iris.save(vas_r, v_nc, fill_value=np.nan)


### OSI‑SAF SIC download + regrid

In [None]:
# osisaf_sic_download_regrid.py
import os, subprocess, numpy as np, iris
from datetime import date, timedelta

FTP_ROOT = "ftp://osisaf.met.no/reprocessed/ice/conc/v2p0"
HEMI = "nh"
KARA_TEMPLATE = iris.load_cube("template_ease2-3_kara.nc")

def wget_day(ymd: date, dest="sic25_raw"):
    fname = f"ice_conc_{HEMI}_ease2-250_cdr-v2p0_{ymd:%Y%m%d}1200.nc"
    url = f"{FTP_ROOT}/{ymd.year}/{ymd:%m}/{fname}"
    os.makedirs(dest, exist_ok=True)
    out = os.path.join(dest, fname)
    if not os.path.exists(out):
        subprocess.run(["wget","-q","-nc","-P",dest,url], check=True)
    return out

def regrid_sic25_to_3(raw_nc, out_nc):
    cube = iris.load_cube(raw_nc, "sea_ice_area_fraction")
    for c in ("projection_x_coordinate","projection_y_coordinate"):
        cube.coord(c).convert_units("meters")
    cube_e3 = cube.regrid(KARA_TEMPLATE, iris.analysis.Linear())
    iris.save(cube_e3.astype(np.float32), out_nc, fill_value=np.nan)

def main(start=date(1979,1,1), end=date(2024,12,31)):
    raw_dir, out_dir = "sic25_raw", "sic3k"
    cur = start
    while cur <= end:
        raw = wget_day(cur, raw_dir)
        out = os.path.join(out_dir,
               os.path.basename(raw).replace("ease2-250","ease2-3k"))
        if not os.path.exists(out):
            os.makedirs(out_dir, exist_ok=True)
            regrid_sic25_to_3(raw, out)
        cur += timedelta(days=1)


## 5 Dataset & DataLoader

In [None]:
# dataset_kara.py
import xarray as xr, numpy as np, torch
from torch.utils.data import Dataset

class IceNetKaraDataset(Dataset):
    def __init__(self, zarr_path, lag_days=28, horizon=1, split='train',
                 val_years=(2018,), test_years=(2022,)):
        ds = xr.open_zarr(zarr_path, consolidated=True)
        years = xr.apply_ufunc(lambda t: t.dt.year, ds.time).values
        if split == 'train':
            idx = ~np.isin(years, val_years + test_years)
        elif split == 'val':
            idx = np.isin(years, val_years)
        else:
            idx = np.isin(years, test_years)

        self.ds = ds.isel(time=idx).chunk({})
        self.lag, self.h = lag_days, horizon
        self.mask = xr.open_dataarray("watermask_kara_EASE2-3km.nc").astype(np.float32)
        self.pred_vars = [v for v in ds.data_vars if v != 'sea_ice_area_fraction']

    def __len__(self):
        return len(self.ds.time) - self.lag - self.h

    def __getitem__(self, i):
        t0 = i + self.lag - 1
        win = self.ds.isel(time=slice(i, t0+1))
        chans=[]
        for v in self.pred_vars:
            chans.append(win[v].data)           # (lag,H,W)
        x = np.stack(chans,1).reshape(-1,*chans[0].shape[-2:]).astype(np.float32)
        sic_now = win['sea_ice_area_fraction'].isel(time=-1).data.astype(np.float32)
        x = np.concatenate([x, sic_now[None]],0)
        sic_fut = self.ds['sea_ice_area_fraction'].isel(time=t0+self.h).data
        y = (sic_fut >= 0.15).astype(np.float32)
        m = self.mask.data
        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(m)


In [None]:
# loader_kara.py
from torch.utils.data import DataLoader
from dataset_kara import IceNetKaraDataset

def get_loaders(zarr_path, batch_size=8, lag_days=28, horizon=1):
    train_ds = IceNetKaraDataset(zarr_path, lag_days=lag_days, horizon=horizon, split='train')
    val_ds   = IceNetKaraDataset(zarr_path, lag_days=lag_days, horizon=horizon, split='val')
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_ld   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    return train_ld, val_ld
