# Setup

### Import necessary modules and do some basic setup.

In [None]:
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= '0.20'

from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

# TensorFlow ≥2.0 is required
import tensorflow_addons as tfa
import tensorflow as tf
assert tf.__version__ >= '2.0'

from tensorflow import keras
from tensorflow.keras import layers

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# Common imports
import os
import glob
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import dask
import datetime
import math
dask.config.set({'array.slicing.split_large_chunks': False})

# To make this notebook's output stable across runs
np.random.seed(42)

# Config matplotlib
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Dotenv
from dotenv import dotenv_values

# Custom utils
from utils.utils_data import *
from utils.utils_ml import *

### Define some paths and constants.

In [None]:
config = dotenv_values(".env")

# Paths
PATH_ERA5 = config['PATH_ERA5']
PATH_EOBS = config['PATH_EOBS']

# Some constants
G = 9.80665 
DATE_START = '1979-01-01'
DATE_END = '2020-12-31'
YY_TRAIN = [1979, 2015]
YY_TEST = [2016, 2020]
LEVELS = [500, 850, 1000]

# Data preparation

## Target variable: precipitation field

In [None]:
# Precipitation ERA5
pr = get_era5_data(PATH_ERA5 + '/precipitation/day_grid1/*nc', DATE_START, DATE_END)

# Define precipitation extremes using the 95th percentile
pr95 = precip_exceedance_xarray(pr, 0.95)

## Input data: meteorological fields

In [None]:
# Load geopotential height
z = get_era5_data(PATH_ERA5 + '/geopotential/grid1/*.nc', DATE_START, DATE_END)
z = z.sel(level=LEVELS)

# Get Z in geopotential height (m)
z.z.values = z.z.values/G

# Get axes
lats = z.lat
lons = z.lon

# Load temperature
t2m = get_era5_data(PATH_ERA5 + '/temperature/grid1/Grid1_Daymean_era5_T2M_EU_19790101-20211231.nc',
                    DATE_START, DATE_END)
t2m['time'] = pd.DatetimeIndex(t2m.time.dt.date)
t2m = t2m.rename_vars({'T2MMEAN': 't'})

# Load relative humidity
rh = get_era5_data(PATH_ERA5 + '/relative_humidity/day_grid1/*.nc',
                   DATE_START, DATE_END)
rh['time'] = pd.DatetimeIndex(rh.time.dt.date)
rh = rh.sel(level=LEVELS)

# Load wind components
u850 = get_era5_data(PATH_ERA5 + '/U_wind/day_grid1/*.nc',
                     DATE_START, DATE_END)
u850['time'] = pd.DatetimeIndex(u850.time.dt.date)
v850 = get_era5_data(PATH_ERA5 + '/V_wind/day_grid1/*.nc',
                     DATE_START, DATE_END)
v850['time'] = pd.DatetimeIndex(v850.time.dt.date)

# Checking dimensions
print('dimension of z', z.dims)
print('dimension of t2m:', t2m.dims)
print('dimension of rh:', rh.dims)
print('dimension of u:', u850.dims)
print('dimension of v:', v850.dims)
print('dimension of pr:', pr.dims)


In [None]:
# Merge arrays
X = xr.merge([z, t2m, rh, u850, v850])
X

### Split data and transform

In [None]:
# Split into training and test
X_train_full = X.sel(time=slice('{}-01-01'.format(YY_TRAIN[0]),
                                '{}-12-31'.format(YY_TRAIN[1])))
X_test = X.sel(time=slice('{}-01-01'.format(YY_TEST[0]),
                          '{}-12-31'.format(YY_TEST[1])))

pr_train_full = pr.sel(time=slice('{}-01-01'.format(YY_TRAIN[0]),
                                  '{}-12-31'.format(YY_TRAIN[1])))
pr_test = pr.sel(time=slice('{}-01-01'.format(YY_TEST[0]),
                            '{}-12-31'.format(YY_TEST[1])))
xtr_train_full = pr95.sel(time=slice('{}-01-01'.format(YY_TRAIN[0]),
                                     '{}-12-31'.format(YY_TRAIN[1])))
xtr_test = pr95.sel(time=slice('{}-01-01'.format(YY_TEST[0]),
                               '{}-12-31'.format(YY_TEST[1])))

In [None]:
# Create a data generator
dic = {'z': LEVELS,
       't': None,
       'r': LEVELS,
       'u': None,
       'v': None}

data_gen_train = DataGeneratorForPrecip(X_train_full.sel(time=slice('1979', '2010')),
                                        pr_train_full.sel(time=slice('1979', '2010')),
                                        dic, batch_size=32, load=True)
data_gen_valid = DataGeneratorForPrecip(X_train_full.sel(time=slice('2011', '2015')),
                                        pr_train_full.sel(time=slice('2011', '2015')),
                                        dic, mean=data_gen_train.mean, std=data_gen_train.std,
                                        batch_size=32, load=True)
data_gen_test = DataGeneratorForPrecip(X_test, pr_test, dic,
                                       mean=data_gen_train.mean, std=data_gen_train.std,
                                       batch_size=32, load=True, shuffle=False)

# Model creation

In [None]:
class EDM(tf.keras.Model):
    """Encoder decoder model."""

    def __init__(self, arch, input_size, output_size, for_extremes=False, latent_dim=128,
                 dropout_rate=0.2):
        super(EDM, self).__init__()
        self.arch = arch
        self.input_size = list(input_size)
        self.output_size = list(output_size)
        self.for_extremes = for_extremes
        self.latent_dim = latent_dim
        self.dropout_rate = dropout_rate

        if arch == 'cnn-v1':
            self.create_cnnv1()
        elif arch == 'cnn-v2':
            self.create_cnnv2()
        elif arch == 'cnn-v3':
            self.create_cnnv3()
        else:
            raise('The architecture was not correctly defined')
            self.create_cnnv1()
            
        self.crop_output()
        
        
    def create_cnnv1(self):
        self.encoder = tf.keras.Sequential(
            [
                layers.InputLayer(input_shape=self.input_size),
                layers.Conv2D(16, 3, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling2D(pool_size=2),
                layers.SpatialDropout2D(self.dropout_rate),
                layers.Conv2D(16, 3, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling2D(pool_size=2),
                layers.SpatialDropout2D(self.dropout_rate),
                layers.Flatten(),
                layers.Dense(self.latent_dim, activation='sigmoid'),
                layers.Dropout(self.dropout_rate),
            ]
        )

        preflat_shape = self.encoder.layers[-3].input.get_shape().as_list()[1:]

        print(preflat_shape)

        self.decoder = tf.keras.Sequential(
            [
                layers.InputLayer(input_shape=(self.latent_dim,)),
                layers.Dense(np.prod(preflat_shape), activation='relu'),


                layers.Reshape(target_shape=preflat_shape),
                layers.Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu'),
                layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu'),
                layers.Conv2DTranspose(1, 3, strides=1, padding='same', activation='relu'),
            ]
        )

        
    def create_cnnv2(self):
        """ self.encoder = tf.keras.Sequential(
            [
                layers.InputLayer(input_shape=self.data_size),
                layers.Conv2D(8, 3, padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(8, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(32, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Flatten(),
                layers.Dense(self.latent_dim, activation='sigmoid', kernel_initializer='he_normal')
            ]
        )
        
        preflat_shape = self.encoder.layers[-2].input.get_shape().as_list()[1:]

        self.decoder = tf.keras.Sequential(
            [
                layers.InputLayer(input_shape=(self.latent_dim,)),
                layers.Dense(np.prod(preflat_shape), activation='relu', kernel_initializer='he_normal'),
                layers.Reshape(target_shape=preflat_shape),
                layers.Conv2DTranspose(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(8, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(1, 3, strides=(1,1), padding='same', activation='relu', kernel_initializer='he_normal')
            ]
        ) """
        
        
    def create_cnnv3(self):
        """ self.encoder = tf.keras.Sequential(
            [
                layers.InputLayer(input_shape=self.data_size),
                layers.Conv2D(8, 3, padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(8, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(32, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(32, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(64, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Flatten(),
                layers.Dense(self.latent_dim, activation='sigmoid', kernel_initializer='he_normal')
            ]
        )
        
        preflat_shape = self.encoder.layers[-2].input.get_shape().as_list()[1:]

        self.decoder = tf.keras.Sequential(
            [
                layers.InputLayer(input_shape=(self.latent_dim,)),
                layers.Dense(np.prod(preflat_shape), activation='relu', kernel_initializer='he_normal'),
                layers.Reshape(target_shape=preflat_shape),
                layers.Conv2DTranspose(32, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(32, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(16, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(8, 3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2D(1, 3, strides=(1,1), padding='same', activation='relu', kernel_initializer='he_normal')
            ]
        ) """
        
    def crop_output(self):
        
        h, w = self.decoder.layers[-1].output.get_shape().as_list()[1:3]  # reconstructed width and hight
        h_tgt, w_tgt = self.data_size[:2]
        dh = h - h_tgt  # deltas to be cropped away
        dw = w - w_tgt

        # add to decoder cropping layer and final reshaping
        self.decoder.add(layers.Cropping2D(cropping=((dh//2, dh-dh//2), (dw//2, dw-dw//2))))
        self.decoder.add(layers.Reshape(target_shape=self.output_size,))

        
    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
    
    def encode(self, x):
        return self.encoder(x)

    
    def decode(self, z):
        return self.decoder(z)