In [1]:
# IMPORTS
import os
import sys
import dataclasses
import typing as t
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, NullFormatter, ScalarFormatter

from IPython.display import clear_output
from dask.diagnostics import ProgressBar
from tqdm.notebook import tqdm

# local clone of weatherbench2 found at : https://github.com/google-research/weatherbench2
sys.path.append('/scratch/dx2/cl6824/weatherbench2')
from weatherbench2 import schema
from weatherbench2.derived_variables import ZonalEnergySpectrum

In [2]:
ProgressBar().register()

### ERA5

In [None]:
# Retrieve all files (3 fields, 12 months)
u_path = os.path.join("/g/data/rt52/era5", "pressure-levels", "reanalysis", "u", "2024")
v_path = os.path.join("/g/data/rt52/era5", "pressure-levels", "reanalysis", "v", "2024")
w_path = os.path.join("/g/data/rt52/era5", "pressure-levels", "reanalysis", "w", "2024")
u_files = sorted([os.path.join(u_path, e) for e in os.listdir(u_path)])
v_files = sorted([os.path.join(v_path, e) for e in os.listdir(v_path)])
w_files = sorted([os.path.join(w_path, e) for e in os.listdir(w_path)])

print("starting")
# Loop on all months
for i in range(len(u_files)):
    print(f"making files n°{i+1}/12")
    def prep(ds):
        # select the desired pressure levels
        ds = ds.sel(level=[500, 200])
        return ds
    
    # Use open_mfdataset with preprocess
    ds = xr.open_mfdataset(
        [u_files[i]]+[v_files[i]]+[w_files[i]],
        combine='by_coords',
        preprocess=prep,
        chunks={}
    )
    print("dataset opened")
    # calculate kinetic energy
    ds["wind_speed"] = ds["u"]**2 + ds["v"]**2 + ds["w"]**2
    ds = ds.drop_vars(["u", "v", "w"])
    print("kinetic energy retrieved")
    # store a file of kinetic energy per month
    ds.to_netcdf(f"/scratch/dx2/cl6824/Fourier/ERA5/tmp/times_{i:02d}.nc")


In [None]:
# Calculate spectra of each monthly file
base_path = "/scratch/dx2/cl6824/Fourier/ERA5/tmp"
files = os.listdir(base_path)
files = sorted([os.path.join(base_path, efile) for efile in files if (efile.endswith(".nc") and efile.startswith("times"))])

for i, efile in enumerate(files):
    print(f"performing file n°{i}/12")
    ds = xr.open_dataset(efile)
    for elevel in [200, 500]:
        print("loading")
        eds = ds.sel(level=elevel)
        eds = eds.load()
        print("computing ZES")
        ZES = ZonalEnergySpectrum("wind_speed")
        res = ZES.compute(eds)
        print("saving results")
        res.to_netcdf(f"/scratch/dx2/cl6824/Fourier/ERA5/tmp/fourier_{elevel}hPa_{i:02d}.nc")


In [None]:
# Combine spectras in a final file over all year
base_path = "/scratch/dx2/cl6824/Fourier/ERA5/tmp"
ffiles = [efile for efile in os.listdir(base_path) if efile.startswith("fourier")]
for elevel in [200, 500]:
    files_level = [os.path.join(base_path, efile) for efile in ffiles if efile.split("_")[1] == str(elevel)+"hPa"]
    ds_level = xr.open_mfdataset(files_level,
            combine='by_coords',
            chunks={}
        )
    ds_level = ds_level.mean(dim="time")
    ds_level.to_netcdf(f"/scratch/dx2/cl6824/Fourier/ERA5/wind_speed_{elevel}hPa/{elevel}hPa_wind_speed_spectrum.nc")

### AIFS

In [20]:
max_lead_time = 60
base_path = "/g/data/dx2/cl6824/ML/AIFS/outputs/postproc/v1/"
files = [os.path.join(base_path, e) for e in os.listdir(base_path)]
lead_times = [0, 1, 2, 3, 4, 5] + list(range(38, max_lead_time))

for lead_time_index in lead_times:
    clear_output(wait=True)

    print(f"doing lead time {lead_time_index}/{max_lead_time}")
    def prep(ds):
        # Select only the variables of interest
        ds = ds[["u", "v", "w"]]
        # Select a specific lead time
        ds = ds.isel(time=[lead_time_index])
        # select the desired pressure level
        ds = ds.sel(level=[500, 200])
        # Compute wind speed from u and v components
        ds["wind_speed"] = ds["u"]**2 + ds["v"]**2 + ds["w"]**2
        # Drop the original wind components
        ds = ds.drop_vars(["u", "v", "w"])
        return ds

    ds_lead_time = xr.open_mfdataset(
        files,
        combine='by_coords',
        preprocess=prep,
        chunks={}
    )

    for elevel in [500, 200]:
        eds = ds_lead_time.sel(level=elevel)
        eds = eds.load()
        ZES = ZonalEnergySpectrum("wind_speed")
        res = ZES.compute(eds)
        res = res.mean(dim="time")

        res.to_netcdf(f"/scratch/dx2/cl6824/Fourier/AIFS/wind_speed_{elevel}hPa/{elevel}hPa_wind_speed_spectrum_leadtime_{(lead_time_index+1)*6:03d}h.nc")

doing lead time 59/60
[########################################] | 100% Completed | 187.88 s
[########################################] | 100% Completed | 187.97 s
[########################################] | 100% Completed | 57.88 s
[########################################] | 100% Completed | 58.18 s


### IFS

In [None]:
base_path = "/g/data/dx2/cl6824/ML/IFS/"
files = sorted([os.path.join(base_path, el) for el in os.listdir(base_path) if el.endswith(".grib")])
dates = [(datetime.strptime('20240101', '%Y%m%d') + timedelta(days=x*3, hours=12)).strftime('%Y%m%d%H') for x in range(122)]
dates = np.array([np.datetime64(datetime.strptime(d, '%Y%m%d%H')) for d in dates])

def prep(ds):
    ds = ds.sel(time=ds.time.isin(dates))
    ds["wind_speed"] = ds["u"]**2 + ds["v"]**2 + ds["w"]**2
    ds = ds.drop_vars(["u", "v", "w"])
    return ds

ds = xr.open_mfdataset(
    files,
    combine='by_coords',
    preprocess=prep,
    chunks={},  # Attention : chunks={} charge tout en mémoire
    decode_timedelta=True
)
for elevel in [500, 200]:
    eds = ds.sel(isobaricInhPa=elevel)
    eds = eds.load()
    print("ds loaded")
    ZES = ZonalEnergySpectrum("wind_speed")
    res = ZES.compute(eds)
    res = res.mean(dim="time")
    res = res.drop_vars(["number", "step"])
    print("computation done, saving ...")

    res.to_netcdf(f"/scratch/dx2/cl6824/Fourier/IFS/wind_speed_{elevel}hPa/{elevel}hPa_wind_speed_spectrum.nc")
    print(f"level {elevel} done")

### IFS FC

In [8]:
base_path = "/g/data/dx2/cl6824/ML/IFS_FC/"
files = sorted([os.path.join(base_path, el) for el in os.listdir(base_path) if el.endswith(".grib")])
dates = [(datetime.strptime('20240101', '%Y%m%d') + timedelta(days=x*3, hours=12)).strftime('%Y%m%d%H') for x in range(122)]
dates = np.array([np.datetime64(datetime.strptime(d, '%Y%m%d%H')) for d in dates])

max_lead_time = 40
lead_times = list(range(0, max_lead_time))
for lead_time_index in lead_times[1:]:
    clear_output(wait=True)

    print(f"doing lead time {lead_time_index}/{max_lead_time}")
    def prep(ds):
        # Select only the variables of interest
        ds = ds[["u", "v", "w"]]
        # Select a specific lead time
        ds = ds.isel(step=[lead_time_index])
        # select the desired pressure level
        ds = ds.sel(isobaricInhPa=[500, 200])
        # Compute wind speed from u and v components
        ds["wind_speed"] = ds["u"]**2 + ds["v"]**2 + ds["w"]**2
        # Drop the original wind components
        ds = ds.drop_vars(["u", "v", "w"])
        return ds

    ds = xr.open_mfdataset(
        files,
        combine='by_coords',
        preprocess=prep,
        chunks={},
        backend_kwargs={
            "decode_timedelta": True   # ou False selon ton cas
        }
    )

    for elevel in [500]:
        print(f"doing level {elevel}")
        eds = ds.sel(isobaricInhPa=elevel)
        eds = eds.load()
        print("loaded")
        ZES = ZonalEnergySpectrum("wind_speed")
        res = ZES.compute(eds)
        res = res.mean(dim="time")
        res = res.drop_vars(["step"])
        print("saving")

        res.to_netcdf(f"/scratch/dx2/cl6824/Fourier/IFS_FC/wind_speed_{elevel}hPa/{elevel}hPa_wind_speed_spectrum_leadtime_{(lead_time_index+1)*6:03d}h.nc")

doing lead time 39/40
doing level 500
[########################################] | 100% Completed | 533.75 s
loaded
saving


### AIFS - ENS

In [None]:
base_path = "/g/data/dx2/cl6824/ML/AIFS/outputs/postproc/ens/ensv1"
max_lead_time = 60
files = [os.path.join(base_path, e) for e in os.listdir(base_path) if e.startswith("ifs")]
lead_times = list(range(0, max_lead_time))
for lead_time_index in lead_times:
    #clear_output(wait=True)

    print(f"doing lead time {lead_time_index}/{max_lead_time}")
    def prep(ds):
        # Select only the variables of interest
        ds = ds[["u", "v", "w"]]
        # Select a specific lead time
        ds = ds.isel(time=[lead_time_index+2]) # it looks like the IC are included (-6h and 0h, so +2 to be consistent with AIFS)
        # select the desired pressure level
        ds = ds.sel(level=[500, 200])
        # Compute wind speed from u and v components
        ds["wind_speed"] = ds["u"]**2 + ds["v"]**2 + ds["w"]**2
        # Drop the original wind components
        ds = ds.drop_vars(["u", "v", "w"])
        return ds

    ds = xr.open_mfdataset(
        files,
        combine='by_coords',
        preprocess=prep,
        chunks={}
    )

    for elevel in [500, 200]:
        print(f"doing level {elevel}")
        eds = ds.sel(level=elevel)
        eds = eds.load()
        print("loaded")
        ZES = ZonalEnergySpectrum("wind_speed")
        res = ZES.compute(eds)
        res = res.mean(dim="time")
        print("saving")

        res.to_netcdf(f"/scratch/dx2/cl6824/Fourier/AIFS_ens/wind_speed_{elevel}hPa/{elevel}hPa_wind_speed_spectrum_leadtime_{(lead_time_index+1)*6:03d}h.nc")

### TESTS