In [None]:
# Common imports
import argparse
import time
import numpy as np
import os
from datetime import datetime
import sys; sys.path.append('../deepdown/')

# Import torch
from torch.utils.data import Dataset
from deepdown.utils.data_generators import DataGenerator
import torch

# Utils
from utils.data_loader import load_target_data, load_input_data
from utils.loss_fcts import *
from utils.data_generators import DataGenerator
from utils.helpers import print_cuda_availability, DEVICE
from models.srgan import Generator, Discriminator
from config import Config

### Read the configuration file and load the data

In [None]:
from argparse import Namespace
cli_args = Namespace(config_file="../config.yaml")
conf = Config(cli_args)
conf.print()

In [None]:
# Date parameters
date_start = conf.config.date_start
date_end = conf.config.date_end

# Years for training, validation, and testing
years_train = conf.config.years_train
years_valid = conf.config.years_valid
years_test = conf.config.years_test

# Other parameters
levels = conf.config.levels
resol_low = conf.config.resol_low

# Input variables and paths
input_variables = conf.config.input_vars
input_paths = [
    conf.config.path_era5land + '/precipitation',
    conf.config.path_era5land + '/temperature',
    conf.config.path_era5land + '/max_temperature/',
    conf.config.path_era5land + '/min_temperature/'
]

# Crop parameters
do_crop = conf.config.do_crop
crop_x = conf.config.lon_limits
crop_y = conf.config.lat_limits

target_paths = [
    conf.config.path_mch + '/RhiresD_v2.0_swiss.lv95/',
    conf.config.path_mch + '/TabsD_v2.0_swiss.lv95/',
    conf.config.path_mch + '/TmaxD_v2.0_swiss.lv95/',
    conf.config.path_mch + '/TminD_v2.0_swiss.lv95/'
]
target_vars = conf.config.target_vars

In [None]:
# Load target data
target = load_target_data(date_start, date_end, target_paths,
                              path_tmp=conf.config.path_tmp)



In [None]:
# Extract the axes of the final target domain based on temperature 
x_axis = target.x
y_axis = target.y

In [None]:
input_data = load_input_data(date_start=date_start, date_end=date_end, levels = levels, resol_low = 0.1, x_axis = x_axis, y_axis= y_axis, 
                                 paths = input_paths, path_dem=conf.config.path_dem, dump_data_to_pickle=True, path_tmp='../tmp/')



In [None]:
# if do_crop:
    
#     input_data = input_data.sel(x=slice(min(crop_x), max(crop_x)),
#                                     y=slice(max(crop_y), min(crop_y)))
#     target = target.sel(x=slice(min(crop_x), max(crop_x)),
#                             y=slice(max(crop_y), min(crop_y)))

# Split the data
x_train = input_data.sel(time=slice( datetime(years_train[0],1,1), datetime(years_train[1], 12, 31)))
x_valid = input_data.sel(time=slice( datetime(years_valid[0],1,1), datetime(years_valid[1], 12, 31)))
x_test = input_data.sel(time=slice( datetime(years_test[0],1,1), datetime(years_test[1], 12, 31)))

y_train = target.sel(time=slice( datetime(years_train[0],1,1), datetime(years_train[1], 12, 31)))
y_valid = target.sel(time=slice( datetime(years_valid[0],1,1), datetime(years_valid[1], 12, 31)))
y_test = target.sel(time=slice( datetime(years_test[0],1,1), datetime(years_test[1], 12, 31)))



In [None]:
# Select the variables to use as input and output
input_vars = {'band_data': None, 'tp': None, 't2m': None, 't2m_min': None, 't2m_max': None}
output_vars = ['tp', 't']  

In [None]:
training_set = DataGenerator(inputs=x_train, outputs=y_train, input_vars=input_vars, 
              output_vars=output_vars, do_crop= True, crop_x = crop_x, crop_y=crop_y, shuffle=True, load=False,
                 mean=None, std=None, y_mean=None, y_std=None, tp_log=None)
loader_train = torch.utils.data.DataLoader(training_set, batch_size=32)
valid_set = DataGenerator(x_valid, y_valid, input_vars, output_vars, do_crop= True, crop_x = crop_x, crop_y=crop_y, shuffle=False, load=False,
                              mean=training_set.mean, std=training_set.std)
loader_val = torch.utils.data.DataLoader(valid_set, batch_size=32)
test_set = DataGenerator(x_test, y_test, input_vars, output_vars, do_crop= True, crop_x = crop_x, crop_y=crop_y, shuffle=False, load=False,
                             mean=training_set.mean, std=training_set.std)
loader_test = torch.utils.data.DataLoader(test_set, batch_size=32)