In [None]:
# Common imports
import sys
sys.path.append('../deepdown/')

# Import torch
from torch.utils.data import Dataset

# Utils
from deepdown.utils.data_loader import load_target_data, load_input_data
from deepdown.utils.loss_fcts import *
from deepdown.utils.data_generators import DataGenerator
from deepdown.utils.helpers import split_data
from deepdown.config import Config

### Read the configuration file and load the data

In [None]:
from argparse import Namespace
cli_args = Namespace(config_file="../config.yaml")
config = Config(cli_args)
config.print()
conf = config.config

In [None]:
# Input variables and paths
input_paths = [
    conf.path_era5land + '/precipitation',
    conf.path_era5land + '/temperature',
    conf.path_era5land + '/max_temperature/',
    conf.path_era5land + '/min_temperature/'
]
target_paths = [
    conf.path_mch + '/RhiresD_v2.0_swiss.lv95/',
    conf.path_mch + '/TabsD_v2.0_swiss.lv95/',
    conf.path_mch + '/TmaxD_v2.0_swiss.lv95/',
    conf.path_mch + '/TminD_v2.0_swiss.lv95/'
]

In [None]:
# Load target data
target = load_target_data(conf.date_start, conf.date_end, target_paths,
                          path_tmp=conf.path_tmp)

In [None]:
input_data = load_input_data(
    date_start=conf.date_start, date_end=conf.date_end, levels=conf.levels,
    resol_low=conf.resol_low, x_axis=target.x, y_axis=target.y,
    paths=input_paths, path_dem=conf.path_dem, dump_data_to_pickle=True,
    path_tmp=conf.path_tmp)

In [None]:
# Split the data
x_train = split_data(input_data, conf.years_train)
x_valid = split_data(input_data, conf.years_valid)
x_test = split_data(input_data, conf.years_test)
y_train = split_data(target, conf.years_train)
y_valid = split_data(target, conf.years_valid)
y_test = split_data(target, conf.years_test)

In [None]:
training_set = DataGenerator(
    x_train, y_train, conf.input_vars, conf.target_vars, do_crop=conf.do_crop,
    crop_x=conf.crop_x, crop_y=conf.crop_y, shuffle=True, tp_log=None)
loader_train = torch.utils.data.DataLoader(training_set, batch_size=conf.batch_size)
valid_set = DataGenerator(
    x_valid, y_valid, conf.input_vars, conf.target_vars, do_crop=conf.do_crop,
    crop_x=conf.crop_x, crop_y=conf.crop_y, shuffle=False,
    x_mean=training_set.x_mean, x_std=training_set.x_std)
loader_val = torch.utils.data.DataLoader(valid_set, batch_size=conf.batch_size)
test_set = DataGenerator(
    x_test, y_test, conf.input_vars, conf.target_vars, do_crop=conf.do_crop,
    crop_x=conf.crop_x, crop_y=conf.crop_y, shuffle=False,
    x_mean=training_set.x_mean, x_std=training_set.x_std)
loader_test = torch.utils.data.DataLoader(test_set, batch_size=conf.batch_size)