In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from aurora import Batch, Metadata, AuroraSmallPretrained
from datetime import datetime, timedelta
from torch.nn import L1Loss
import torch.optim as optim
import xarray as xr
import numpy as np



In [1]:
from aurora import Aurora, Batch, Metadata
from aurora import Batch as BaseBatch
from aurora import Metadata as BaseMetadata
from aurora.batch import interpolate

# Import dependencies
import contextlib
from functools import partial
import dataclasses
from datetime import datetime, timedelta
import torch
from torch.nn import L1Loss  # = MAE: Mean Absolute Error = '.abs().mean()'
from torch.nn import Module  # for custom WeightedMAELoss (Aurora loss)
import torch.nn.functional as F  # for downsample batch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from peft import get_peft_model
import tempfile
import requests
import pickle
import dask
import numpy as np
import pandas as pd
import xarray as xr
import zarr
import gcsfs
from typing import Union, Callable

In [2]:
# Aurora core
from aurora import Aurora, Batch, Metadata
from aurora.batch import interpolate

# Dataset & dataloader
from torch.utils.data import Dataset, DataLoader

# Model training utils
import torch
import torch.nn.functional as F
from torch.nn import L1Loss, Module
import torch.optim as optim
from torch.amp import autocast

# LoRA PEFT
from peft import get_peft_model, LoraConfig

# Data processing
import xarray as xr
import numpy as np
import pandas as pd

# Filesystem + storage
import gcsfs
import zarr

# General utilities
from typing import Union, Callable
from datetime import datetime, timedelta
import dataclasses
import tempfile
import contextlib
from functools import partial


# Avoid CUDA error: invalid configuration argument on F.scaled_dot_product_attention
# https://stackoverflow.com/questions/77343471/pytorch-cuda-error-invalid-configuration-argument
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)


# Constants
STATIC_VARS_HF_URL = "https://huggingface.co/microsoft/aurora/resolve/main/aurora-0.25-static.pickle"
GCS_URL = "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr/"
LEAD_TIME = pd.Timedelta("6h")
AURORA_PRESSURE_LEVELS = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
AURORA_VARIABLE_NAMES = {
    "surface": [
        '10u',
        '10v',
        '2t',
        'msl',
    ],
    "atmospheric": [
        't',
        'u',
        'v',
        'q',
        'z',
    ],
    "static": [
        'z',  # geopotential_at_sea_level
        'lsm',  # land_sea_mask
        'slt',  # soil_type
    ]
}
VARIABLES_STATISTICS: dict[str, tuple[float, float]] = { # mean, std (location, scale)
    'z': (-1386.496, 58844.67),
    'lsm': (0.0, 1.0),
    'slt': (0.0, 7.0),
    '2t': (278.514, 21.22036),
    '10u': (-0.05135059, 5.547512),
    '10v': (0.189158, 4.765339),
    'msl': (100957.8, 1332.246),
    'z_50': (199373.0, 5875.553),
    'z_100': (157642.1, 5510.64),
    'z_150': (133141.4, 5823.912),
    'z_200': (115330.0, 5820.169),
    'z_250': (101223.1, 5536.585),
    'z_300': (89414.15, 5091.916),
    'z_400': (69980.38, 4150.851),
    'z_500': (54115.37, 3353.187),
    'z_600': (40648.33, 2695.808),
    'z_700': (28928.82, 2136.436),
    'z_850': (13749.78, 1470.321),
    'z_925': (7015.005, 1228.997),
    'z_1000': (738.1545, 1072.307),
    'u_50': (5.653076, 15.29281),
    'u_100': (10.27951, 13.52611),
    'u_150': (13.54061, 16.04335),
    'u_200': (14.20915, 17.6763),
    'u_250': (13.34584, 17.9671),
    'u_300': (11.80173, 17.11917),
    'u_400': (8.817291, 14.34276),
    'u_500': (6.563273, 11.98419),
    'u_600': (4.814521, 10.33421),
    'u_700': (3.345237, 9.168821),
    'u_850': (1.418379, 8.188043),
    'u_925': (0.6172657, 7.940808),
    'u_1000': (-0.03328723, 6.141778),
    'v_50': (0.004226111, 7.058931),
    'v_100': (0.01411897, 7.47931),
    'v_150': (-0.03697671, 9.57199),
    'v_200': (-0.04507801, 11.88069),
    'v_250': (-0.02980338, 13.38039),
    'v_300': (-0.0229477, 13.34044),
    'v_400': (-0.01771003, 11.22955),
    'v_500': (-0.02387986, 9.181708),
    'v_600': (-0.02716674, 7.803569),
    'v_700': (0.02153583, 6.87104),
    'v_850': (0.142815, 6.264443),
    'v_925': (0.205348, 6.470644),
    'v_1000': (0.1867637, 5.308203),
    't_50': (212.4864, 10.26284),
    't_100': (208.4042, 12.52901),
    't_150': (213.3201, 8.928709),
    't_200': (218.0615, 7.189547),
    't_250': (222.771, 8.529282),
    't_300': (228.8696, 10.71679),
    't_400': (242.1368, 12.69102),
    't_500': (252.9492, 13.06447),
    't_600': (261.1347, 13.42046),
    't_700': (267.401, 14.76523),
    't_850': (274.56, 15.5888),
    't_925': (277.3572, 16.08798),
    't_1000': (281.013, 17.13983),
    'q_50': (2.67818e-06, 3.571687e-07),
    'q_100': (2.633677e-06, 5.703754e-07),
    'q_150': (5.254625e-06, 3.794077e-06),
    'q_200': (1.940632e-05, 2.267534e-05),
    'q_250': (5.773618e-05, 7.446644e-05),
    'q_300': (0.0001273861, 0.0001684361),
    'q_400': (0.0003855659, 0.0005078644),
    'q_500': (0.0008529599, 0.001079294),
    'q_600': (0.001541429, 0.001769722),
    'q_700': (0.002431637, 0.002549169),
    'q_850': (0.004575618, 0.004112368),
    'q_925': (0.006033134, 0.005071058),
    'q_1000': (0.007030342, 0.005913548)
}

class Xaurora(Aurora):
    def __init__(self, *args, **kwargs) -> None:
        self.autocast = kwargs.pop("autocast", True)  # Remove 'autocast' from kwargs
        self.autocast_dtype = torch.bfloat16

        super().__init__(*args, **kwargs)
        
    def forward(self, batch: Batch, lead_time: timedelta) -> Batch:
        member_id = batch.metadata.member_id

        # Move batch to correct dtype and device
        p = next(self.parameters())
        batch = batch.type(p.dtype)
        batch = batch.normalise()
        batch = batch.crop(patch_size=self.encoder.patch_size)
        batch = batch.to(p.device)

        # Determine resolution of latent patch representation
        H, W = batch.spatial_shape
        patch_res = (
            self.encoder.latent_levels,
            H // self.encoder.patch_size,
            W // self.encoder.patch_size,
        )

        # Repeat static variables to match batch/time dims
        B, T = next(iter(batch.surf_vars.values())).shape[:2]
        batch = dataclasses.replace(
            batch,
            static_vars={k: v[None, None].repeat(B, T, 1, 1) for k, v in batch.static_vars.items()},
        )

        # Encode inputs
        x = self.encoder(batch, lead_time=lead_time)

        # Apply backbone with autocast if enabled
        with torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.autocast else contextlib.nullcontext():
            x = self.backbone(
                x,
                lead_time=lead_time,
                patch_res=patch_res,
                rollout_step=batch.metadata.rollout_step,
            )

        # Decode forecast
        pred = self.decoder(
            x,
            batch,
            lead_time=lead_time,
            patch_res=patch_res,
        )

        # Wrap in Batch object and restore metadata
        pred = Batch.from_aurora_batch(pred, member_id=member_id)
        assert pred.metadata.member_id == batch.metadata.member_id, "Member ID mismatch."

        # Remove temporal dims from static variables
        pred = dataclasses.replace(
            pred,
            static_vars={k: v[0, 0] for k, v in batch.static_vars.items()},
        )

        # Ensure output shape includes time dim (1 step)
        pred = dataclasses.replace(
            pred,
            surf_vars={k: v[:, None] for k, v in pred.surf_vars.items()},
            atmos_vars={k: v[:, None] for k, v in pred.atmos_vars.items()},
            metadata=dataclasses.replace(
                pred.metadata,
                rollout_step=batch.metadata.rollout_step + 1,
            )
        )

        # Unnormalize the output
        pred = pred.unnormalise()
        return pred

    
XauroraSmall = partial(
    Xaurora,
    encoder_depths=(2, 6, 2),
    encoder_num_heads=(4, 8, 16),
    decoder_depths=(2, 6, 2),
    decoder_num_heads=(16, 8, 4),
    embed_dim=256,
    num_heads=8,
    use_lora=False,
)


@dataclasses.dataclass
class Metadata(BaseMetadata):
    """Metadata in a batch.

    Args:
        lat (:class:`torch.Tensor`): Latitudes.
        lon (:class:`torch.Tensor`): Longitudes.
        time (tuple[datetime, ...]): For every batch element, the time.
        atmos_levels (tuple[int | float, ...]): Pressure levels for the atmospheric variables in
            hPa.
        rollout_step (int, optional): How many roll-out steps were used to produce this prediction.
            If equal to `0`, which is the default, then this means that this is not a prediction,
            but actual data. This field is automatically populated by the model and used to use a
            separate LoRA for every roll-out step. Generally, you are safe to ignore this field.
        member_id (int, optional): The member ID of the ensemble member. It defaults to `0`.
    """

    lat: torch.Tensor
    lon: torch.Tensor
    time: tuple[datetime, ...]
    atmos_levels: tuple[int | float, ...]
    rollout_step: int = 0
    member_id: int|list[int] = None


@dataclasses.dataclass
class Batch:
    """A batch of data.

    Args:
        surf_vars (dict[str, :class:`torch.Tensor`]): Surface-level variables with shape
            `(b, t, h, w)`.
        static_vars (dict[str, :class:`torch.Tensor`]): Static variables with shape `(h, w)`.
        atmos_vars (dict[str, :class:`torch.Tensor`]): Atmospheric variables with shape
            `(b, t, c, h, w)`.
        metadata (:class:`Metadata`): Metadata associated to this batch.
    """

    surf_vars: dict[str, torch.Tensor]
    static_vars: dict[str, torch.Tensor]
    atmos_vars: dict[str, torch.Tensor]
    metadata: Metadata
    
    @property
    def spatial_shape(self) -> tuple[int, int]:
        """Get the spatial shape from an arbitrary surface-level variable."""
        return next(iter(self.surf_vars.values())).shape[-2:]
    
    def save(self, path: str) -> None:
        """Save the batch to a file."""
        torch.save(self, path)
    
    @staticmethod
    def load(path: str) -> "Batch":
        """Load a batch from a file."""
        return torch.load(path, weights_only=False)
    
    def fillna(self, how: str="spatial_conv", **how_kwargs):
        if how == "spatial_conv":
            return Batch(
                surf_vars={k: fillna_spatial_mean_conv(v, **how_kwargs) for k, v in self.surf_vars.items()},
                static_vars={k: fillna_spatial_mean_conv(v, **how_kwargs) for k, v in self.static_vars.items()},
                atmos_vars={k: fillna_spatial_mean_conv(v, **how_kwargs) for k, v in self.atmos_vars.items()},
                metadata=self.metadata,
            )
        else:
            raise ValueError(f"Unknown fillna method {how}.")
    
    def normalise(
        self, 
        stats: dict[str, tuple[float, float]]=VARIABLES_STATISTICS
    ) -> "Batch":
        assert all(
            k in stats.keys() for k in self.surf_vars.keys()
        ), "Not all surface variables have statistics."
        return Batch(
            surf_vars={
                k: normalise_surf_var(v, k, stats=stats) for k, v in self.surf_vars.items()
            },
            static_vars={
                k: normalise_surf_var(v, k, stats=stats) for k, v in self.static_vars.items()
            },
            atmos_vars={
                k: normalise_atmos_var(v, k, self.metadata.atmos_levels, stats=stats)
                for k, v in self.atmos_vars.items()
            },
            metadata=self.metadata,
        )
    
    def unnormalise(
        self, 
        stats: dict[str, tuple[float, float]]=VARIABLES_STATISTICS
    ) -> "Batch":
        assert all(
            k in stats.keys() for k in self.surf_vars.keys()
        ), "Not all surface variables have statistics."
        return Batch(
            surf_vars={
                k: unnormalise_surf_var(v, k, stats=stats) for k, v in self.surf_vars.items()
            },
            static_vars={
                k: unnormalise_surf_var(v, k, stats=stats) for k, v in self.static_vars.items()
            },
            atmos_vars={
                k: unnormalise_atmos_var(v, k, self.metadata.atmos_levels, stats=stats)
                for k, v in self.atmos_vars.items()
            },
            metadata=self.metadata,
        )
        
    def crop(self, patch_size: int) -> "Batch":
        """Crop the variables in the batch to patch size `patch_size`."""
        h, w = self.spatial_shape

        if w % patch_size != 0:
            raise ValueError("Width of the data must be a multiple of the patch size.")

        if h % patch_size == 0:
            return self
        
        elif h % patch_size == 1:
            return Batch(
                surf_vars={k: v[..., :-1, :] for k, v in self.surf_vars.items()},
                static_vars={k: v[..., :-1, :] for k, v in self.static_vars.items()},
                atmos_vars={k: v[..., :-1, :] for k, v in self.atmos_vars.items()},
                metadata=Metadata(
                    lat=self.metadata.lat[:-1],
                    lon=self.metadata.lon,
                    atmos_levels=self.metadata.atmos_levels,
                    time=self.metadata.time,
                    rollout_step=self.metadata.rollout_step,
                    member_id=self.metadata.member_id,
                ),
            )
        else:
            raise ValueError(
                f"There can at most be one latitude too many, "
                f"but there are {h % patch_size} too many."
            )
            
    def crop_right(self, n: int) -> "Batch":
        """Crop the rightmost `n` columns of the variables in the batch."""
        return Batch(
            surf_vars={k: v[..., :-n, :] for k, v in self.surf_vars.items()},
            static_vars={k: v[..., :-n, :] for k, v in self.static_vars.items()},
            atmos_vars={k: v[..., :-n, :] for k, v in self.atmos_vars.items()},
            metadata=Metadata(
                lat=self.metadata.lat[:-n],
                lon=self.metadata.lon,
                atmos_levels=self.metadata.atmos_levels,
                time=self.metadata.time,
                rollout_step=self.metadata.rollout_step,
                member_id=self.metadata.member_id,
            ),
        )

    def _fmap(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Batch":
        return Batch(
            surf_vars={k: f(v) for k, v in self.surf_vars.items()},
            static_vars={k: f(v) for k, v in self.static_vars.items()},
            atmos_vars={k: f(v) for k, v in self.atmos_vars.items()},
            metadata=Metadata(
                lat=f(self.metadata.lat),
                lon=f(self.metadata.lon),
                atmos_levels=self.metadata.atmos_levels,
                time=self.metadata.time,
                rollout_step=self.metadata.rollout_step,
                member_id=self.metadata.member_id,
            ),
        )

    def to(self, device: str | torch.device) -> "Batch":
        """Move the batch to another device."""
        return self._fmap(lambda x: x.to(device))

    def type(self, t: type) -> "Batch":
        """Convert everything to type `t`."""
        return self._fmap(lambda x: x.type(t))

    def regrid(self, res: float) -> "Batch":
        """Regrid the batch to a `res` degrees resolution.

        This results in `float32` data on the CPU.

        This function is not optimised for either speed or accuracy. Use at your own risk.
        """

        shape = (round(180 / res) + 1, round(360 / res))
        lat_new = torch.from_numpy(np.linspace(90, -90, shape[0]))
        lon_new = torch.from_numpy(np.linspace(0, 360, shape[1], endpoint=False))
        interpolate_res = partial(
            interpolate,
            lat=self.metadata.lat,
            lon=self.metadata.lon,
            lat_new=lat_new,
            lon_new=lon_new,
        )

        return Batch(
            surf_vars={k: interpolate_res(v) for k, v in self.surf_vars.items()},
            static_vars={k: interpolate_res(v) for k, v in self.static_vars.items()},
            atmos_vars={k: interpolate_res(v) for k, v in self.atmos_vars.items()},
            metadata=Metadata(
                lat=lat_new,
                lon=lon_new,
                atmos_levels=self.metadata.atmos_levels,
                time=self.metadata.time,
                rollout_step=self.metadata.rollout_step,
                member_id=self.metadata.member_id,
            ),
        )
        
    def downsample(self, factor: int) -> "Batch":
        """Downsample batch by factor using Batch.regrid.
        
        See notes about using it carefully in .regrid"""
        
        # detach and to cpu 
        if self.metadata.lat.device != torch.device('cpu'):
            out = self._fmap(lambda x: x.detach().cpu())
        else:
            out = self
        
        _, W = out.spatial_shape
        current_resolution = 360 / W
        new_resolution = current_resolution * factor
        
        return out.regrid(new_resolution)
    
    @classmethod        
    def from_aurora_batch(cls, aurora_batch: BaseBatch, **metadata_kwargs) -> "Batch":
        return cls(
            surf_vars=aurora_batch.surf_vars,
            static_vars=aurora_batch.static_vars,
            atmos_vars=aurora_batch.atmos_vars,
            metadata=Metadata(
                lat=aurora_batch.metadata.lat,
                lon=aurora_batch.metadata.lon,
                time=aurora_batch.metadata.time,
                atmos_levels=aurora_batch.metadata.atmos_levels,
                rollout_step=aurora_batch.metadata.rollout_step,
                **metadata_kwargs
            ),
        )


class WeightedMAELoss(Module):
    """
    Weighted Mean Absolute Error loss used in Aurora fine-tuning.
    Updated to use `reduction='mean'` instead of manually normalizing.
    """
    def __init__(
        self,
        gamma=1.0,  # dataset weight — reduced from 2.0 to stabilize magnitude
        alpha=0.25,  # overall surface loss weight
        beta=1.0,  # overall atmospheric loss weight
        surf_var_weights=None,  # dict of surface variable weights
        atmos_var_weights=None  # dict of atmospheric variable weights
    ):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.beta = beta
        self.l1loss = L1Loss(reduction='mean')  # avg over all pixels/levels

        self.surf_var_weights = surf_var_weights or {
            "2t": 1.0, "10u": 1.0, "10v": 1.0, "msl": 1.0
        }
        self.atmos_var_weights = atmos_var_weights or {
            "z": 1.0, "q": 1.0, "t": 1.0, "u": 1.0, "v": 1.0
        }

    def forward(self, pred_batch, target_batch) -> torch.Tensor:
        device = next(iter(pred_batch.surf_vars.values())).device

        surface_loss_sum = torch.tensor(0.0, device=device)
        atmospheric_loss_sum = torch.tensor(0.0, device=device)

        for k in pred_batch.surf_vars:
            w_S_k = self.surf_var_weights.get(k, 1.0)
            var_mae = self.l1loss(pred_batch.surf_vars[k], target_batch.surf_vars[k])
            surface_loss_sum += w_S_k * var_mae

        for k in pred_batch.atmos_vars:
            w_A_k = self.atmos_var_weights.get(k, 1.0)
            var_mae = self.l1loss(pred_batch.atmos_vars[k], target_batch.atmos_vars[k])
            atmospheric_loss_sum += w_A_k * var_mae

        VS = len(pred_batch.surf_vars)
        VA = len(pred_batch.atmos_vars)

        combined_loss = self.alpha * surface_loss_sum + self.beta * atmospheric_loss_sum
        final_loss = (self.gamma / (VS + VA)) * combined_loss

        return final_loss



class AuroraDataset(Dataset):
    def __init__(
        self,
        # xr Datasets
        surface_ds: xr.Dataset,
        atmospheric_ds: xr.Dataset,
        static_ds: xr.Dataset,
        # variables
        surface_variables: list[str]=AURORA_VARIABLE_NAMES["surface"],
        atmospheric_variables: list[str]=AURORA_VARIABLE_NAMES["atmospheric"],
        static_variables: list[str]=AURORA_VARIABLE_NAMES["static"],
        # temporal parameters
        base_frequency: Union[str, pd.Timedelta]="6h",
        input_temporal_length: Union[str, int, pd.Timedelta]="12h",
        output_temporal_length: Union[str, int, pd.Timedelta]="6h",
        inter_sample_gap: Union[str, int, pd.Timedelta]="6h",
        forecast_horizon: Union[str, int, pd.Timedelta]="6w",
        init_frequency: Union[str, int, pd.Timedelta]="1d",
        init_gap: Union[str, int, pd.Timedelta]="0h",
        downsampling_rate: int=1,
    ):
        """
        Initialise the AuroraDataset.

        Args:
            surface_ds: xr.Dataset
                The surface dataset.
            atmospheric_ds: xr.Dataset
                The atmospheric dataset.
            static_ds: xr.Dataset
                The static dataset.
            surface_variables: list[str]
                The surface variables to include. Defaults to Aurora's surface variables.
            atmospheric_variables: list[str]
                The atmospheric variables to include. Defaults to Aurora's atmospheric variables.
            static_variables: list[str]
                The static variables to include. Defaults to Aurora's static variables.
            base_frequency: Union[str, pd.Timedelta]
                The base frequency of the data. Defaults to 6 hours.
            input_temporal_length: Union[str, int, pd.Timedelta]
                The length of the input time series. If `int`, considered to be a multiple of `base_frequency`. Defaults to 12 hours.
            output_temporal_length: Union[str, int, pd.Timedelta]
                The length of the output time series. If `int`, considered to be a multiple of `base_frequency`. Defaults to 6 hours.
            inter_sample_gap: Union[str, int, pd.Timedelta]
                The gap between consecutive samples. If `int`, considered to be a multiple of `base_frequency`. Defaults to 6 hours.
            forecast_horizon: Union[str, int, pd.Timedelta]
                The forecast horizon. If `int`, considered to be a multiple of `base_frequency`. Defaults to 6 weeks.
            init_frequency: Union[str, int, pd.Timedelta]
                The frequency of the initialisation times. If `int`, considered to be a multiple of `base_frequency`. Defaults to 1 day.
            init_gap: Union[str, int, pd.Timedelta]
                The gap between the initialisation time and the start of the input time series. If `int`, considered to be a multiple of 
                `base_frequency`. Defaults to 0 hours.
            downsampling_rate: int
                The downsampling rate of the dataset. Defaults to 1.
        """
        super().__init__()

        # get sorted datasets
        # Aurora requires latitudes to be in descending order
        # and longitudes to be in ascending order
        self.surface_ds = surface_ds.sortby("latitude", ascending=False).sortby("longitude", ascending=True)
        self.atmospheric_ds = atmospheric_ds.sortby("latitude", ascending=False).sortby("longitude", ascending=True)
        self.static_ds = static_ds.sortby("latitude", ascending=False).sortby("longitude", ascending=True)

        self.atmospheric_variables = atmospheric_variables
        self.surface_variables = surface_variables
        self.static_variables = static_variables

        self.base_frequency = pd.Timedelta(base_frequency) if isinstance(base_frequency, str) else base_frequency
        self.input_temporal_length = convert_to_steps(input_temporal_length, self.base_frequency, check_valid=True)
        self.output_temporal_length = convert_to_steps(output_temporal_length, self.base_frequency, check_valid=True)
        self.inter_sample_gap = convert_to_steps(inter_sample_gap, self.base_frequency, check_valid=True)
        self.forecast_horizon = convert_to_steps(forecast_horizon, self.base_frequency, check_valid=True)
        self.init_frequency = convert_to_steps(init_frequency, self.base_frequency, check_valid=True)
        self.init_gap = convert_to_steps(init_gap, self.base_frequency, check_valid=True)

        self.downsampling_rate = downsampling_rate
        self._spatial_shape = None

        # extract timestamps
        assert (surface_ds.time == atmospheric_ds.time).all(), f"got different timestamps for surface and atmospheric data."
        self.timestamps = self.surface_ds.time.values.astype("datetime64[s]")
        
        # add init gap
        if self.init_gap < 0:
            # reverse the init gap index based on the length of self.timestamps
            self.init_gap = len(self.timestamps) + self.init_gap
        self.timestamps = self.timestamps[self.init_gap:]        

        # compute feasible length of the time series
        discarded_timesteps = self.forecast_horizon + self.output_temporal_length + self.input_temporal_length + self.inter_sample_gap
        self.num_samples = len(self.timestamps) - discarded_timesteps
        assert self.timestamps[self.num_samples-1] + discarded_timesteps * self.base_frequency == self.timestamps[-1]

        # add init frequency
        self.num_samples = self.num_samples // self.init_frequency


    @classmethod
    def from_cloud_storage(
        cls,
        gcs_url: str,
        start_year: str|int=None,
        end_year: str|int=None,
        static_url: str=STATIC_VARS_HF_URL,
        atmospheric_variables: list[str]=AURORA_VARIABLE_NAMES["atmospheric"],
        surface_variables: list[str]=AURORA_VARIABLE_NAMES["surface"],
        pressure_levels: list[int]=AURORA_PRESSURE_LEVELS,
        variable_names_map: dict[str, str]=None,
        **kwargs
    ) -> "AuroraDataset":
        """
        No need to download anything locally!

        Args:
            gcs_url: str
                The url to the zarr store in the cloud storage.
            static_url: str
                The url to the static dataset in the cloud storage.
            atmospheric_variables: list[str]
                The atmospheric variables to include. Defaults to Aurora's atmospheric variables.
            surface_variables: list[str]
                The surface variables to include. Defaults to Aurora's surface variables.
            pressure_levels: list[int]
                The pressure levels to include. Defaults to Aurora's pressure levels.
            variable_names_map: dict[str, str]
                A dictionary mapping the variable names to the desired names.
            **kwargs:
                Additional arguments to pass to the AuroraDataset constructor.
        Returns:
            AuroraDataset
        """
        static_ds = load_static_ds_local("static_data/static.nc")

        # load surface and atmospheric ds
        surface_ds, atmospheric_ds = load_gcs_datasets(
            gcs_url=gcs_url,
            start_year=start_year,
            end_year=end_year,
            atmospheric_variables=atmospheric_variables,
            surface_variables=surface_variables,
            pressure_levels=pressure_levels,
            variable_names_map=variable_names_map
        )
        return cls(
            surface_ds=surface_ds,
            atmospheric_ds=atmospheric_ds,
            static_ds=static_ds,
            **kwargs
        )

    @property
    def spatial_shape(self) -> tuple[int]:
        """
        return (nlat, nlon)
        """
        # return if already computed
        if self._spatial_shape is not None:
            return self._spatial_shape
        # compute
        batch = self[0][0]
        if self.downsampling_rate is not None and self.downsampling_rate > 1:
            with torch.no_grad():
                batch = downsample_batch(batch, self.downsampling_rate)

        _,_,H,W = batch.surf_vars[next(iter(batch.surf_vars.keys()))].shape
        self._spatial_shape = (H, W)
        return self._spatial_shape

    def get_matching_timestamp(self, timestamp: np.datetime64|datetime) -> np.datetime64:
        """
        Find a create an OUTPUT batch from the requested timestamp. 
        """
        if self.output_temporal_length > 1:
            raise NotImplementedError("get_matching_timestamp is only implemented for output_temporal_length=1.")

        if isinstance(timestamp, datetime):
            timestamp = np.datetime64(timestamp)

        # get output slice
        output_slice = [
            timestamp + np.timedelta64(pd.Timedelta(self.base_frequency).value * n, "s")
            for n in range(self.output_temporal_length)
        ]

        # return the batch
        return xr_to_batch(
            self.surface_ds.sel(time=output_slice).compute(),
            self.atmospheric_ds.sel(time=output_slice).compute(),
            self.static_ds.compute(),
            surface_variables=self.surface_variables,
            static_variables=self.static_variables,
            atmospheric_variables=self.atmospheric_variables
        )

    def get_matching_output(self, batch: Batch) -> Batch:
        """
        Creates a Batch from the requested timestamp. The timestamp is the first 
        timestamp of the output time series of length `output_temporal_length`.
        """
        T = batch.atmos_vars[next(iter(batch.atmos_vars.keys()))].shape[1]
        if T != 1:
            raise NotImplementedError("get_matching_output is only implemented for T=1.")

        # get the times
        times = list(batch.metadata.time)
        output_batch = []

        # loop over requested times
        for time in times:
            new = self.get_matching_timestamp(time)
            output_batch.append(new)

        # collate
        output_batch = batch_collate_fn(output_batch)

        # apply downsampling
        if self.downsampling_rate is not None and self.downsampling_rate > 1:
            with torch.no_grad():
                output_batch = downsample_batch(output_batch, self.downsampling_rate)

        return output_batch

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Batch:
        if idx < 0:
            idx = self.num_samples + idx

        # add init frequency
        idx = idx * self.init_frequency

        # get input and output indexes
        input_slice = slice(idx, idx+self.input_temporal_length)
        output_slice = slice(idx+self.input_temporal_length+self.inter_sample_gap-1,
                             idx+self.input_temporal_length+self.inter_sample_gap-1+self.output_temporal_length)

        # get timestamps
        input_timestamps = self.timestamps[input_slice]
        output_timestamps = self.timestamps[output_slice]

        # make a few checks
        assert len(input_timestamps) == self.input_temporal_length
        assert len(output_timestamps) == self.output_temporal_length
        assert input_timestamps[-1] + self.inter_sample_gap * pd.Timedelta(self.base_frequency) == output_timestamps[0]

        input_batch = xr_to_batch(
            self.surface_ds.sel(time=input_timestamps).compute(),
            self.atmospheric_ds.sel(time=input_timestamps).compute(),
            self.static_ds.compute(),
            surface_variables=self.surface_variables,
            static_variables=self.static_variables,
            atmospheric_variables=self.atmospheric_variables
        )

        output_batch = xr_to_batch(
            self.surface_ds.sel(time=output_timestamps).compute(),
            self.atmospheric_ds.sel(time=output_timestamps).compute(),
            self.static_ds.compute(),
            surface_variables=self.surface_variables,
            static_variables=self.static_variables,
            atmospheric_variables=self.atmospheric_variables
        )

        if self.downsampling_rate is not None and self.downsampling_rate > 1:
            with torch.no_grad():
                input_batch = downsample_batch(input_batch, self.downsampling_rate)
                output_batch = downsample_batch(output_batch, self.downsampling_rate)

        return input_batch, output_batch


# Functions
def normalise_surf_var(
    x: torch.Tensor,
    name: str,
    stats: dict[str, tuple[float, float]],
    unnormalise: bool = False,
) -> torch.Tensor:
    """Normalise a surface-level variable."""
    location, scale = stats[name]
    if unnormalise:
        return x * scale + location
    else:
        return (x - location) / scale


def normalise_atmos_var(
    x: torch.Tensor,
    name: str,
    atmos_levels: tuple[int | float, ...],
    stats: dict[str, tuple[float, float]],
    unnormalise: bool = False,
) -> torch.Tensor:
    """Normalise an atmospheric variable."""
    level_locations: list[int | float] = []
    level_scales: list[int | float] = []
    for level in atmos_levels:
        name_level = f"{name}_{level}"
        level_locations.append(stats[name_level][0])
        level_scales.append(stats[name_level][1])
    location = torch.tensor(level_locations, dtype=x.dtype, device=x.device)
    scale = torch.tensor(level_scales, dtype=x.dtype, device=x.device)

    if unnormalise:
        return x * scale[..., None, None] + location[..., None, None]
    else:
        return (x - location[..., None, None]) / scale[..., None, None]


unnormalise_surf_var = partial(normalise_surf_var, unnormalise=True)
unnormalise_atmos_var = partial(normalise_atmos_var, unnormalise=True)


def fillna_spatial_mean_conv(tensor: torch.Tensor, kernel_size: int = 3, iterations: int = 5) -> torch.Tensor:
    """
    Approximate filling of NaNs via convolutional local mean.
    
    Supports:
      - 4D (B, T, H, W)
      - 5D (B, T, L, H, W)

    Args:
        tensor: Tensor with NaNs (on GPU or CPU).
        kernel_size: Spatial window size (must be odd).
        iterations: Number of fill iterations.

    Returns:
        Tensor with NaNs replaced by local mean.
    """
    assert kernel_size % 2 == 1, "kernel_size must be odd"
    
    if not torch.isnan(tensor).any():
        return tensor  # No NaNs — early exit

    device = tensor.device
    filled = tensor.clone()
    nan_mask = torch.isnan(filled)
    filled[nan_mask] = 0  # Replace NaNs with 0s for now

    # Set up convolution kernel
    kernel_dim = 2 if tensor.ndim == 4 else 3
    spatial_dims = (-2, -1) if kernel_dim == 2 else (-3, -2, -1)
    padding = kernel_size // 2

    # Define kernel
    kernel_shape = [1] * (tensor.ndim - kernel_dim) + [1] * kernel_dim
    kernel = torch.ones(kernel_shape[:-kernel_dim] + [kernel_size] * kernel_dim, device=device)

    for _ in range(iterations):
        valid_mask = ~torch.isnan(tensor)
        weights = valid_mask.float()
        values = filled * weights

        if kernel_dim == 2:
            smoothed_vals = F.conv2d(values.view(-1, 1, *values.shape[-2:]), kernel, padding=padding, groups=1)
            smoothed_weights = F.conv2d(weights.view(-1, 1, *weights.shape[-2:]), kernel, padding=padding, groups=1)
        else:  # 3D
            smoothed_vals = F.conv3d(values.view(-1, 1, *values.shape[-3:]), kernel, padding=padding, groups=1)
            smoothed_weights = F.conv3d(weights.view(-1, 1, *weights.shape[-3:]), kernel, padding=padding, groups=1)

        smoothed_mean = smoothed_vals / (smoothed_weights + 1e-6)
        smoothed_mean = smoothed_mean.view_as(filled)

        # Update only NaN locations
        filled[nan_mask] = smoothed_mean[nan_mask]

    return filled


def convert_to_steps(temporal_length: Union[str, int, pd.Timedelta], 
                      base_frequency: Union[str, pd.Timedelta],
                      check_valid: bool=True) -> int:
    # convert to timedelta
    if isinstance(temporal_length, str):
        temporal_length = pd.Timedelta(temporal_length)
    elif isinstance(temporal_length, int):
        return temporal_length
    # check if it is a multiple of the base frequency
    if check_valid:
        assert temporal_length % pd.Timedelta(base_frequency) == pd.Timedelta(0), \
            "temporal_length must be a multiple of base_frequency"
    return int(temporal_length / base_frequency)


def downsample_batch(batch, factor, pooling=F.avg_pool2d):
    return batch.downsample(factor)


def batch_collate_fn(batches: list[Batch]|None) -> Batch:
    """
    Custom collate function to combine multiple Batch objects. Collating custom
    objects such as Batch in PyTorch requires a custom function. For efficient
    parallel processing on the GPU, individual samples are combined into a single
    larger 'batch'. In the context of Aurora, one 'sample' is one Aurora Batch
    instance: all the atmospheric and surface variables for a single point in
    time at a specific location.
    """
    # Check whether the input is of type Batch
    _batches = batches.copy()
    for i, batch in enumerate(_batches):
        if batch is None: batches.pop(i)
        elif not isinstance(batch, Batch):
            raise ValueError(f"Expected a list of Aurora batches or NoneType, got {type(batch)}")
    batches = _batches

    if len(batches) == 0:
        return # nothing to batch return None
    elif len(batches) == 1:
        return batches[0] # nothing to batch return the single batch

    # Prediction batches have a single time sample apparently
    times = []
    for batch in batches:
        time = batch.metadata.time
        if isinstance(time, (tuple, list)):
            times.extend(list(time))
        else:
            times.append(time)
    times = tuple(times)

    # batch the data
    return Batch(
        surf_vars={
            var: torch.cat([batch.surf_vars[var] for batch in batches], dim=0)
            for var in batches[0].surf_vars
        },
        atmos_vars={
            var: torch.cat([batch.atmos_vars[var] for batch in batches], dim=0)
            for var in batches[0].atmos_vars
        },
        static_vars={
            var: batches[0].static_vars[var]
            for var in batches[0].static_vars
        },
        metadata=Metadata(
            lat=batches[0].metadata.lat,
            lon=batches[0].metadata.lon,
            atmos_levels=batches[0].metadata.atmos_levels,
            rollout_step=batches[0].metadata.rollout_step,
            time=times,
        )
    )


def batch_pairs_collate_fn(batch_pairs: list[tuple[Batch, Batch]]) -> tuple[Batch, Batch]:
    """
    Collate function for batch input/output pairs. This function is a wrapper designed to 
    handle the common scenario in supervised learning where the DataLoader yields pairs 
    of (input, target).

    Input: It expects a list where each element is a tuple containing an input Batch and a 
    target Batch. 
    E.g., [(input_batch_1, target_batch_1), (input_batch_2, target_batch_2), ...] could be 
    such a list.
    """
    input_batch = batch_collate_fn([batch_pair[0] for batch_pair in batch_pairs])
    output_batch = batch_collate_fn([batch_pair[1] for batch_pair in batch_pairs])

    return input_batch, output_batch


def prepare_lons_lats(lons: np.ndarray, lats: np.ndarray) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Aurora requires decreasing latitudes and increasing longitudes
    """
    if lats[0] < lats[1]: # i.e. increasing
        lats = torch.from_numpy(lats[::-1].copy())
        flip_lats = True
    else:
        lats = torch.from_numpy(lats)
        flip_lats = False

    if lons[0] < lons[1]: # i.e. increasing
        lons = torch.from_numpy(lons)
        flip_lons = False
    else:
        lons = torch.from_numpy(lons[::-1].copy())
        flip_lons = True

    return lons, lats, flip_lats, flip_lons


def prepare_array(x: np.ndarray, shape: tuple[int], flip_lons: bool, flip_lats: bool) -> torch.Tensor:
    x = x.reshape(shape).copy()
    if flip_lons: x = x[...,::-1].copy()
    if flip_lats: x = x[...,::-1,:].copy()
    return torch.from_numpy(x).clone()


def rename_xr_variables(ds: xr.Dataset, variable_names_map: dict[str, str]):
    """
    Directly taken from Eliot's data/utils.py.
    """
    # Intersect with data_vars
    variable_names_map = {k: v for k, v in variable_names_map.items() if k in list(ds.data_vars)+list(ds.coords)}

    # Rename the variables in the dataset
    renamed_ds = ds.rename(variable_names_map)

    return renamed_ds


def load_gcs_datasets(
    gcs_url: str=GCS_URL,
    num_debug_timesteps: int=0,
    start_year: str|int=None,
    end_year: str|int=None,
    atmospheric_variables: list[str]=AURORA_VARIABLE_NAMES["atmospheric"],
    surface_variables: list[str]=AURORA_VARIABLE_NAMES["surface"],
    pressure_levels: list[int]=AURORA_PRESSURE_LEVELS,
    variable_names_map: dict[str, str]=None,
):
    """
    A modified version of the function in Eliot's data/utils.py, allowing for
    selecting only a small number of recent time points for debugging
    purposes.
    """
    if start_year is None:
        assert end_year is None, "If start_year is None, end_year must be None"
    if end_year is None:
        assert start_year is None, "If end_year is None, start_year must be None"

    ds = xr.open_dataset(
        gcs_url,
        engine="zarr",
        chunks={},
        storage_options={"token": "anon"},
    )

    if num_debug_timesteps:  # Debug mode: select a small number of recent time points
        ds = ds.isel(time=slice(-num_debug_timesteps, None))

        # Create a slice from the selected times
        time_slice = slice(ds.time.values[0], ds.time.values[-1])
    elif start_year is not None and end_year is not None:
        start_date = f"{start_year}-01-01"
        end_date = f"{end_year}-12-31"
        time_slice = slice(start_date, end_date)
    else:
        time_slice = None

    # Remove -90 from latitude (i.e. 721 -> 720)
    # Solves ValueError: cannot reshape array of size 1036800 into shape (721,1440)
    if 90 in ds.latitude.values and -90 in ds.latitude.values:
        ds = ds.where(ds.latitude!=-90, drop=True)

    # Get the rename map
    if variable_names_map is None:
        freq = pd.infer_freq(ds.time.values)

        if not freq[0].isnumeric():
            freq = "1" + freq

        if pd.Timedelta(freq) in [pd.Timedelta("6h"), pd.Timedelta("24h")]:
            variable_names_map = ERA5_HRES_T0_WB2_VARIABLE_NAMES_MAP
        else:
            raise ValueError(f"No variable name map for freq {freq}")

    if len(surface_variables) > 0:
        surface_ds = rename_xr_variables(
            ds,
            variable_names_map
        )[surface_variables]

        if time_slice is not None:
            surface_ds = surface_ds.sel(time=time_slice)
            surface_ds = surface_ds.sortby("latitude", ascending=False).sortby("longitude", ascending=True)
    else:
        surface_ds = None

    if len(atmospheric_variables) > 0:
        atmospheric_ds = rename_xr_variables(
            ds,
            variable_names_map
        )[atmospheric_variables].sel(level=pressure_levels)

        if time_slice is not None:
            atmospheric_ds = atmospheric_ds.sel(time=time_slice)
            atmospheric_ds = atmospheric_ds.sortby("latitude", ascending=False).sortby("longitude", ascending=True)
    else:
        atmospheric_ds = None

    return surface_ds, atmospheric_ds


def xr_to_batch(
    surface_ds: xr.Dataset,
    atmospheric_ds: xr.Dataset,
    static_ds: xr.Dataset,
    surface_variables: list[str]=AURORA_VARIABLE_NAMES["surface"],
    static_variables: list[str]=AURORA_VARIABLE_NAMES["static"],
    atmospheric_variables: list[str]=AURORA_VARIABLE_NAMES["atmospheric"],
) -> Batch:
    """
    Create an Aurora Batch from XR Datasets.

    inspired by https://microsoft.github.io/aurora/example_era5.html
    and https://microsoft.github.io/aurora/example_hres_t0.html
    """
    # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
    # `datetime.datetime`s. Note that this needs to be a tuple of length one:
    # one value for every batch element.
    # temporally, we want index 0 to be PREVIOUS time step and index 1 to be CURRENT time step
    # the metadata 'time' refers to the CURRENT time step, i.e. the last index
    if surface_ds.sizes["time"] == 1:
        _time = (surface_ds.time.values[0].astype("datetime64[s]").item(), )
    else:
        times = list(sorted(surface_ds.time.values.astype("datetime64[s]").tolist()))
        _time = (times[-1],) # only the last

    # get shapes for explicit reshaping and get lons, lats, levels
    # the process is repeated for each dataset because the
    # datasets can be empty
    if static_ds is not None and len(static_ds) > 0:
        H, W = static_ds.sizes["latitude"], static_ds.sizes["longitude"]

        # 1. get lons, lats
        lons = static_ds.longitude.values.copy()
        lats = static_ds.latitude.values.copy()

        # 2. prepare lons, lats
        lons, lats, flip_lats_static, flip_lons_static = prepare_lons_lats(lons, lats)

    else:
        flip_lats_static = False
        flip_lons_static = False

    if surface_ds is not None and len(surface_ds) > 0:
        # sort time
        surface_ds = surface_ds.sortby("time", ascending=True)
        T, H, W = surface_ds.sizes["time"], surface_ds.sizes["latitude"], surface_ds.sizes["longitude"]

        # 1. get lons, lats
        lons = surface_ds.longitude.values.copy()
        lats = surface_ds.latitude.values.copy()

        # 2. prepare lons, lats
        lons, lats, flip_lats_surface, flip_lons_surface = prepare_lons_lats(lons, lats)
    else:
        flip_lats_surface = False
        flip_lons_surface = False

    if atmospheric_ds is not None and len(atmospheric_ds) > 0:
        # sort time
        atmospheric_ds = atmospheric_ds.sortby("time", ascending=True)
        C, T, H, W = (
            atmospheric_ds.sizes["level"],
            atmospheric_ds.sizes["time"],
            atmospheric_ds.sizes["latitude"],
            atmospheric_ds.sizes["longitude"],
        )
        # 1. get lons, lats
        lons = atmospheric_ds.longitude.values.copy()
        lats = atmospheric_ds.latitude.values.copy()

        # 2. prepare lons, lats
        lons, lats, flip_lats_atmospheric, flip_lons_atmospheric = prepare_lons_lats(lons, lats)

        # 3. get levels
        levels = tuple(int(level) for level in atmospheric_ds.level.values)
    else:
        levels = tuple()
        flip_lats_atmospheric = False
        flip_lons_atmospheric = False

    return Batch.from_aurora_batch(
        BaseBatch(
            surf_vars = {
                var: prepare_array(
                    surface_ds[var].values,
                    (1, T, H, W),
                    flip_lats=flip_lats_surface,
                    flip_lons=flip_lons_surface)
                for var in surface_variables
            },
            atmos_vars = {
                var: prepare_array(
                    atmospheric_ds[var].values,
                    (1, T, C, H, W),
                    flip_lats=flip_lats_atmospheric,
                    flip_lons=flip_lons_atmospheric)
                for var in atmospheric_variables
            },
            static_vars = {
                var: prepare_array(
                    static_ds[var].values,
                    (H, W),
                    flip_lats=flip_lats_static,
                    flip_lons=flip_lons_static)
                for var in static_variables
            } if static_ds is not None and len(static_ds) > 0 else {},
            metadata=BaseMetadata(
                lat=lats,
                lon=lons,
                time=_time,
                atmos_levels=levels,
            )
        )
    )



def load_static_ds_local(path: str) -> xr.Dataset:
    ds = xr.open_dataset(path, engine="netcdf4")
    if 90 in ds.latitude.values and -90 in ds.latitude.values:
        ds = ds.where(ds.latitude!=-90, drop=True)
    return ds


# Modified version of 'ERA5_HRES_T0_WB2_VARIABLE_NAMES_MAP_6HR' in Eliot's data/constants.py
ERA5_HRES_T0_WB2_VARIABLE_NAMES_MAP = {
    # Surface-level Variables
    '10m_u_component_of_wind': '10u',
    '10m_v_component_of_wind': '10v',
    '2m_temperature': '2t',
    'mean_sea_level_pressure': 'msl',

    # Atmospheric Variables
    'temperature': 't',
    'u_component_of_wind': 'u',
    'v_component_of_wind': 'v',
    'specific_humidity': 'q',
    'geopotential': 'z',
}

# Initialize the base (pre-trained) Aurora model
base_model = XauroraSmall(use_lora=False, autocast=True)
base_model.load_checkpoint(
    "microsoft/aurora", 
    "aurora-0.25-small-pretrained.ckpt", 
    strict=False,  # to avoid error when loading state dict after disabling lora (finetuned model has lora weights)
)

names_target_modules = [
    "backbone.encoder_layers.0.blocks.0.attn.qkv",
    "backbone.encoder_layers.0.blocks.0.attn.proj",
    "backbone.encoder_layers.0.blocks.1.attn.qkv",
    "backbone.encoder_layers.0.blocks.1.attn.proj",
    "backbone.encoder_layers.1.blocks.0.attn.qkv",
    "backbone.encoder_layers.1.blocks.0.attn.proj",
    "backbone.encoder_layers.1.blocks.1.attn.qkv",
    "backbone.encoder_layers.1.blocks.1.attn.proj",
    "backbone.encoder_layers.1.blocks.2.attn.qkv",
    "backbone.encoder_layers.1.blocks.2.attn.proj",
    "backbone.encoder_layers.1.blocks.3.attn.qkv",
    "backbone.encoder_layers.1.blocks.3.attn.proj",
    "backbone.encoder_layers.1.blocks.4.attn.qkv",
    "backbone.encoder_layers.1.blocks.4.attn.proj",
    "backbone.encoder_layers.1.blocks.5.attn.qkv",
    "backbone.encoder_layers.1.blocks.5.attn.proj",
    "backbone.encoder_layers.2.blocks.0.attn.qkv",
    "backbone.encoder_layers.2.blocks.0.attn.proj",
    "backbone.encoder_layers.2.blocks.1.attn.qkv",
    "backbone.encoder_layers.2.blocks.1.attn.proj",
    "backbone.decoder_layers.0.blocks.0.attn.qkv",
    "backbone.decoder_layers.0.blocks.0.attn.proj",
    "backbone.decoder_layers.0.blocks.1.attn.qkv",
    "backbone.decoder_layers.0.blocks.1.attn.proj",
    "backbone.decoder_layers.1.blocks.0.attn.qkv",
    "backbone.decoder_layers.1.blocks.0.attn.proj",
    "backbone.decoder_layers.1.blocks.1.attn.qkv",
    "backbone.decoder_layers.1.blocks.1.attn.proj",
    "backbone.decoder_layers.1.blocks.2.attn.qkv",
    "backbone.decoder_layers.1.blocks.2.attn.proj",
    "backbone.decoder_layers.1.blocks.3.attn.qkv",
    "backbone.decoder_layers.1.blocks.3.attn.proj",
    "backbone.decoder_layers.1.blocks.4.attn.qkv",
    "backbone.decoder_layers.1.blocks.4.attn.proj",
    "backbone.decoder_layers.1.blocks.5.attn.qkv",
    "backbone.decoder_layers.1.blocks.5.attn.proj",
    "backbone.decoder_layers.2.blocks.0.attn.qkv",
    "backbone.decoder_layers.2.blocks.0.attn.proj",
    "backbone.decoder_layers.2.blocks.1.attn.qkv",
    "backbone.decoder_layers.2.blocks.1.attn.proj"
]



In [3]:
import dask
dataset = AuroraDataset.from_cloud_storage(
    gcs_url="gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721.zarr/",  
    start_year=2016,
    end_year=2020
)

# Extract datasets for regional slicing
surface_ds = dataset.surface_ds
atmospheric_ds = dataset.atmospheric_ds
static_ds = dataset.static_ds
#For fine tuning do fro 2016-2020 including 2020

In [4]:
from aurora import Aurora
from peft import LoraConfig, LoraModel
from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from datetime import datetime, timedelta
import torch
import dataclasses
import pandas as pd
import numpy as np

# --- Region Config ---
REGIONS = {
    "eastern_med": {
        "lat": slice(43.75, 30.0),
        "lon": slice(16.0, 43.75)
    },
    "western_med": {
        "lat": slice(45.75, 32.0),
        "lon": slice(348.0, 376.0)
    }
}

def build_model():
    target_modules = [
        f"backbone.encoder_layers.{i}.blocks.0.attn.qkv" for i in range(6)
    ] + [
        f"backbone.encoder_layers.{i}.blocks.0.attn.proj" for i in range(6)
    ] + [
        f"backbone.decoder_layers.{i}.blocks.0.attn.qkv" for i in range(6)
    ] + [
        f"backbone.decoder_layers.{i}.blocks.0.attn.proj" for i in range(6)
    ]

    config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=target_modules,
        lora_dropout=0.1,
        bias="none"
    )

    base = Xaurora(use_lora=True, autocast=True)
    base.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
    model = LoraModel(base, config, adapter_name="default")

    # Patch forward
    original_forward = model.forward
    def patched_forward(self, batch, lead_time):
        p = next(self.parameters())
        batch = batch.type(p.dtype).to(p.device)
        return original_forward(batch, lead_time)
    model.forward = patched_forward.__get__(model, type(model))

    model = model.cuda().train()
    model.configure_activation_checkpointing()

    # Print trainable params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")
    return model




In [5]:
def move_batch_to_device(batch, device):
    return batch.to(device)

def clean_batch(batch):
    for var_group in [batch.surf_vars, batch.atmos_vars, batch.static_vars]:
        for k, v in var_group.items():
            var_group[k] = v.nan_to_num(nan=0.0).float()
    return batch

def roll_forward(input_batch, prediction):
    new_batch = dataclasses.replace(input_batch)

    for group in ['surf_vars', 'atmos_vars']:
        input_vars = getattr(input_batch, group)
        pred_vars = getattr(prediction, group)
        new_group = {}

        for k in input_vars:

            rolled = torch.cat([input_vars[k][:, -1:], pred_vars[k][:, -1:].detach()], dim=1)
            new_group[k] = rolled

        setattr(new_batch, group, new_group)

    return new_batch



In [7]:
def train_on_region_14d(region_key, surface_ds, atmospheric_ds, static_ds, epochs=5, max_steps=200):
    print(f"\nStarting fine-tuning for region: {region_key}")

    region = REGIONS[region_key]
    surf_reg = surface_ds.sel(latitude=region["lat"], longitude=region["lon"])
    atmos_reg = atmospheric_ds.sel(latitude=region["lat"], longitude=region["lon"])
    static_reg = static_ds.sel(latitude=region["lat"], longitude=region["lon"])

    dataset = AuroraDataset(
        surface_ds=surf_reg,
        atmospheric_ds=atmos_reg,
        static_ds=static_reg,
        input_temporal_length=1,
        output_temporal_length="6h",
        forecast_horizon="14d"
    )

    loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x[0])
    model = build_model()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scaler = GradScaler(enabled=torch.cuda.is_available())

    loss_fn = WeightedMAELoss(
        gamma=2.0, alpha=0.25, beta=1.0,
        surf_var_weights={"2t": 3.0, "10u": 1.0, "10v": 1.0, "msl": 0.05},
        atmos_var_weights={"t": 3.0, "u": 1.0, "v": 1.0, "q": 0.1, "z": 0.05},
    )

    LEAD_TIME = timedelta(hours=6)
    ROLLOUT_STEPS = 56

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        for step, (input_batch, _) in enumerate(loader):
            if step >= max_steps:
                print("Max steps reached. Breaking.")
                break

            device = next(model.parameters()).device
            input_batch = clean_batch(move_batch_to_device(input_batch, device))

            # Get start time
            initial_time = input_batch.metadata.time
            if isinstance(initial_time, (list, tuple, np.ndarray)):
                initial_time = initial_time[0]
            if isinstance(initial_time, torch.Tensor):
                initial_time = initial_time.item()
            if isinstance(initial_time, np.datetime64):
                initial_time = pd.to_datetime(str(initial_time)).to_pydatetime()

            # Build target rollout
            target_batches = []
            for i in range(ROLLOUT_STEPS):
                target_time = initial_time + i * LEAD_TIME
                target = dataset.get_matching_timestamp(target_time)
                target = clean_batch(move_batch_to_device(target, device)).normalise()
                target_batches.append(target)

            current_input = input_batch
            step_losses = []
            optimizer.zero_grad()

            for t in range(ROLLOUT_STEPS):
                target = target_batches[t]
                with autocast(device_type="cuda"):
                    pred = model(current_input, LEAD_TIME)
                    pred = clean_batch(pred).normalise()
                    for var in pred.surf_vars:
                        pred.surf_vars[var] = torch.clamp(pred.surf_vars[var], -1.0, 1.0)
                    for var in pred.atmos_vars:
                        pred.atmos_vars[var] = torch.clamp(pred.atmos_vars[var], -1.0, 1.0)
                    loss = loss_fn(pred, target)

                if not torch.isfinite(loss):
                    print(f"[Warning] Non-finite loss at step {t}, setting to zero.")
                    loss = torch.tensor(0.0, device=device)

                step_losses.append(loss)

                with torch.no_grad():
                    detached_pred = Batch(
                        surf_vars={k: v.detach() for k, v in pred.surf_vars.items()},
                        static_vars={k: v.detach() for k, v in pred.static_vars.items()},
                        atmos_vars={k: v.detach() for k, v in pred.atmos_vars.items()},
                        metadata=pred.metadata
                    )
                    current_input = roll_forward(current_input, detached_pred)

                del pred, target, loss
                torch.cuda.empty_cache()

            total_loss = sum(step_losses) / len(step_losses)
            scaler.scale(total_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            print(f"Step {step+1} | Mean Loss: {total_loss.item():.4f}")
            del step_losses, total_loss, current_input
            torch.cuda.empty_cache()

    ckpt_name = f"aurora_{region_key}_14d_finetuned_epochss{epochs}.pt"
    torch.save(model.state_dict(), ckpt_name)
    print(f"Model saved to {ckpt_name}")



In [8]:
train_on_region_14d("western_med", surface_ds, atmospheric_ds, static_ds, epochs=1, max_steps=10)



Starting fine-tuning for region: western_med
trainable params: 3,194,880 / 1,259,495,056 (0.25%)

Epoch 1/1
Step 1 | Mean Loss: 1.3965
Step 2 | Mean Loss: 1.3753
Step 3 | Mean Loss: 1.5722
Step 4 | Mean Loss: 1.6296
Step 5 | Mean Loss: 1.7905
Step 6 | Mean Loss: 1.4943
Step 7 | Mean Loss: 1.5974
Step 8 | Mean Loss: 1.3030
Step 9 | Mean Loss: 1.7901
Step 10 | Mean Loss: 1.4778
Max steps reached. Breaking.
Model saved to aurora_western_med_14d_finetuned_epochss1.pt
