### Imports

In [1]:
from main import *
import xarray as xr

### Datasets

In [None]:
cd_200 = "/lustrefs/taiga/corrdiff_inference/heat_dates/aug_25_2023_200m.nc"
wrf_200 = "/lustrefs/taiga/vast102/weather/training_data/wrf_200m_dataset/wrf_data_200/wrfout_200_2023-08-25.nc"
era5_native_ad = "/inception_users/adah.kibet/era5_2023_aug_25_ccropped_AD.nc"
fcn_ad = "/inception_users/adah.kibet/codebase/corrdiff-uae-inference/examples/sfno_cropped_sfno_aug_2023_25_v1.nc"

cd_2km_r = "/lustrefs/taiga/vast102/weather/weather_navigator/inference_regression/inference_2023-08-17T{hour:02d}_00_00_regression.nc"
wrf_2km = "/lustrefs/taiga/vast102/weather/training_data/weather_wrf_dataset/combined_data_q1000/test/2023/08/inp_out_combined_2023_08_17.nc"
era5_native_uae = "/inception_users/adah.kibet/era5_2023_aug_17_ccropped_UAE_real.nc"
era5_interp = "/lustrefs/taiga/vast102/weather/training_data/weather_wrf_dataset/combined_data_q1000/test/2023/08/inp_out_combined_2023_08_17.nc"
cd_2km = "/lustrefs/taiga/vast102/weather/weather_navigator/modulus_last/examples/generative/corrdiff/outputs/inference/all_2km/inference_2023-08-17T{hour:02d}_00_00_all_2km.nc"

# update sfno UAE only to get aug 17, 2023
fcn_uae = "/inception_users/adah.kibet/sfno_cropped_sfno_aug_2023_25_v1_ad.nc"

In [6]:
cd_200 = xr.open_dataset(cd_200)
wrf_200 = xr.open_dataset(wrf_200)
wrf_2km = load_wrf_dataset(wrf_2km, "output")
era5_native_ad = xr.open_dataset(era5_native_ad)
era5_native_uae = xr.open_dataset(era5_native_uae)
era5_interp = load_wrf_dataset(era5_interp, "input")
fcn_ad = xr.open_dataset(fcn_ad)
fcn_uae = xr.open_dataset(fcn_uae)

hourly_datasets = []
for lead in range(24):
    file_path = cd_2km_r.format(hour=lead)
    ds = xr.open_dataset(file_path, group="prediction")
    assert ds.dims["time"] == 1, f"{file_path} has unexpected time dim"
    ds = ds.isel(time=0, drop=True)
    ds = ds.expand_dims({"lead_time": [lead]})
    hourly_datasets.append(ds)
cd_2km_r = xr.concat(hourly_datasets, dim="lead_time")


hourly_datasets = []
for lead in range(24):
    file_path = cd_2km.format(hour=lead)
    ds = xr.open_dataset(file_path, group="prediction")
    assert ds.dims["time"] == 1, f"{file_path} has unexpected time dim"
    ds = ds.isel(time=0, drop=True)
    ds = ds.expand_dims({"lead_time": [lead]})
    hourly_datasets.append(ds)
cd_2km = xr.concat(hourly_datasets, dim="lead_time")


### Derived Variables

In [7]:
cd_200 = compute_derived_variables(cd_200)
cd_2km = compute_derived_variables(cd_2km)
wrf_200 = compute_derived_variables(wrf_200)
wrf_2km = compute_derived_variables(wrf_2km)

### Interactive Plots



#### 2KM T2

In [8]:
datasets = {
    "CorrDiff (Regression) 2KM ": cd_2km_r,
    "CorrDiff - 2KM": cd_2km,
    "WRF 2KM": wrf_2km,
    "ERA5 Native 25KM - UAE": era5_native_uae,
    "ERA5 Interpolation 2KM - UAE ": era5_interp,
    "FourCastNet 25KM - UAE": fcn_uae
}
interactive_dataset_viewer(datasets, "2M Temperature (K) - August 25, 2023, UAE", cmap="coolwarm")

VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 200m T2

In [34]:
datasets = {
    "CorrDiff 200M": cd_200,
    "WRF 200M": wrf_200,
    "FourCastNet 25KM - AD": fcn_ad,
    "ERA5 Native 25KM - AD": era5_native_ad
}

interactive_dataset_viewer(datasets, "2M Temperature (K) - August 25, 2023, Abu Dhabi", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 2KM HI

In [35]:
# no need to change dates for this as they're heat dates
datasets = {
    "CorrDiff 2KM ": cd_2km,
    "WRF 2KM": wrf_2km,
}
interactive_dataset_viewer(datasets, "Heat Index (K) - August 25, 2023, UAE", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 200M HI

In [36]:
datasets = {
    "CorrDiff 200M ": cd_200,
    "WRF 200M": wrf_200,
}

interactive_dataset_viewer(datasets, "Heat Index (K) - August 25, 2023, Abu Dhabi", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 2km FI

In [13]:
datasets = {
    "CorrDiff 2KM ": cd_2km,
    "WRF 2KM": wrf_2km,
}
interactive_dataset_viewer(datasets, "Fog Index - August 25, 2023, UAE", cmap="viridis")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 200m FI

In [None]:
datasets = {
    "CorrDiff 200M": cd_200,
    "WRF 200M": wrf_200,
}

interactive_dataset_viewer(datasets, "Fog Index - August 25, 2023, Abu Dhabi", cmap="viridis")

VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 2KM - Wind (U10, V10 and Total Wind)

In [10]:
# compute total wind speed
def compute_total_wind(ds, u10, v10):

    u = ds[u10]
    v = ds[v10]
    ds["TW"] = np.sqrt(u**2 + v**2)  
    
    return ds


In [11]:
datasets = {
    "CorrDiff (Regression) 2KM ": compute_total_wind(cd_2km_r, "U10", "V10"),
    "CorrDiff - 2KM": compute_total_wind(cd_2km, "U10", "V10"),
    "WRF 2KM": compute_total_wind(wrf_2km, "U10", "V10"),
    "ERA5 Native 25KM - UAE": compute_total_wind(era5_native_uae, "u10", "v10"),
    "ERA5 Interpolation 2KM - UAE ": compute_total_wind(era5_interp, "u10", "v10"),
    "FourCastNet 25KM - UAE": compute_total_wind(fcn_uae, "u10m", "v10m")
}

interactive_dataset_viewer(datasets, "Horizontal/Zonal Wind Speed (M/S) - August 25, 2023, UAE", cmap="viridis")

# change these after the new dates are receievd

VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

In [44]:
interactive_dataset_viewer(datasets, "Vertical/Meridional Wind Speed (M/S) - August 25, 2023, UAE", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

In [None]:
wind_speed = np.sqrt(u**2 + v**2)
interactive_dataset_viewer(datasets, "Total Wind Speed (M/S) - August 25, 2023, UAE", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

#### 200M - Wind (U10, V10 and Total Wind)

In [None]:
datasets = {
    "CorrDiff 200M": compute_total_wind(cd_200, "U10", "V10"),
    # "WRF 200M": compute_total_wind(wrf_200, "U10", "V10"),
    "FourCastNet - AD": compute_total_wind(fcn_ad, "u10m", "v10m"),
    "ERA5 Native - AD": compute_total_wind(era5_native_ad, "u10", "v10")
}

interactive_dataset_viewer(datasets, "Horizontal/Zonal Wind Speed (M/S) - August 25, 2023, Abu Dhabi", cmap="coolwarm")

VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

In [48]:
interactive_dataset_viewer(datasets, "Vertical/Meridional Wind Speed (M/S) - August 25, 2023, Abu Dhabi", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

In [49]:
interactive_dataset_viewer(datasets, "Total Wind Speed (M/S) - August 25, 2023, Abu Dhabi", cmap="coolwarm")


VBox(children=(HBox(children=(Dropdown(description='Dataset 1', layout=Layout(width='200px'), options=(None, '…

Output()

### Loss Functions

- RMSE: Measures average error magnitude (lower = better)
- Bias: Measures systematic over/under-prediction (0 = ideal)
- Correlation: Measures linear relationship strength (higher = better)

-------------------
1. RMSE (Root Mean Square Error):
- Measures the average magnitude of forecast errors
- Formula: sqrt(mean((forecast - truth)²))
- Units: same as original data
- Lower is better (0 = perfect)

2. Bias (Mean Error):
- Measures systematic over/under-prediction
- Formula: mean(forecast - truth)
- Positive = forecast too high, Negative = forecast too low
- Units: same as original data
- 0 = no systematic bias (ideal)

3. Correlation (Pearson r):
- Measures linear relationship between forecast and truth
- Range: -1 to +1
- +1 = perfect positive correlation
- 0 = no linear relationship
- -1 = perfect negative correlation
- Higher absolute values are better

In [19]:
datasets = {
    "CorrDiff - 2KM": cd_2km,
    "WRF 2KM": wrf_2km,
    "ERA5 Interpolation - UAE": era5_interp,
}
loss_functions(datasets)

interactive(children=(Dropdown(description='Dataset 1', index=2, options=('CorrDiff - 2KM', 'WRF 2KM', 'ERA5 I…