In [None]:
# Script to extract the data over country-shapes

In [None]:
import pandas as pd
import geopandas as gpd
# Operations on geometries
import shapely
import numpy as np
import math
import xarray as xr
import regionmask
import dask.array as da
# plotting
import matplotlib.pyplot as plt
import matplotlib

%matplotlib inline

In [None]:
# Path files (change the paths accordingly)
mypath = '/storage/workspaces/giub_hydro/hydro/data'
era_dir = mypath + '/ERA5/'
sh_file = '/ref-nuts-2016-10m.shp/NUTS_RG_10M_2016_4326_LEVL_0.shp'
t2m = 'Daymean_era5_T2M_EU_19790101-20210905.nc'

In [None]:
nuts = gpd.read_file(mypath+sh_file)

In [None]:
nuts.head()

In [None]:
t2m_dat = xr.open_mfdataset(era_dir + t2m, chunks = {'time': 10})
lons = t2m_dat.lon
lats = t2m_dat.lat

In [None]:
mslp_dat = xr.open_mfdataset(f'{era_dir}mslp/*.nc', combine='by_coords', parallel = True)

In [None]:
gp_dat = xr.open_mfdataset(f'{era_dir}geopotential/*.nc', combine='by_coords', parallel = True)

In [None]:
nn = len(nuts.NUTS_ID)

In [None]:
# Define the mask
nuts_mask_poly = regionmask.Regions(name = 'nuts_mask', numbers = list(range(0,nn)), names = list(nuts.NUTS_ID), abbrevs = list(nuts.NUTS_ID), outlines = list(nuts.geometry.values[i] for i in range(0,nn)))
print(nuts_mask_poly)

In [None]:
# Extract the ID for Switzerland
ID_CH = 11

In [None]:
# Important: pay attention to incresing (or decreasin order of latitude)
mask = nuts_mask_poly.mask(t2m_dat.isel(time = 0).sel(lat = slice(80,30), lon  = slice(-40, 40)), lat_name='lat', lon_name='lon')

In [None]:
# Just to visualise
plt.figure(figsize=(12,8))
ax = plt.axes()
mask.plot(ax = ax)
nuts.plot(ax = ax, alpha = 0.8, facecolor = 'none', lw = 1)

In [None]:
def extract_nuts_TS(nc, nuts_mask_poly, nvar,lim_lat, lim_lon, nam_lat, nam_lon, ID_country):
    # note: I need to change the coordinates names in geopotential as in T2m and MSL, so I could remove those if (nvar=='z') ... 
    # for now, we can use the function like this
    # Important: pay attention to incresing (or decreasin order of latitude)
    if ( nvar == 'z'):
        mask = nuts_mask_poly.mask(nc.isel(time = 0).sel(latitude = slice(lim_lat[0],lim_lat[1]), longitude  = slice(lim_lon[0], lim_lon[1])), lat_name=nam_lat, lon_name=nam_lon)
        lat = mask.latitude.values
        lon = mask.longitude.values
    else:
        mask = nuts_mask_poly.mask(nc.isel(time = 0).sel(lat = slice(lim_lat[0],lim_lat[1]), lon  = slice(lim_lon[0], lim_lon[1])), lat_name=nam_lat, lon_name=nam_lon)
        lat = mask.lat.values
        lon = mask.lon.values
        
   
    var_country = list()
    meanvar_country = list()
    if ( ID_country == None ):
        for ID_REGION in range(0,37):
            print(ID_REGION)
            sel_mask = mask.where(mask == ID_REGION).values
            id_lon = lon[np.where(~np.all(np.isnan(sel_mask), axis=0))]
            id_lat = lat[np.where(~np.all(np.isnan(sel_mask), axis=1))]
            out_sel = nc.sel(lat = slice(id_lat[0], id_lat[-1]), lon = slice(id_lon[0], id_lon[-1])).compute().where(mask == ID_REGION)
            var_country.append(out_sel)
            # For doing country average
            x = out_sel.groupby('time').mean(...)
            meanvar_country.append(x)
    else:
        print(ID_country)        
        sel_mask = mask.where(mask == ID_country).values
        id_lon = lon[np.where(~np.all(np.isnan(sel_mask), axis=0))]
        id_lat = lat[np.where(~np.all(np.isnan(sel_mask), axis=1))]
           
        if ( nvar == 'z'):
            levs = gp_dat.level
            x = list()
            out_sel = nc.sel(latitude = slice(id_lat[0], id_lat[-1]), longitude = slice(id_lon[0], id_lon[-1])).compute().where(mask == ID_country)
            for il in range(0,len(levs)):
                print(il)
                xx = out_sel.isel(level=il).groupby('time').mean(...)
                x.append(xx)
        else:  
            out_sel = nc.sel(lat = slice(id_lat[0], id_lat[-1]), lon = slice(id_lon[0], id_lon[-1])).compute().where(mask == ID_country)
            var_country.append(out_sel)
            # For doing country average
            x = out_sel.groupby('time').mean(...)
            meanvar_country.append(x)
        
    return(x)

In [None]:
t2m_CH = extract_nuts_TS(t2m_dat, nuts_mask_poly,nvar= 'T2MMEAN', lim_lat= [80,30],lim_lon= [-40,40], nam_lat='lat', nam_lon='lon', ID_country = ID_CH)

In [None]:
mslp_CH = extract_nuts_TS(mslp_dat, nuts_mask_poly, nvar= 'MSL', lim_lat= [30,80], lim_lon= [-40,40], nam_lat='lat', nam_lon='lon', ID_country = ID_CH)

In [None]:
z_CH = extract_nuts_TS(gp_dat,nuts_mask_poly, nvar='z', lim_lat= [80,30], lim_lon= [-40,40], nam_lat='latitude', nam_lon='longitude', ID_country = ID_CH)

In [None]:
def save_dataf(mdat, xvar, dout):
    df_ch = pd.DataFrame({'date':mdat['time'].values, xvar:mdat[xvar].values})
    # saving the dataframe 
    df_ch.to_csv(dout + 'df_'+ xvar + '_CH_1979-2021.csv', index = False, header=True)

    
    

In [None]:
def save_dataf_levels(mdat, xvar, nlev, dout):
    df = pd.DataFrame({'date':mdat[0]['time'].values})
    dfs=[]
    for i in range(0,len(nlev)):
        if (type(mdat[i]) != float):
            x = pd.DataFrame({nlev[i]:mdat[i][xvar]})
            dfs.append(x)
            dft = pd.concat(dfs, axis=1)
            dft['date']=df['date']
            
    cols = dft.columns.tolist()
    cols = cols[-1:] + cols[:-1]
    df_out=dft[cols]
    # saving the dataframe 
    df_out.to_csv(dout + 'df_'+ xvar + '_CH_1979-2021.csv', index = False, header=True)


In [None]:
save_dataf(t2m_CH, 'T2MMEAN', mypath)

In [None]:
# MSLP
save_dataf(mslp_CH, 'MSL', mypath)

In [None]:
nlev=gp_dat.level.values
# save Z
save_dataf_levels(z_CH, 'z', nlev, mypath)