# Develop a dataset creation process that also works for the conditional case:

The function that creates the dataset should take the following arguments:

- root directory in which the Weatherbench Dataset got installed
- startdate and enddate for each test- train- and validation set. Datetime objects.
- Variables: dict.

- delta_t between timesteps (int)
- lead_time (int, in terms of timesteps)
- conditioning_timesteps (list of ints)

As long as the dataset size still allows it, write the data to a single .pt file.

In [2]:
# from WD.datasets import write_conditional_datasets

# new refactored function:
from WD.datasets import write_conditional_datasets, Conditional_Dataset


# Write some example data sets.

In [2]:
write_conditional_datasets("/data/compoundx/WeatherDiff/config_file/template.yaml")

Load config file.
Open datasets.
Number of conditioning variables: 3
Number of constant variables: 1
Number of output variables: 2
Normalize datasets.
Slice into train, test and validation set and write to .pt files.
{'data_specs': {'conditioning_time_step': [0, -1, -2], 'delta_t': 6, 'lead_time': 1, 'spatial_resolution': '5.625deg', 'constants': ['orography'], 'conditioning_vars': {'total_precipitation': {'level': None}, 'geopotential': {'level': [50, 250]}}, 'output_vars': {'total_precipitation': {'level': None}, 'geopotential': {'level': [500]}}}, 'exp_data': {'train': {'from': '1979-01-02 00:00:00', 'to': '2015-12-31 00:00:00'}, 'test': {'from': '2017-01-01 00:00:00', 'to': '2018-12-31 00:00:00'}, 'val': {'from': '2016-01-01 00:00:00', 'to': '2016-12-31 00:00:00'}}, 'file_structure': {'dir_WeatherBench': '/data/compoundx/WeatherBench/', 'dir_model_input': '/data/compoundx/WeatherDiff/model_input/'}, 'ds_id': 'F01A8B', 'git-rev-parse': {'dm_zoo': 'a233ea15df1416edb47f8b3afc2478677f8

# Test Pytorch Datasets:

In [10]:
# from WD.datasets import Conditional_Dataset, Unconditional_Dataset

from WD.datasets import Conditional_Dataset, open_datasets
from WD.io import load_config

import xarray as xr
import os

In [19]:
config = load_config("/data/compoundx/WeatherDiff/config_file/template_rasp_thuerey_no_precip.yml")

In [20]:
conditioning_variables = config.data_specs.conditioning_vars.toDict()
output_variables = config.data_specs.output_vars.toDict()

In [8]:
conditioning_variables

{'temperature': {'level': [50, 250, 500, 600, 700, 850, 925]},
 'geopotential': {'level': [50, 250, 500, 600, 700, 850, 925]},
 'u_component_of_wind': {'level': [50, 250, 500, 600, 700, 850, 925]},
 'v_component_of_wind': {'level': [50, 250, 500, 600, 700, 850, 925]},
 'specific_humidity': {'level': [50, 250, 500, 600, 700, 850, 925]},
 '2m_temperature': None,
 'total_precipitation': None,
 'toa_incident_solar_radiation': None}

In [16]:
res_datasets = []
root_dir = "/data/compoundx/WeatherBench/"
spatial_resolution = "5.625deg"

for foldername, var_config in conditioning_variables.items():
    print(var_config)
    path = os.path.join(
        root_dir,
        foldername,
        "*_{}.nc".format(spatial_resolution),
    )
    print(path)
    ds = xr.open_mfdataset(path)
    """
    assert len(ds.keys()) == 1
    varname = list(ds.keys())[0]

    if varname == "tp":
        ds = ds.rolling(time=6).sum()  # take 6 hour average
        ds["tp"] = transform_precipitation(ds)

    # extract desired pressure levels:
    if var_config["level"] is not None:
        datasets = create_variables_from_pressure_levels(
            ds=ds, var_config=var_config
        )
        res_datasets.extend(datasets)
    else:
        if "level" in ds.dims:
            assert ds.level.size == 1, (
                "The given dataset is defined at more"
                " than one pressure level, but no"
                " pressure levels were selected in the"
                " configuration file."
            )
        if "level" in ds.var():
            ds = ds.drop_vars("level")
        res_datasets.append(ds)
    """

{'level': [50, 250, 500, 600, 700, 850, 925]}
/data/compoundx/WeatherBench/temperature/*_5.625deg.nc




{'level': [50, 250, 500, 600, 700, 850, 925]}
/data/compoundx/WeatherBench/geopotential/*_5.625deg.nc


KeyboardInterrupt: 

In [22]:
a = open_datasets("/data/compoundx/WeatherBench/", conditioning_variables, "5.625deg")

In [3]:
ds_train = Conditional_Dataset(
    pt_file_path="/data/compoundx/WeatherDiff/model_input/E0876B_train.pt",
    config_file_path="/data/compoundx/WeatherDiff/config_file/E0876B.yml"
)

# for old version, it would have been called like this:
# ds_train = Conditional_Dataset(pt_file_path="/data/compoundx/WeatherDiff/input_data/6A9C62_train.pt")