In [None]:
!pwd

/content


In [None]:
!pip install cdsapi

Collecting cdsapi
  Downloading cdsapi-0.6.1.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: cdsapi
  Building wheel for cdsapi (setup.py) ... [?25l[?25hdone
  Created wheel for cdsapi: filename=cdsapi-0.6.1-py2.py3-none-any.whl size=12006 sha256=617a0447d698a3c5518b8baea985db414133cb2e06bceb724ee527767c16c484
  Stored in directory: /root/.cache/pip/wheels/7c/63/08/45461d6f6636c1aba7846828d8c787a064073945048f76d44a
Successfully built cdsapi
Installing collected packages: cdsapi
Successfully installed cdsapi-0.6.1


In [None]:
cdsKey = "294342:19fc170f-4d11-469c-a561-1efc7bcad06f"

In [None]:
import os
import cdsapi

In [None]:
def download_copernicus_era5(dst, variable, year, regions={"Globe": [90, -180, -90, 180]}, pressure=False, api_key=None):
    if api_key is not None:
        content = f"url: https://cds.climate.copernicus.eu/api/v2\nkey: {api_key}"
        home_dir = os.environ["HOME"]
        with open(os.path.join(home_dir, ".cdsapirc"), "w") as f:
            f.write(content)
    os.makedirs(dst, exist_ok=True)
    client = cdsapi.Client()
    download_args = {
        "product_type": "reanalysis",
        "format": "netcdf",
        "variable": variable,
        "year": str(year),
        "month": "03", #[str(i).rjust(2, "0") for i in range(1, 13)],
        "day": "01", #[str(i).rjust(2, "0") for i in range(1, 32)],
        "time": [str(i).rjust(2, "0") + ":00" for i in range(0, 24)],
    }
    if pressure:
        src = "reanalysis-era5-pressure-levels"
        download_args["pressure_level"] = [1000, 850, 500, 50]
    else:
        src = "reanalysis-era5-single-levels"

    for regName, area in regions.items():
      download_args["area"] = area
      client.retrieve(src, download_args, f"{dst}/{variable}_{year}_{regName}_0.25deg.nc")

In [None]:
twoRegs = {
    "CA": [50, -126, 30, -112],
    "Globe": [90, -180, -90, 180],
}

In [None]:
oneReg = {
    "CA": [50, -126, 30, -112]
}

In [None]:
variables = ['2m_temperature', 'total_precipitation']
for var in variables:
  download_copernicus_era5("data", var, 2024, regions=oneReg, api_key=cdsKey)

2024-04-03 21:53:15,553 INFO Welcome to the CDS
INFO:cdsapi:Welcome to the CDS
2024-04-03 21:53:15,558 INFO Sending request to https://cds.climate.copernicus.eu/api/v2/resources/reanalysis-era5-single-levels
INFO:cdsapi:Sending request to https://cds.climate.copernicus.eu/api/v2/resources/reanalysis-era5-single-levels
2024-04-03 21:53:15,857 INFO Request is queued
INFO:cdsapi:Request is queued
2024-04-03 21:53:18,601 INFO Request is running
INFO:cdsapi:Request is running
2024-04-03 21:53:24,475 INFO Request is completed
INFO:cdsapi:Request is completed
2024-04-03 21:53:24,480 INFO Downloading https://download-0018.copernicus-climate.eu/cache-compute-0018/cache/data6/adaptor.mars.internal-1712181201.8340993-28264-19-7554ada1-2694-4c89-adb1-0c6932a16013.nc to data/2m_temperature_2024_CA_0.25deg.nc (218.1K)
INFO:cdsapi:Downloading https://download-0018.copernicus-climate.eu/cache-compute-0018/cache/data6/adaptor.mars.internal-1712181201.8340993-28264-19-7554ada1-2694-4c89-adb1-0c6932a1601

In [None]:
# Standard library
import glob

# Third party
import numpy as np
import xarray as xr
from tqdm import tqdm

In [None]:
NAME_TO_VAR = {
    "2m_temperature": "t2m",
    "10m_u_component_of_wind": "u10",
    "10m_v_component_of_wind": "v10",
    "mean_sea_level_pressure": "msl",
    "surface_pressure": "sp",
    "toa_incident_solar_radiation": "tisr",
    "total_precipitation": "tp",
    "land_sea_mask": "lsm",
    "orography": "orography",
    "lattitude": "lat2d",
    "geopotential": "z",
    "u_component_of_wind": "u",
    "v_component_of_wind": "v",
    "temperature": "t",
    "relative_humidity": "r",
    "specific_humidity": "q",
    "vorticity": "vo",
    "potential_vorticity": "pv",
    "total_cloud_cover": "tcc",
}

In [None]:
a = np.arange(12)
b = np.arange(11)
print(a[:20])
print(b[-20:])

[ 0  1  2  3  4  5  6  7  8  9 10 11]
[ 0  1  2  3  4  5  6  7  8  9 10]


In [None]:
!ls data

2m_temperature_2024_CA_0.25deg.nc  total_precipitation_2024_CA_0.25deg.nc


In [None]:
HOURS_PER_YEAR = 24  # look at the shard with small data

def nc2np(path, variables, years, save_dir, partition, num_shards_per_year):
    os.makedirs(os.path.join(save_dir, partition), exist_ok=True)

    if partition == "train":
        normalize_mean = {}
        normalize_std = {}
    climatology = {}

    for year in tqdm(years):
        np_vars = {}

        # non-constant fields
        for var in variables:
            ps = glob.glob(os.path.join(path, f"*{var}*{year}*.nc"))
            ds = xr.open_mfdataset(
                ps, combine="by_coords", parallel=True
            )  # dataset for a single variable
            code = NAME_TO_VAR[var]
            print(ds[code].shape)
            if len(ds[code].shape) == 3:  # surface level variables
                ds[code] = ds[code].expand_dims("val", axis=1)
                # remove the last 24 hours if this year has 366 days
                if code == "tp":  # accumulate 6 hours and log transform
                    tp = ds[code].to_numpy()
                    print(tp[:,0, 0,0])
                    tp_cum_6hrs = np.cumsum(tp, axis=0)
                    tp_cum_6hrs[6:] = tp_cum_6hrs[6:] - tp_cum_6hrs[:-6]
                    print(tp_cum_6hrs[:, 0, 0, 0])
                    eps = 0.001
                    tp_cum_6hrs = np.log(eps + tp_cum_6hrs) - np.log(eps)
                    np_vars[var] = tp_cum_6hrs[-HOURS_PER_YEAR:]
                    print(np_vars[var][:, 0, 0, 0])
                else:
                    np_vars[var] = ds[code].to_numpy()[-HOURS_PER_YEAR:]

                if partition == "train":
                    # compute mean and std of each var in each year
                    var_mean_yearly = np_vars[var].mean(axis=(0, 2, 3))
                    var_std_yearly = np_vars[var].std(axis=(0, 2, 3))
                    if var not in normalize_mean:
                        normalize_mean[var] = [var_mean_yearly]
                        normalize_std[var] = [var_std_yearly]
                    else:
                        normalize_mean[var].append(var_mean_yearly)
                        normalize_std[var].append(var_std_yearly)

                clim_yearly = np_vars[var].mean(axis=0)
                if var not in climatology:
                    climatology[var] = [clim_yearly]
                else:
                    climatology[var].append(clim_yearly)

            else:  # pressure-level variables
                assert len(ds[code].shape) == 4
                all_levels = ds["level"][:].to_numpy()
                all_levels = np.intersect1d(all_levels, DEFAULT_PRESSURE_LEVELS)
                for level in all_levels:
                    ds_level = ds.sel(level=[level])
                    level = int(level)
                    # remove the last 24 hours if this year has 366 days
                    np_vars[f"{var}_{level}"] = ds_level[code].to_numpy()[
                        -HOURS_PER_YEAR:
                    ]

                    if partition == "train":
                        # compute mean and std of each var in each year
                        var_mean_yearly = np_vars[f"{var}_{level}"].mean(axis=(0, 2, 3))
                        var_std_yearly = np_vars[f"{var}_{level}"].std(axis=(0, 2, 3))
                        if f"{var}_{level}" not in normalize_mean:
                            normalize_mean[f"{var}_{level}"] = [var_mean_yearly]
                            normalize_std[f"{var}_{level}"] = [var_std_yearly]
                        else:
                            normalize_mean[f"{var}_{level}"].append(var_mean_yearly)
                            normalize_std[f"{var}_{level}"].append(var_std_yearly)

                    clim_yearly = np_vars[f"{var}_{level}"].mean(axis=0)
                    if f"{var}_{level}" not in climatology:
                        climatology[f"{var}_{level}"] = [clim_yearly]
                    else:
                        climatology[f"{var}_{level}"].append(clim_yearly)

        assert HOURS_PER_YEAR % num_shards_per_year == 0
        num_hrs_per_shard = HOURS_PER_YEAR // num_shards_per_year
        for shard_id in range(num_shards_per_year):
            start_id = shard_id * num_hrs_per_shard
            end_id = start_id + num_hrs_per_shard
            sharded_data = {k: np_vars[k][start_id:end_id] for k in np_vars.keys()}
            np.savez(
                os.path.join(save_dir, partition, f"{year}_{shard_id}.npz"),
                **sharded_data,
            )
        print("===")
        for k in np_vars.keys():
          print(k, np_vars[k].shape)
        print("===")

    if partition == "train":
        for var in normalize_mean.keys():
            if not constants_are_downloaded or var not in constant_fields:
                normalize_mean[var] = np.stack(normalize_mean[var], axis=0)
                normalize_std[var] = np.stack(normalize_std[var], axis=0)

        for var in normalize_mean.keys():  # aggregate over the years
            if not constants_are_downloaded or var not in constant_fields:
                mean, std = normalize_mean[var], normalize_std[var]
                # var(X) = E[var(X|Y)] + var(E[X|Y])
                variance = (
                    (std**2).mean(axis=0)
                    + (mean**2).mean(axis=0)
                    - mean.mean(axis=0) ** 2
                )
                std = np.sqrt(variance)
                # E[X] = E[E[X|Y]]
                mean = mean.mean(axis=0)
                normalize_mean[var] = mean
                if var == "total_precipitation":
                    normalize_mean[var] = np.zeros_like(normalize_mean[var])
                normalize_std[var] = std

        np.savez(os.path.join(save_dir, "normalize_mean.npz"), **normalize_mean)
        np.savez(os.path.join(save_dir, "normalize_std.npz"), **normalize_std)

    for var in climatology.keys():
        climatology[var] = np.stack(climatology[var], axis=0)
    climatology = {k: np.mean(v, axis=0) for k, v in climatology.items()}
    print("Climatology !!!!!!")
    for k, v in climatology.items():
      print(k, v.shape)
    np.savez(
        os.path.join(save_dir, partition, "climatology.npz"),
        **climatology,
    )

In [None]:
nc2np("data", variables, [2024], "processed", "test", 4)

100%|██████████| 1/1 [00:01<00:00,  1.16s/it]

(24, 81, 57)
(24, 81, 57)
[2.6178826e-04 8.6783106e-04 2.1930784e-05 3.3346005e-05 1.7786934e-04
 3.2709690e-04 4.1101594e-04 6.7328848e-04 7.2953431e-04 7.7630195e-04
 9.5749216e-04 9.1743527e-04 6.4181024e-04 1.4016451e-04 4.7666952e-05
 1.9094441e-05 8.5784122e-06 3.3346005e-05 7.5824326e-05 2.2221543e-04
 3.7670112e-04 6.9805607e-04 1.7461781e-03 1.2502746e-03]
[0.00026179 0.00112962 0.00115155 0.0011849  0.00136277 0.00168986
 0.00183909 0.00164455 0.00235215 0.00309511 0.00387473 0.00446507
 0.00469586 0.00416274 0.00348087 0.00272366 0.00177475 0.00089066
 0.00032467 0.00040673 0.00073576 0.00141472 0.00315232 0.00436925]
[0.23253012 0.7559433  0.7661886  0.7815685  0.85983276 0.98949003
 1.0434837  0.9725003  1.2096024  1.4097929  1.584065   1.6983767
 1.7397404  1.6414671  1.4998174  1.3147082  1.0205607  0.6369262
 0.28116703 0.3412652  0.5514455  0.88158417 1.4236674  1.6806884 ]
===
2m_temperature (24, 1, 81, 57)
total_precipitation (24, 1, 81, 57)
===
Climatology !!!!!!
2m




In [None]:
import torch
from torch.utils.data import IterableDataset

In [None]:
class NpyReader(IterableDataset):
    def __init__(
        self,
        inp_file_list,
        out_file_list,
        variables,
        out_variables,
        shuffle=False,
    ):
        super().__init__()
        assert len(inp_file_list) == len(out_file_list)
        self.inp_file_list = [f for f in inp_file_list if "climatology" not in f]
        self.out_file_list = [f for f in out_file_list if "climatology" not in f]
        self.variables = variables
        self.out_variables = out_variables if out_variables is not None else variables
        self.shuffle = shuffle

    def __iter__(self):

        n_files = len(self.inp_file_list)

        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            iter_start = 0
            iter_end = n_files
        else:
            if not torch.distributed.is_initialized():
                rank = 0
                world_size = 1
            else:
                rank = torch.distributed.get_rank()
                world_size = torch.distributed.get_world_size()
            num_workers_per_ddp = worker_info.num_workers
            num_shards = num_workers_per_ddp * world_size
            per_worker = n_files // num_shards
            worker_id = rank * num_workers_per_ddp + worker_info.id
            iter_start = worker_id * per_worker
            iter_end = iter_start + per_worker

        for idx in range(iter_start, iter_end):
            path_inp = self.inp_file_list[idx]
            path_out = self.out_file_list[idx]
            inp = np.load(path_inp)
            if path_out == path_inp:
                out = inp
            else:
                out = np.load(path_out)
            yield {k: np.squeeze(inp[k], axis=1) for k in self.variables}, {
                k: np.squeeze(out[k], axis=1) for k in self.out_variables
            }, self.variables, self.out_variables

In [None]:
inp_lister_test = sorted(
    glob.glob(os.path.join("processed", "test", "*.npz"))
)
out_lister_test = sorted(
    glob.glob(os.path.join("processed", "test", "*.npz"))
)

In [None]:
print(inp_lister_test)

['processed/test/2024_0.npz', 'processed/test/2024_1.npz', 'processed/test/2024_2.npz', 'processed/test/2024_3.npz', 'processed/test/climatology.npz']


In [None]:
tmp = NpyReader(
    inp_file_list=inp_lister_test,
    out_file_list=out_lister_test,
    variables=variables,
    out_variables=variables,
    shuffle=False,
)

In [None]:
for xIn, xOut, varIn, varOut in tmp:
  for k, v in xIn.items():
    print(k, v.shape)
  print(varIn)

2m_temperature (6, 81, 57)
total_precipitation (6, 81, 57)
['2m_temperature', 'total_precipitation']
2m_temperature (6, 81, 57)
total_precipitation (6, 81, 57)
['2m_temperature', 'total_precipitation']
2m_temperature (6, 81, 57)
total_precipitation (6, 81, 57)
['2m_temperature', 'total_precipitation']
2m_temperature (6, 81, 57)
total_precipitation (6, 81, 57)
['2m_temperature', 'total_precipitation']


In [None]:
class DirectForecast(IterableDataset):
    def __init__(self, dataset, src, pred_range=6, history=3, window=6):
        super().__init__()
        self.dataset = dataset
        self.history = history
        if src == "era5":
            self.pred_range = pred_range
            self.window = window
        elif src == "mpi-esm1-2-hr":
            assert pred_range % 6 == 0
            assert window % 6 == 0
            self.pred_range = pred_range // 6
            self.window = window // 6

    def __iter__(self):
        for inp_data, out_data, variables, out_variables in self.dataset:
            inp_data = {
                k: torch.from_numpy(inp_data[k].astype(np.float32))
                .unsqueeze(0)
                .repeat_interleave(self.history, dim=0)
                for k in inp_data.keys()
            }
            out_data = {
                k: torch.from_numpy(out_data[k].astype(np.float32))
                for k in out_data.keys()
            }
            for key in inp_data.keys():
                for t in range(self.history):
                    inp_data[key][t] = inp_data[key][t].roll(-t * self.window, dims=0)

            last_idx = -((self.history - 1) * self.window + self.pred_range)

            inp_data = {
                k: inp_data[k][:, :last_idx].transpose(0, 1)
                for k in inp_data.keys()  # N, T, H, W
            }

            inp_data_len = inp_data[variables[0]].size(0)

            predict_ranges = torch.ones(inp_data_len).to(torch.long) * self.pred_range
            output_ids = (
                torch.arange(inp_data_len)
                + (self.history - 1) * self.window
                + predict_ranges
            )
            out_data = {k: out_data[k][output_ids] for k in out_data.keys()}
            yield inp_data, out_data, variables, out_variables

In [None]:
tmp1 = DirectForecast(tmp, 'era5', pred_range=1, history=2, window=2)

In [None]:
# let's see what goes in:
for xIn, xOut, varIn, varOut in tmp:
  print("in: ")
  for k, v in xIn.items():
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  break

in: 
2m_temperature      : [273.28662 273.08453 271.6197  270.7301  270.22946 270.12387]
total_precipitation : [0.23253012 0.7559433  0.7661886  0.7815685  0.85983276 0.98949003]
out: 
2m_temperature      : [273.28662 273.08453 271.6197  270.7301  270.22946 270.12387]
total_precipitation : [0.23253012 0.7559433  0.7661886  0.7815685  0.85983276 0.98949003]


In [None]:
for xIn, xOut, varIn, varOut in tmp1:
  print("in: ")
  for k, v in xIn.items():
    print('-')
    print(f'{k.ljust(20)}: {v[:, :, 0, 0]}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  break

in: 
-
2m_temperature      : tensor([[273.2866, 271.6197],
        [273.0845, 270.7301],
        [271.6197, 270.2295]])
-
total_precipitation : tensor([[0.2325, 0.7662],
        [0.7559, 0.7816],
        [0.7662, 0.8598]])
out: 
2m_temperature      : tensor([270.7301, 270.2295, 270.1239])
total_precipitation : tensor([0.7816, 0.8598, 0.9895])


In [None]:
# let's check input dimensions:
for xIn, xOut, varIn, varOut in tmp:
  print("in: ")
  for k, v in xIn.items():
    print(f'{k.ljust(20)}: {v.shape}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v.shape}')
  break

in: 
2m_temperature      : (6, 81, 57)
total_precipitation : (6, 81, 57)
out: 
2m_temperature      : (6, 81, 57)
total_precipitation : (6, 81, 57)


In [None]:
# let's check the output dim's:
for xIn, xOut, varIn, varOut in tmp1:
  print("in: ")
  for k, v in xIn.items():
    print(f'{k.ljust(20)}: {v.shape}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v.shape}')
  break

in: 
2m_temperature      : torch.Size([4, 2, 81, 57])
total_precipitation : torch.Size([4, 2, 81, 57])
out: 
2m_temperature      : torch.Size([4, 81, 57])
total_precipitation : torch.Size([4, 81, 57])


In [None]:
for xIn, xOut, varIn, varOut in tmp1:
  print("in: ")
  for k, v in xIn.items():
    print('-')
    print(f'{k.ljust(20)}: {v[:, :, 0, 0]}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  break

in: 
-
2m_temperature      : tensor([[273.2866, 273.0845],
        [273.0845, 271.6197],
        [271.6197, 270.7301],
        [270.7301, 270.2295]])
-
total_precipitation : tensor([[0.2325, 0.7559],
        [0.7559, 0.7662],
        [0.7662, 0.7816],
        [0.7816, 0.8598]])
out: 
2m_temperature      : tensor([271.6197, 270.7301, 270.2295, 270.1239])
total_precipitation : tensor([0.7662, 0.7816, 0.8598, 0.9895])


In [None]:
class ContinuousForecast(IterableDataset):
    def __init__(
        self,
        dataset,
        random_lead_time=True,
        min_pred_range=6,
        max_pred_range=120,
        hrs_each_step=1,
        history=3,
        window=6,
    ):
        super().__init__()
        if not random_lead_time:
            assert min_pred_range == max_pred_range
        self.dataset = dataset
        self.random_lead_time = random_lead_time
        self.min_pred_range = min_pred_range
        self.max_pred_range = max_pred_range
        self.hrs_each_step = hrs_each_step
        self.history = history
        self.window = window

    def __iter__(self):
        for inp_data, out_data, variables, out_variables in self.dataset:
            inp_data = {
                k: torch.from_numpy(inp_data[k].astype(np.float32))
                .unsqueeze(0)
                .repeat_interleave(self.history, dim=0)
                for k in inp_data.keys()
            }
            out_data = {
                k: torch.from_numpy(out_data[k].astype(np.float32))
                for k in out_data.keys()
            }
            for key in inp_data.keys():
                for t in range(self.history):
                    inp_data[key][t] = inp_data[key][t].roll(-t * self.window, dims=0)

            last_idx = -((self.history - 1) * self.window + self.max_pred_range)

            inp_data = {
                k: inp_data[k][:, :last_idx].transpose(0, 1)
                for k in inp_data.keys()  # N, T, H, W
            }

            inp_data_len = inp_data[variables[0]].size(0)
            dtype = inp_data[variables[0]].dtype

            if self.random_lead_time:
                predict_ranges = torch.randint(
                    low=self.min_pred_range,
                    high=self.max_pred_range + 1,
                    size=(inp_data_len,),
                )
            else:
                predict_ranges = (
                    torch.ones(inp_data_len).to(torch.long) * self.max_pred_range
                )
            lead_times = self.hrs_each_step * predict_ranges / 100
            lead_times = lead_times.to(dtype)
            output_ids = (
                torch.arange(inp_data_len)
                + (self.history - 1) * self.window
                + predict_ranges
            )

            out_data = {k: out_data[k][output_ids] for k in out_data.keys()}
            yield inp_data, out_data, lead_times, variables, out_variables

In [None]:
tmp2 = ContinuousForecast(tmp, random_lead_time=True, hrs_each_step=1, min_pred_range=1, max_pred_range=2, history=2, window=2)

In [None]:
for xIn, xOut, varIn, varOut in tmp:
  print("in: ")
  for k, v in xIn.items():
    print('-')
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  break

in: 
-
2m_temperature      : [273.28662 273.08453 271.6197  270.7301  270.22946 270.12387]
-
total_precipitation : [0.23253012 0.7559433  0.7661886  0.7815685  0.85983276 0.98949003]
out: 
2m_temperature      : [273.28662 273.08453 271.6197  270.7301  270.22946 270.12387]
total_precipitation : [0.23253012 0.7559433  0.7661886  0.7815685  0.85983276 0.98949003]


In [None]:
for xIn, xOut, leadTime, varIn, varOut in tmp2:
  print(f"lead time: {leadTime}")
  print("in: ")
  for k, v in xIn.items():
    print('-')
    print(f'{k.ljust(20)}: {v[:, :, 0, 0]}')
  print("out: ")
  for k, v in xOut.items():
    print(f'{k.ljust(20)}: {v[:, 0, 0]}')
  break

lead time: tensor([0.0200, 0.0100])
in: 
-
2m_temperature      : tensor([[273.2866, 271.6197],
        [273.0845, 270.7301]])
-
total_precipitation : tensor([[0.2325, 0.7662],
        [0.7559, 0.7816]])
out: 
2m_temperature      : tensor([270.2295, 270.2295])
total_precipitation : tensor([0.8598, 0.8598])


# **Simple synthetic test**

In [None]:
class FakeData(IterableDataset):
    def __init__(self, dim):
        super().__init__()
        self._niter = 1
        self._nvar = 1
        self._vars = [f'k{i}' for i in range(self._nvar)]
        self.dim = dim

    def __iter__(self):
        for i in range(self._niter):
            yield {k: (i+1)*np.arange(self.dim) for k in self._vars}, {
                k: (i+1)*np.arange(self.dim) for k in self._vars
            }, self._vars, self._vars

In [None]:
def printNoLead(x):
  for xi, xo, vi, vo in x:
    print("In: ")
    for k, v in xi.items():
      print(f'{k.ljust(10)} --> {v}')
    print("Out: ")
    for k, v in xo.items():
      print(f'{k.ljust(10)} --> {v}')

def printLead(x):
  for xi, xo, leadTime, vi, vo in x:
    print("In: ")
    print(f"lead time: {leadTime*100}")
    for k, v in xi.items():
      print(f'{k.ljust(10)} --> {v}')
    print("Out: ")
    for k, v in xo.items():
      print(f'{k.ljust(10)} --> {v}')

In [None]:
dataIn = FakeData(20)

In [None]:
printNoLead(dataIn)

In: 
k0         --> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
Out: 
k0         --> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]


In [None]:
dataDirect = DirectForecast(dataIn, 'era5', pred_range=1, history=3, window=5)

In [None]:
printNoLead(dataDirect)

In: 
k0         --> tensor([[ 0.,  5., 10.],
        [ 1.,  6., 11.],
        [ 2.,  7., 12.],
        [ 3.,  8., 13.],
        [ 4.,  9., 14.],
        [ 5., 10., 15.],
        [ 6., 11., 16.],
        [ 7., 12., 17.],
        [ 8., 13., 18.]])
Out: 
k0         --> tensor([11., 12., 13., 14., 15., 16., 17., 18., 19.])


In [None]:
dataContinousRand = ContinuousForecast(dataIn, random_lead_time=True, min_pred_range=1, max_pred_range=4,
                                   hrs_each_step=1, history=3, window=5)

In [None]:
printLead(dataContinousRand)

In: 
lead time: tensor([1., 2., 1., 1., 3., 1.])
k0         --> tensor([[ 0.,  5., 10.],
        [ 1.,  6., 11.],
        [ 2.,  7., 12.],
        [ 3.,  8., 13.],
        [ 4.,  9., 14.],
        [ 5., 10., 15.]])
Out: 
k0         --> tensor([11., 13., 13., 14., 17., 16.])


In [None]:
dataContinousFixed = ContinuousForecast(dataIn, random_lead_time=False, min_pred_range=4, max_pred_range=4,
                                   hrs_each_step=1, history=4, window=3)

In [None]:
printLead(dataContinousFixed)

In: 
lead time: tensor([4., 4., 4., 4., 4., 4., 4.])
k0         --> tensor([[ 0.,  3.,  6.,  9.],
        [ 1.,  4.,  7., 10.],
        [ 2.,  5.,  8., 11.],
        [ 3.,  6.,  9., 12.],
        [ 4.,  7., 10., 13.],
        [ 5.,  8., 11., 14.],
        [ 6.,  9., 12., 15.]])
Out: 
k0         --> tensor([13., 14., 15., 16., 17., 18., 19.])


In [None]:
class Downscale(IterableDataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __iter__(self):
        for inp_data, out_data, variables, out_variables in self.dataset:
            inp_data = {
                k: torch.from_numpy(inp_data[k].astype(np.float32))
                for k in inp_data.keys()
            }
            out_data = {
                k: torch.from_numpy(out_data[k].astype(np.float32))
                for k in out_data.keys()
            }
            yield inp_data, out_data, variables, out_variables

In [None]:
class IndividualDataIter(IterableDataset):
    def __init__(
        self,
        dataset,
        transforms,
        output_transforms,
        subsample=6,
    ):
        super().__init__()
        self.dataset = dataset
        self.transforms = transforms
        self.output_transforms = output_transforms
        self.subsample = subsample

    def __iter__(self):
        for sample in self.dataset:
            if isinstance(self.dataset, (DirectForecast, Downscale)):
                inp, out, variables, out_variables = sample
            elif isinstance(self.dataset, ContinuousForecast):
                inp, out, lead_times, variables, out_variables = sample
            inp_shapes = set([inp[k].shape[0] for k in inp.keys()])
            out_shapes = set([out[k].shape[0] for k in out.keys()])
            assert len(inp_shapes) == 1
            assert len(out_shapes) == 1
            inp_len = next(iter(inp_shapes))
            out_len = next(iter(out_shapes))
            assert inp_len == out_len
            for i in range(0, inp_len, self.subsample):
                x = {k: inp[k][i] for k in inp.keys()}
                y = {k: out[k][i] for k in out.keys()}
                if self.transforms is not None:
                    if isinstance(self.dataset, (DirectForecast, ContinuousForecast)):
                        x = {
                            k: self.transforms[k](x[k].unsqueeze(1)).squeeze(1)
                            for k in x.keys()
                        }
                    elif isinstance(self.dataset, Downscale):
                        x = {
                            k: self.transforms[k](x[k].unsqueeze(0)).squeeze(0)
                            for k in x.keys()
                        }
                    else:
                        raise RuntimeError(f"Not supported task.")
                if self.output_transforms is not None:
                    y = {
                        k: self.output_transforms[k](y[k].unsqueeze(0)).squeeze(0)
                        for k in y.keys()
                    }
                if isinstance(self.dataset, (DirectForecast, Downscale)):
                    result = x, y, variables, out_variables
                elif isinstance(self.dataset, ContinuousForecast):
                    result = x, y, lead_times[i], variables, out_variables
                yield result

### fake data testing individual data iterable

In [None]:
class FakeData(IterableDataset):
    def __init__(self, dim):
        super().__init__()
        self._niter = 1
        self._nvar = 2
        self._vars = [f'k{i}' for i in range(self._nvar)]
        self.dim = dim

    def __iter__(self):
        for i in range(self._niter):
            yield {k: (i+1)*np.arange(self.dim) for k in self._vars}, {
                k: (i+1)*np.arange(self.dim) for k in self._vars
            }, self._vars, self._vars

In [None]:
dataIn = FakeData(20)

In [None]:
printNoLead(dataIn)

In: 
k0         --> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
k1         --> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
Out: 
k0         --> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
k1         --> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]


In [None]:
dataDirect = DirectForecast(dataIn, 'era5', pred_range=1, history=3, window=5)

In [None]:
printNoLead(dataDirect)

In: 
k0         --> tensor([[ 0.,  5., 10.],
        [ 1.,  6., 11.],
        [ 2.,  7., 12.],
        [ 3.,  8., 13.],
        [ 4.,  9., 14.],
        [ 5., 10., 15.],
        [ 6., 11., 16.],
        [ 7., 12., 17.],
        [ 8., 13., 18.]])
k1         --> tensor([[ 0.,  5., 10.],
        [ 1.,  6., 11.],
        [ 2.,  7., 12.],
        [ 3.,  8., 13.],
        [ 4.,  9., 14.],
        [ 5., 10., 15.],
        [ 6., 11., 16.],
        [ 7., 12., 17.],
        [ 8., 13., 18.]])
Out: 
k0         --> tensor([11., 12., 13., 14., 15., 16., 17., 18., 19.])
k1         --> tensor([11., 12., 13., 14., 15., 16., 17., 18., 19.])


In [None]:
directInd = IndividualDataIter(dataDirect, transforms=None, output_transforms=None, subsample=4)

In [None]:
for sample in directInd:
  print(sample)

({'k0': tensor([ 0.,  5., 10.]), 'k1': tensor([ 0.,  5., 10.])}, {'k0': tensor(11.), 'k1': tensor(11.)}, ['k0', 'k1'], ['k0', 'k1'])
({'k0': tensor([ 4.,  9., 14.]), 'k1': tensor([ 4.,  9., 14.])}, {'k0': tensor(15.), 'k1': tensor(15.)}, ['k0', 'k1'], ['k0', 'k1'])
({'k0': tensor([ 8., 13., 18.]), 'k1': tensor([ 8., 13., 18.])}, {'k0': tensor(19.), 'k1': tensor(19.)}, ['k0', 'k1'], ['k0', 'k1'])


In [3]:
import torch
from typing import Optional, Union

In [4]:
def gaussian_crps(
    pred: torch.distributions.Normal,
    target: Union[torch.FloatTensor, torch.DoubleTensor],
    aggregate_only: bool = False,
    lat_weights: Optional[Union[torch.FloatTensor, torch.DoubleTensor]] = None,
) -> Union[torch.FloatTensor, torch.DoubleTensor]:
    mean, std = pred.loc, pred.scale
    z = (target - mean) / std
    standard_normal = torch.distributions.Normal(
        torch.zeros_like(pred), torch.ones_like(pred)
    )
    pdf = torch.exp(standard_normal.log_prob(z))
    cdf = standard_normal.cdf(z)
    crps = std * (z * (2 * cdf - 1) + 2 * pdf - 1 / torch.pi)
    if lat_weights is not None:
        crps = crps * lat_weights
    per_channel_losses = crps.mean([0, 2, 3])
    loss = crps.mean()
    if aggregate_only:
        return loss
    return torch.cat((per_channel_losses, loss.unsqueeze(0)))

In [5]:
size = (2, 3, 4, 5)

In [18]:
m1 = torch.randn(size)
s1 = torch.rand(size)

In [19]:
p1 = torch.distributions.Normal(m1, s1)
t1 = m1+0.01

In [20]:
p1.size

AttributeError: 'Normal' object has no attribute 'size'

In [22]:
p1.loc.shape

torch.Size([2, 3, 4, 5])

## Test Metrics classes

In [8]:
# Standard library
from dataclasses import dataclass
from functools import wraps
from typing import List, Union, Optional

# Third party
import numpy.typing as npt
import numpy as np
import torch

Pred = Union[torch.FloatTensor, torch.DoubleTensor, torch.distributions.Normal]


@dataclass
class MetricsMetaInfo:
    in_vars: List[str]
    out_vars: List[str]
    lat: npt.ArrayLike
    lon: npt.ArrayLike
    climatology: torch.Tensor

METRICS_REGISTRY = {}

def register(name):
    def decorator(metric_class):
        METRICS_REGISTRY[name] = metric_class # log
        metric_class.name = name # add name attribute to the class
        return metric_class

    return decorator

def handles_probabilistic(metric):
    @wraps(metric) # keep the metadata from metric
    def wrapper(pred: Pred, *args, **kwargs):
        if isinstance(pred, torch.distributions.Normal):
            pred = pred.loc
        return metric(pred, *args, **kwargs)

    return wrapper

In [6]:
class Metric:
    """Parent class for all ClimateLearn metrics."""

    def __init__(
        self, aggregate_only: bool = False, metainfo: Optional[MetricsMetaInfo] = None
    ):
        r"""
        .. highlight:: python

        :param aggregate_only: If false, returns both the aggregate and
            per-channel metrics. Otherwise, returns only the aggregate metric.
            Default is `False`.
        :type aggregate_only: bool
        :param metainfo: Optional meta-information used by some metrics.
        :type metainfo: MetricsMetaInfo|None
        """
        self.aggregate_only = aggregate_only
        self.metainfo = metainfo

    def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        :param pred: The predicted value(s).
        :type pred: torch.Tensor
        :param target: The ground truth target value(s).
        :type target: torch.Tensor

        :return: A tensor. See child classes for specifics.
        :rtype: torch.Tensor
        """
        raise NotImplementedError()


class LatitudeWeightedMetric(Metric):
    """Parent class for latitude-weighted metrics."""

    def __init__(
        self, aggregate_only: bool = False, metainfo: Optional[MetricsMetaInfo] = None
    ):
        super().__init__(aggregate_only, metainfo)
        lat_weights = np.cos(np.deg2rad(self.metainfo.lat))
        lat_weights = lat_weights / lat_weights.mean()
        lat_weights = torch.from_numpy(lat_weights).view(1, 1, -1, 1)
        self.lat_weights = lat_weights

    def cast_to_device(
        self, pred: Union[torch.FloatTensor, torch.DoubleTensor]
    ) -> None:
        r"""
        .. highlight:: python

        Casts latitude weights to the same device as `pred`.
        """
        self.lat_weights = self.lat_weights.to(device=pred.device)

In [28]:
inVar = ['k1', 'k2']
outVar = ['k1', 'k2']
ddeg=1.
lat = np.arange(30.+ddeg/2.,35.,ddeg)
lon = np.arange(-120+ddeg/2., -112., ddeg)

In [29]:
print(lat.shape, lon.shape)

(5,) (8,)


In [30]:
minfo = MetricsMetaInfo(inVar, outVar, lat, lon, None)

In [32]:
minfo.lat, minfo.lon

(array([30.5, 31.5, 32.5, 33.5, 34.5]),
 array([-119.5, -118.5, -117.5, -116.5, -115.5, -114.5, -113.5, -112.5]))

In [36]:
oneDegGrid =  LatitudeWeightedMetric(aggregate_only=False, metainfo=minfo)

In [39]:
print(oneDegGrid.lat_weights)
oneDegGrid.lat_weights.shape # B, C, H(lat), W(lon)

tensor([[[[1.0219],
          [1.0113],
          [1.0003],
          [0.9890],
          [0.9775]]]], dtype=torch.float64)


torch.Size([1, 1, 5, 1])

### generate fake data on the same grid

In [60]:
batch_size = 3
size = (batch_size, len(inVar), len(lat), len(lon))
pred = torch.ones(size)
target = torch.zeros(size)

In [41]:
@handles_probabilistic
def mse(
    pred: Pred,
    target: Union[torch.FloatTensor, torch.DoubleTensor],
    aggregate_only: bool = False,
    lat_weights: Optional[Union[torch.FloatTensor, torch.DoubleTensor]] = None,
) -> Union[torch.FloatTensor, torch.DoubleTensor]:
    error = (pred - target).square()
    if lat_weights is not None:
        error = error * lat_weights
    per_channel_losses = error.mean([0, 2, 3])
    loss = error.mean()
    if aggregate_only:
        return loss
    return torch.cat((per_channel_losses, loss.unsqueeze(0)))

@register("lat_mse")
class LatWeightedMSE(LatitudeWeightedMetric):
    """Computes latitude-weighted mean-squared error."""

    def __call__(
        self,
        pred: Union[torch.FloatTensor, torch.DoubleTensor],
        target: Union[torch.FloatTensor, torch.DoubleTensor],
    ) -> Union[torch.FloatTensor, torch.DoubleTensor]:
        r"""
        .. highlight:: python

        :param pred: The predicted values of shape [B,C,H,W].
        :type pred: torch.FloatTensor|torch.DoubleTensor
        :param target: The ground truth target values of shape [B,C,H,W].
        :type target: torch.FloatTensor|torch.DoubleTensor

        :return: A singleton tensor if `self.aggregate_only` is `True`. Else, a
            tensor of shape [C+1], where the last element is the aggregate
            MSE, and the preceding elements are the channel-wise MSEs.
        :rtype: torch.FloatTensor|torch.DoubleTensor
        """
        super().cast_to_device(pred)
        return mse(pred, target, self.aggregate_only, self.lat_weights)

In [44]:
tmp = LatWeightedMSE(aggregate_only=False, metainfo=minfo)

In [45]:
METRICS_REGISTRY

{'lat_mse': __main__.LatWeightedMSE}

In [61]:
tmp(pred, target)

tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64)