In [None]:
import xarray as xr
from pathlib import Path
import netCDF4

import datetime
import pickle

import cartopy.crs as ccrs
import cartopy
import proplot as pplt

import numpy as np
import copy
import math
import random

#For reproducibility set seeds
np.random.seed(0)
random.seed(0)

#loads n years of netcdf4-files into an array of xarrays
def load_years(start_year, years):
    #data is stored in a different directory, so get the parent first, then go to the data directory
    parent_dir = Path.cwd().parent
    data_dir = parent_dir / "ERA5-downloader"

    monthly_data = list()

    for year in range(start_year,(start_year + years)):  
        print("Loading " + str(year))
        for month in range(1,13):
            #for naming, make the month always two digits
            if (month < 10):
                month_filled = "0" + str(month)
            else:
                month_filled = str(month)
                
            file_name = str(year) + "-" + month_filled + ".nc"
            #print(file_name)
            file_path = data_dir / file_name
            monthly_data.append(xr.load_dataset(file_path, engine="netcdf4"))

    return monthly_data

#Selects a random cell of a given size from an xarray
def select_cell(data, cell_lat, cell_lon):
    #check if xarray is larger than cell size
    lon = tuple(dict(data[['longitude']].sizes).values())[0]
    lat = tuple(dict(data[['latitude']].sizes).values())[0]
    if (cell_lat > lat):
        return data
    if (cell_lon > lon):
        return data

    #now find indices of upper left corner, randomly between 0 and lat/lon - 5 or 6
    lon_select = random.randint(0, lon - cell_lon)
    lat_select = random.randint(0, lat - cell_lat)

    return data.isel(longitude = slice(lon_select, lon_select + cell_lon), latitude = slice(lat_select, lat_select + cell_lat))

#creates a random patch from the given data with the given width and height in pixels. Include a given number of steps before
#and after the target sample. Returns a tuple of before and after, the latter including the central step
def create_patch(data, patch_lat, patch_lon, pre_steps, post_steps, filter_prec):
    #Is our selected patch suitable?
    good_selection = False
    while(good_selection == False):
        sel = select_cell(data, patch_lat, patch_lon)
        
        #Now we want to select a single time step, with a number of steps before and after
        #How many hours does our data include?
        time = tuple(dict(sel[['time']].sizes).values())[0]
        time_sel =  random.randint(pre_steps, time - post_steps - 1)
        sel = sel.isel(time = slice(time_sel - pre_steps, time_sel + post_steps))

        #Does our patch contain sufficient precipitation to learn from?
        #Also convert to mm
        if (filter_prec == True):
            single = sel.isel(time = slice(pre_steps, pre_steps + post_steps)) #All steps of the prediction
            patch_prec = single['tp'].sum(dim = ["latitude", "longitude", "time"])
            #if it exceeds the limit, accept it and exit the loop
            limit = post_steps * patch_lat * 0.0005
            #currently 0.5 mm in one hour * 10 * 6
            #print(patch_prec)
            #print(limit)
            if (patch_prec > limit):
                good_selection = True
        else:
            good_selection = True
                
    return sel

#write list to binary file
def write_list(a_list, filename):
    #store list in binary file so 'wb' mode
    with open(filename, 'wb') as fp:
        pickle.dump(a_list, fp)
        print('Done writing list into a binary file')

#read list to memory
def read_list(filename):
    # for reading also binary mode is important
    with open(filename, 'rb') as fp:
        n_list = pickle.load(fp)
        return n_list

#takes a list of xarrays and picks a number of samples from them without duplicates
def get_patch_list(data, sample_count, pixel_width, pixel_height, prior_timesteps, past_timesteps, training_data):
    patch_list = list()

    no_duplicates = False
    while(no_duplicates == False):
        patch_list.clear()
        for i in range(sample_count):
            #Select a random month out of our xarray-list
            patch_list.append(create_patch(random.choice(data), pixel_width, pixel_height, prior_timesteps, past_timesteps, training_data))
   
        #The risk of duplicates - same exact time and place - is roughly 2%. We do not want that, as samples would be overrepresented in the data
        no_duplicates = True
        for i in range(sample_count-1):
            for j in range(i+1, sample_count):
                #only compare the first xarray of the tuples. Since each tuple is consecutive, if the first is identical, the second is too
                if (patch_list[i].equals(patch_list[j])):
                    no_duplicates = False
                    print("Duplicate at i = " + str(i) + " and j = " + str(j))
                    print(patch_list[i])
                    print(patch_list[j])
    return patch_list

#Takes a list of datasets with the values u10, v10, t2m, sp, tp and throws away coordinates and the specific time while maintaining the sequence
#Returns an array of shape [sample][feature][time][x][y]
def xarray_dataset_list_to_numpy(dataset_list):

    #Take the different values from each data set, stack them and put them into a new list
    stacked_list = []
    
    for ds in dataset_list:
        u10_array = ds['u10'].values
        v10_array = ds['v10'].values
        t2m_array = ds['t2m'].values
        sp_array = ds['sp'].values
        tp_array = ds['tp'].values * 1000 #convert to mm

        stacked_list.append(np.stack((u10_array, v10_array, t2m_array, sp_array, tp_array), axis = 0))

    return np.stack(stacked_list)
    
#find the minimum and maximum values for u10, v10, t2m, sp, tp and the difference between them
#needed for normalization
def find_min_max_ptp(data_array):
    arr_max = np.max(data_array, axis = (0, 2, 3, 4))
    arr_min = np.min(data_array, axis = (0, 2, 3, 4))
    #shape of array: [u10,v10,t2m,sp,tp][min,max,range]
    ptp = np.zeros(5)
    for i in range(5):
        ptp[i] = arr_max[i] - arr_min[i]
        if (ptp[i] == 0):
            print("Error, no difference between min and max value!")
            print("Calculation will fail!")

    return np.stack((arr_min, arr_max, ptp), axis = 1)

#Normalize all values apart from precipitation between 0 and 1
def min_max_scaling(data_array, min_max_values):
    data_shape = data_array.shape
    print(data_shape)

    for feature in range(data_shape[1]-1): #Do not scale tp!
        data_array[:,feature,:,:,:] -= min_max_values[feature][0]
        data_array[:,feature,:,:,:] /= min_max_values[feature][2]
        
    return data_array

def print_stats(ds):
    print("Min, max, mean, median")
    for feature in range(5):
        print(str(np.min(ds[:,feature,:,:,:])) + " | " +
              str(np.max(ds[:,feature,:,:,:])) + " | " +
              str(np.mean(ds[:,feature,:,:,:])) + " | " +
              str(np.median(ds[:,feature,:,:,:])))

def PGW(data, warming):
    """
    Takes an existing, already generated dataset and applies warming to it

    Args:
        data (Array): Generated train/test numpy dataset as historical input 
                        of shape [*, 5, 54, x, x]
        warming (float): The warming to be applied in K

    Returns:
        ds (Array): numpy dataset as of shape [*, 5, 54, x, x]
    """

    ds = copy.deepcopy(data) #does this do what we want?
    
    print(ds.shape)
    print(ds[0,4,0,0,0])

    #u10,v10,t2m,sp,tp

    #temperature
    ds[:,2,:,:,:] += warming

    #wind
    #Pielke 2022:4 RCP 8.5 equals 5K warming
    #Jung2019: RCP 8.5, all far future
        #Mean: No change
        #Skew and kurtosis too difficult
        #Stdev increases by 10%
        #-> preserve mean, preserve sign
    stdev_factor = 1 + (0.02 * warming)

    ds[:,0,:,:,:] = adjust_std_separately(ds[:,0,:,:,:], stdev_factor)
    ds[:,1,:,:,:] = adjust_std_separately(ds[:,1,:,:,:], stdev_factor)
    
    #surface pressure
    #Schmidt 2017:10575 - increase around 0.1 hPa/decade in AOI between 1980 and 2010
    #NOAA 2023: 0.29 K/decade
    # => 0.36 hPa/K
    ds[:,3,:,:,:] += warming * 35.7 #1 hPa = 100 Pa
    
    #precipitation

    prec_med = np.median(ds[:,4,:,:,:])
    print(prec_med)
    prec_mean = np.mean(ds[:,4,:,:,:])
    print(prec_mean)
    print("-------")
    
    for sample in range(ds.shape[0]):
        for hour in range(ds.shape[2]):
            for x in range(ds.shape[3]):
                for y in range(ds.shape[4]):
                    prec = ds[sample, 4, hour, x, y]
                    ds[sample, 4, hour, x, y] = prec * warming_prec(
                        prec, warming, 2.1)

    return ds

def adjust_std_separately(arr, factor):
    pos_mask = arr > 0
    neg_mask = arr < 0
    
    # Compute means separately
    pos_mean = np.mean(arr[pos_mask]) if np.any(pos_mask) else 0
    neg_mean = np.mean(arr[neg_mask]) if np.any(neg_mask) else 0
    
    # Center data around their respective means
    pos_centered = arr[pos_mask] - pos_mean
    neg_centered = arr[neg_mask] - neg_mean
    
    # Scale deviations separately
    pos_centered *= factor
    neg_centered *= factor
    
    # Re-add the respective means to maintain them
    adjusted_arr = arr.copy()
    adjusted_arr[pos_mask] = pos_centered + pos_mean
    adjusted_arr[neg_mask] = neg_centered + neg_mean
    
    return adjusted_arr

def warming_prec(prec, warming, prec_m):
    '''

    '''
    if (prec > 1.5*prec_m):
        #print("Case 1")
        return pow(1.07, warming)
    elif (prec < 0.5*prec_m):
        #print("Case 2")
        return pow(0.93, warming)
    else:
        #print("Case 3")
        return (((pow(1.07, warming) - pow(0.93, warming)) / 2) * math.sin(
            (math.pi / prec_m) * (prec - prec_m)) + ((pow(1.07, warming) + pow(0.93, warming)) / 2))

numbers = np.linspace(0, 5, num = 200)
scaling = np.empty(200)
scaling2 = np.empty(200)
scaling3 = np.empty(200)

i = 0
while (i < len(numbers)):
    scaling[i] = warming_prec(numbers[i], 1, 2.1)
    scaling2[i] = warming_prec(numbers[i], 2, 2.1)
    scaling3[i] = warming_prec(numbers[i], 4, 2.1)
    i = i + 1

with pplt.rc.context(fontsize='11px'):
    fig, ax = pplt.subplot(xlabel='Precipitation: mm/h', ylabel='Scale Factor',
                           figheight='5cm', figwidth="14cm")
ax.plot(numbers, scaling, label = "Warming 1°C")
ax.plot(numbers, scaling2, label = "Warming 2°C")
ax.plot(numbers, scaling3, label = "Warming 4°C")
ax.legend(loc="upper left")
fig.savefig("transfer")
pplt.show()
pplt.close()

def plot_hist(ds, name):
    with pplt.rc.context(fontsize='11px'):
        fig, axes = pplt.subplots(nrows = 2, ncols = 3,figheight='8cm', figwidth="14cm", share=False)
    axes.format(yticks='null', abc = True)

    wind_bins = np.linspace(-5, 5, 200)
    t_bins = np.linspace(285, 310, 200)
    p_bins = np.linspace(85000, 105000, 200)
    prec_bins1 = np.linspace(0, 0.5, 200)
    prec_bins2 = np.linspace(0.5, 10, 200)
    
    axes[0].hist(ds[:, 0, :, :, :].flatten(), bins = wind_bins) 
    axes[0].format(title = "Wind u (10m)", xlabel = "m/s")

    axes[1].hist(ds[:, 1, :, :, :].flatten(), bins = wind_bins) 
    axes[1].format(title = "Wind v (10m)", xlabel = "m/s")

    axes[2].hist(ds[:, 2, :, :, :].flatten(), bins = t_bins) 
    axes[2].format(title = "Temperature (2m)", xlabel = "K")

    axes[3].hist(ds[:, 3, :, :, :].flatten(), bins = p_bins) 
    axes[3].format(title = "Surface Pressure", xlabel = "Pa")

    axes[4].hist(ds[:, 4, :, :, :].flatten(), bins = prec_bins1) 
    axes[4].format(title = "Precipitation < 0.5 mm/h", xlabel = "mm/h")
    
    axes[5].hist(ds[:, 4, :, :, :].flatten(), bins = prec_bins2) 
    axes[5].format(title = "Precipitation > 5 mm/h", xlabel = "mm/h")
    pplt.show()
    fig.savefig(name)
    pplt.close()

def plot_hist_comp(ds1, ds2, name):

    ds1 = np.swapaxes(ds1, 0, 1)
    ds1 = np.reshape(ds1, (ds1.shape[0], ds1.shape[1]*ds1.shape[2]*ds1.shape[3]*ds1.shape[4]))
    ds2 = np.swapaxes(ds2, 0, 1)
    ds2 = np.reshape(ds2, (ds2.shape[0], ds2.shape[1]*ds2.shape[2]*ds2.shape[3]*ds2.shape[4]))
    ds = np.stack((ds1, ds2), axis=2)
   
    with pplt.rc.context(fontsize='11px'):
        fig, axes = pplt.subplots(nrows = 2, ncols = 3,figheight='8cm', figwidth="14cm", share=False)

    axes.format(yticks='null', abc = True)

    wind_bins = np.linspace(-5, 5, 200)
    t_bins = np.linspace(285, 310, 200)
    p_bins = np.linspace(85000, 105000, 200)
    prec_bins1 = np.linspace(0, 0.5, 200)
    prec_bins2 = np.linspace(0.5, 10, 200)
    
    axes[0].hist(ds[0], bins = wind_bins, alpha = 0.5, filled=True, cycle=('indigo9', 'red9'),labels=("Base", "5K warming")) 
    axes[0].format(title = "Wind u (10m)", xlabel = "m/s")

    axes[1].hist(ds[1], bins = wind_bins, alpha = 0.5, filled=True, cycle=('indigo9', 'red9')) 
    axes[1].format(title = "Wind v (10m)", xlabel = "m/s")

    axes[2].hist(ds[2], bins = t_bins, alpha = 0.5, filled=True, cycle=('indigo9', 'red9')) 
    axes[2].format(title = "Temperature (2m)", xlabel = "K")

    axes[3].hist(ds[3], bins = p_bins, alpha = 0.5, filled=True, cycle=('indigo9', 'red9')) 
    axes[3].format(title = "Surface Pressure", xlabel = "Pa")

    axes[4].hist(ds[4], bins = prec_bins1, alpha = 0.5, filled=True, cycle=('indigo9', 'red9')) 
    axes[4].format(title = "Precipitation < 0.5 mm/h", xlabel = "mm/h")
    
    axes[5].hist(ds[4], bins = prec_bins2, alpha = 0.5, filled=True, cycle=('indigo9', 'red9')) 
    axes[5].format(title = "Precipitation > 0.5 mm/h", xlabel = "mm/h")

    fig.legend(loc='t')
    
    pplt.show()
    fig.savefig(name)
    pplt.close()


#arr1 = read_list("test_array_10_1000_np")
#arr2 = read_list("PGW_10_1000_50_np")
#arr3 = read_list("train_array_10_10000_np")

#plot_hist_comp(arr1, arr2, "Hist-test-data")
#plot_hist(arr3, "Hist-training")

#print_stats(arr1)
#print_stats(arr2)
#print_stats(arr3)

In [None]:
#This cell actually calls all functions needed to transform the data
#This gives us several years of monthly xarrays in one list
train_data_A = load_years(1980, 2)
train_data_B = load_years(2020, 2)
test_data = load_years(2000, 2)
print("Loading finished")

#How many patches do we want each?
train_number = 5000 #for each of the two datasets
test_number = 1000
#How large should our patches be? One pixel = 0.25°
pixel_width = 10
pixel_height = 10
#How many hours before and after do we want?
prior_timesteps = 48
past_timesteps = 6

print("Creating train list")
train_list = get_patch_list(train_data_A, train_number, pixel_width, pixel_height, prior_timesteps, past_timesteps, True) \
+ get_patch_list(train_data_B, train_number, pixel_width, pixel_height, prior_timesteps, past_timesteps, True)
print("Train list created")
print("Creating test list")
test_list = get_patch_list(test_data, test_number, pixel_width, pixel_height, prior_timesteps, past_timesteps, False)
print("Test list created")

train_array = xarray_dataset_list_to_numpy(train_list)
test_array = xarray_dataset_list_to_numpy(test_list)

#Due to conversion errors some precipiation values are slightly below 0, so we clip them to 0
train_array[:,4,:,:,:] = np.clip(train_array[:,4,:,:,:], 0, None)
test_array[:,4,:,:,:] = np.clip(test_array[:,4,:,:,:], 0, None)
#Also save the non-scaled arrays
write_list(train_array, "train_array_10_10000_np")
write_list(test_array, "test_array_10_2000_np")

PGW_array_05 = PGW(test_array, 0.5)
PGW_array_10 = PGW(test_array, 1)
PGW_array_20 = PGW(test_array, 2)
PGW_array_30 = PGW(test_array, 3)
PGW_array_40 = PGW(test_array, 4)
PGW_array_50 = PGW(test_array, 5)

write_list(PGW_array_05, "PGW_10_1000_05_np")
write_list(PGW_array_10, "PGW_10_1000_10_np")
write_list(PGW_array_20, "PGW_10_1000_20_np")
write_list(PGW_array_30, "PGW_10_1000_30_np")
write_list(PGW_array_40, "PGW_10_1000_40_np")
write_list(PGW_array_50, "PGW_10_1000_50_np")

#We want to use the min-max-values of the training data to scale both train and test data -> no data flow from test to train
min_max = find_min_max_ptp(train_array)
write_list(min_max)

print("Now scaling")
train_array_scaled = min_max_scaling(copy.deepcopy(train_array), min_max)
test_array_scaled = min_max_scaling(copy.deepcopy(test_array), min_max)

PGW_05_scaled = min_max_scaling(copy.deepcopy(PGW_array_05), min_max)
PGW_10_scaled = min_max_scaling(copy.deepcopy(PGW_array_10), min_max)
PGW_20_scaled = min_max_scaling(copy.deepcopy(PGW_array_20), min_max)
PGW_30_scaled = min_max_scaling(copy.deepcopy(PGW_array_30), min_max)
PGW_40_scaled = min_max_scaling(copy.deepcopy(PGW_array_40), min_max)
PGW_50_scaled = min_max_scaling(copy.deepcopy(PGW_array_50), min_max)

write_list(train_array_scaled, "train_array_10_10000")
write_list(test_array_scaled, "test_array_10_2000")

write_list(PGW_05_scaled, "PGW_10_1000_05")
write_list(PGW_10_scaled, "PGW_10_1000_10")
write_list(PGW_20_scaled, "PGW_10_1000_20")
write_list(PGW_30_scaled, "PGW_10_1000_30")
write_list(PGW_40_scaled, "PGW_10_1000_40")
write_list(PGW_50_scaled, "PGW_10_1000_50")
print("Lists saved")