In [None]:
import xarray as xr
from pathlib import Path
import netCDF4

import datetime
import pickle

import cartopy.crs as ccrs
import cartopy
import proplot as pplt

import numpy as np
import copy
import math
import random

#For reproducibility set seeds
np.random.seed(0)
random.seed(0)

#loads n years of netcdf4-files into an array of xarrays
def load_years(start_year, years):
    #data is stored in a different directory, so get the parent first, then go to the data directory
    parent_dir = Path.cwd().parent
    data_dir = parent_dir / "ERA5-downloader"

    monthly_data = list()

    for year in range(start_year,(start_year + years)):  
        print("Loading " + str(year))
        for month in range(1,13):
            #for naming, make the month always two digits
            if (month < 10):
                month_filled = "0" + str(month)
            else:
                month_filled = str(month)
                
            file_name = str(year) + "-" + month_filled + ".nc"
            #print(file_name)
            file_path = data_dir / file_name
            monthly_data.append(xr.load_dataset(file_path, engine="netcdf4"))

    return monthly_data

#Selects a random cell of a given size from an xarray
def select_cell(data, cell_lat, cell_lon):
    #check if xarray is larger than cell size
    lon = tuple(dict(data[['longitude']].sizes).values())[0]
    lat = tuple(dict(data[['latitude']].sizes).values())[0]
    if (cell_lat > lat):
        return data
    if (cell_lon > lon):
        return data

    #now find indices of upper left corner, randomly between 0 and lat/lon - 5 or 6
    lon_select = random.randint(0, lon - cell_lon)
    lat_select = random.randint(0, lat - cell_lat)

    return data.isel(longitude = slice(lon_select, lon_select + cell_lon), latitude = slice(lat_select, lat_select + cell_lat))

#creates a random patch from the given data with the given width and height in pixels. Include a given number of steps before
#and after the target sample. Returns a tuple of before and after, the latter including the central step
def create_patch(data, patch_lat, patch_lon, pre_steps, post_steps, filter_prec):
    #Is our selected patch suitable?
    good_selection = False
    while(good_selection == False):
        sel = select_cell(data, patch_lat, patch_lon)
        
        #Now we want to select a single time step, with a number of steps before and after
        #How many hours does our data include?
        time = tuple(dict(sel[['time']].sizes).values())[0]
        time_sel =  random.randint(pre_steps, time - post_steps - 1)
        sel = sel.isel(time = slice(time_sel - pre_steps, time_sel + post_steps))

        #Does our patch contain sufficient precipitation to learn from?
        #Also convert to mm
        if (filter_prec == True):
            single = sel.isel(time = slice(pre_steps, pre_steps + post_steps)) #All steps of the prediction
            patch_prec = single['tp'].sum(dim = ["latitude", "longitude", "time"])
            #if it exceeds the limit, accept it and exit the loop
            limit = post_steps * patch_lat * 0.0005
            #currently 0.5 mm in one hour * 10 * 6
            #print(patch_prec)
            #print(limit)
            if (patch_prec > limit):
                good_selection = True
        else:
            good_selection = True
                
    return sel

#write list to binary file
def write_list(a_list, filename):
    #store list in binary file so 'wb' mode
    with open(filename, 'wb') as fp:
        pickle.dump(a_list, fp)
        print('Done writing list into a binary file')

#read list to memory
def read_list(filename):
    # for reading also binary mode is important
    with open(filename, 'rb') as fp:
        n_list = pickle.load(fp)
        return n_list

#takes a list of xarrays and picks a number of samples from them without duplicates
def get_patch_list(data, sample_count, pixel_width, pixel_height, prior_timesteps, past_timesteps, training_data):
    patch_list = list()

    no_duplicates = False
    while(no_duplicates == False):
        patch_list.clear()
        for i in range(sample_count):
            #Select a random month out of our xarray-list
            patch_list.append(create_patch(random.choice(data), pixel_width, pixel_height, prior_timesteps, past_timesteps, training_data))
   
        #The risk of duplicates - same exact time and place - is roughly 2%. We do not want that, as samples would be overrepresented in the data
        no_duplicates = True
        for i in range(sample_count-1):
            for j in range(i+1, sample_count):
                #only compare the first xarray of the tuples. Since each tuple is consecutive, if the first is identical, the second is too
                if (patch_list[i].equals(patch_list[j])):
                    no_duplicates = False
                    print("Duplicate at i = " + str(i) + " and j = " + str(j))
                    print(patch_list[i])
                    print(patch_list[j])
    return patch_list

#Takes a list of datasets with the values u10, v10, t2m, sp, tp and throws away coordinates and the specific time while maintaining the sequence
#Returns an array of shape [sample][feature][time][x][y]
def xarray_dataset_list_to_numpy(dataset_list):

    #Take the different values from each data set, stack them and put them into a new list
    stacked_list = []
    
    for ds in dataset_list:
        u10_array = ds['u10'].values
        v10_array = ds['v10'].values
        t2m_array = ds['t2m'].values
        sp_array = ds['sp'].values
        tp_array = ds['tp'].values * 1000 #convert to mm

        stacked_list.append(np.stack((u10_array, v10_array, t2m_array, sp_array, tp_array), axis = 0))

    return np.stack(stacked_list)
    
#find the minimum and maximum values for u10, v10, t2m, sp, tp and the difference between them
#needed for normalization
def find_min_max_ptp(data_array):
    arr_max = np.max(data_array, axis = (0, 2, 3, 4))
    arr_min = np.min(data_array, axis = (0, 2, 3, 4))
    #shape of array: [u10,v10,t2m,sp,tp][min,max,range]
    ptp = np.zeros(5)
    for i in range(5):
        ptp[i] = arr_max[i] - arr_min[i]
        if (ptp[i] == 0):
            print("Error, no difference between min and max value!")
            print("Calculation will fail!")

    return np.stack((arr_min, arr_max, ptp), axis = 1)

#Normalize all values apart from precipitation between 0 and 1
def min_max_scaling(data_array, min_max_values):
    data_shape = data_array.shape
    print(data_shape)

    for feature in range(data_shape[1]-1): #Do not scale tp!
        data_array[:,feature,:,:,:] -= min_max_values[feature][0]
        data_array[:,feature,:,:,:] /= min_max_values[feature][2]
        
    return data_array

def print_stats(ds):
    print("Min, max, mean, median")
    for feature in range(5):
        print(str(np.min(ds[:,feature,:,:,:])) + " | " +
              str(np.max(ds[:,feature,:,:,:])) + " | " +
              str(np.mean(ds[:,feature,:,:,:])) + " | " +
              str(np.median(ds[:,feature,:,:,:])))