## Evaluation of the models (ResNet and ViT)

**ResNet:** This model includes a series of convolutional layers with skip connections, global average pooling, and integrates low-level features via a skip connection from the input. It is tailored for complex pattern recognition in data with spatial dimension (adopted from https://arxiv.org/pdf/2301.10343).

**ViT (Vision Transformer):** This configuration features multi-scale convolutional patch embedding with residual connections, followed by a series of transformer blocks. The model processes input data to generate an output with the same spatial dimensions, applying learnable positional encodings and a final dense layer for regression (https://arxiv.org/pdf/2410.12728).

Models were trained on historical data using air temperature and sea-level pressure as a predictors, and residuals (observational - cmip6) as the target variable. Estimated residuals were added back to bilinearly interpolated CMIP6 data (resampled to the observational data's spatial resolution) to evaluate models' performance.


The data from CNRM-ESM2-1 for historical and SSP245 scenario. 
The historical data were split using daily data as follows: training from 1950-2004, validation from 2004-2009, and testing from 2010-2014.

#### Define Functions for data loading and plotting

In [8]:
import xarray as xr
import numpy as np

def load_and_mean_data(files_or_path, var_name, start_year=None, end_year=None, mask=None, lat=None, lon=None, interp_method=None, is_zarr=False):
  
    if is_zarr:
        ds = xr.open_zarr(files_or_path, consolidated=True)
    else:
        ds = xr.open_mfdataset(files_or_path, combine='by_coords', chunks={'time': -1})

    # Select time range based on start and end years
    if 'time' in ds.coords and hasattr(ds.time.dt, 'year'):
        if start_year is not None:
            ds = ds.sel(time=ds.time.dt.year >= start_year)
        if end_year is not None:
            ds = ds.sel(time=ds.time.dt.year <= end_year)
    elif start_year is not None or end_year is not None:
        print("Warning: start_year and end_year provided, but 'time' coordinate not found or doesn't have a 'year' attribute. Ignoring year selection.")


    data_mean = ds[var_name].mean(dim='time')

    # Interpolate if lat/lon are provided
    if lat is not None and lon is not None and interp_method:
        data_mean = data_mean.interp(lat=lat, lon=lon, method=interp_method, kwargs={"fill_value": "extrapolate"})

    # Apply mask if provided
    if mask is not None:
        data_mean = data_mean.where(~mask, np.nan)

    return data_mean

# Function to plot mean temperature maps
def plot_map(data, ax, title, mean_value, vmin=270, vmax=292, cmap='RdBu_r'):
    data.plot(
        ax=ax, transform=ccrs.PlateCarree(), cmap=cmap, vmin=vmin, vmax=vmax,
        cbar_kwargs={'orientation': 'horizontal', 'pad': 0.1, 'shrink': 0.8, 'aspect': 30, 'label': 'Mean Temperature (K)'}
    )
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, edgecolor='gray')
    ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    ax.set_title(title)
    ax.text(10, 50.5, f'Mean: {mean_value:.2f} K', fontsize=10, ha='center', va='center', color='black')

# Function to calculate RMSE and MAE
def calculate_rmse_mae(model1, model2):
    diff = model1 - model2
    rmse = np.sqrt(np.nanmean(diff ** 2))
    mae = np.nanmean(np.abs(diff))
    return rmse, mae