# Preprocess ERA5 data

In this notebook, we will preprocess the data downloaded from [ECMWF website](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5)

In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget
matplotlib.rc('font', size=18)
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

import os
import xarray as xr

HOME = '../../'

## Download the data

Run the scripts `fetch_t2m.py`, `fetch_zg.py` and `fetch_lsm.py`. They may take a while


## Compute cell area

In [None]:
raw_lsm = xr.open_dataset('raw/lsm.nc')
raw_lsm

In [None]:
lsm = raw_lsm.lsm.isel(time=0,drop=True)
lsm

In [None]:
lsm.to_netcdf('land_sea_mask_fullres.nc')

In [None]:
R = 6.371e6 # Earth radius
cell_area = xr.ones_like(lsm).rename('cell_area')

dlon = cell_area.longitude.data[1] - cell_area.longitude.data[0]
dlat = -cell_area.latitude.data[1] + cell_area.latitude.data[0]

cell_area *= R**2*(np.pi/180)**2*dlon*dlat*np.maximum(np.cos(np.pi/180*cell_area.latitude),0)

cell_area.attrs = {
    'units': 'm**2',
    'long_name': 'cell area',
}

cell_area

In [None]:
cell_area.to_netcdf('cell_area_fullres.nc')

## Compute daily means

Run the script `daily_mean.py`

If you have storage space limitations, after computing the daily means you can delete the whole `raw` directory

## Compute area integral over France of the temperature field

In [None]:
def standardize_dim_names(xa:xr.DataArray) -> xr.DataArray:
    '''
    Renames the coordinates of `xa` to oblige with standard:
    longitude: 'lon'
    latitude:  'lat'
    time:      'time'

    The renamed dataarray is then returned
    '''
    standard_names_to_variants = {
        'lon': ['longitude', 'Longitude'],
        'lat': ['latitude', 'Latitude'],
        'time': ['Time', 't', 'T']
    }
    renamings = {}
    for dim in xa.dims:
        if dim in standard_names_to_variants:
            continue
        for standard_dim, variants in standard_names_to_variants.items():
            if dim in variants:
                renamings[dim] = standard_dim
                break
    if renamings:
        xa = xa.rename(renamings)
    return xa

def is_above_line(da:xr.DataArray, lon1:float, lat1:float, lon2:float, lat2:float) -> xr.DataArray:
    '''
    returns a mask of the input object that is true north of a line in lon-lat space defined by two points.

    By multiplying the output of this funcion over several evaluations you can define a polygonal mask

    Parameters
    ----------
    da : xr.DataArray
        input object with longitude and latitude dimensions
    lon1 : float
        longitude of the first point
    lat1 : float
        latitude of the first point
    lon2 : float
        longitude of the second point
    lat2 : float
        latitude of the second point

    Returns
    -------
    xr.DataArray
        the mask
    '''
    da = standardize_dim_names(da)
    return da.lat - (lat1*(lon2 - da.lon) + lat2*(da.lon - lon1))/(lon2 - lon1) > 0

def masked_average(xa:xr.DataArray,
                   dim=None,
                   weights:xr.DataArray=None,
                   mask:xr.DataArray=None) -> xr.DataArray:
    '''
    Computes the average of `xa` over given dimensions `dim`, weighting with `weights` and masking with `mask`

    Parameters
    ----------
    xa : xr.DataArray
        data
    dim : str or list of str, optional
        dimensions over which to perform the average, by default None
    weights : xr.DataArray, optional
        weights for the average, for example the cell, by default None
    mask : xr.DataArray, optional
        True over the data to keep, False over the data to ignore, by default None

    Returns
    -------
    xr.DataArray
        masked and averaged array
    '''
    if weights is not None:
        _weights = weights.copy()
        if mask is not None:
            _weights = _weights.where(mask, 0)
    elif mask is not None:
        _weights = xr.where(mask, 1, 0)
    else: # mask = weights = None
        return xa.mean(dim=dim)

    _weights /= _weights.sum(dim=dim) # normalize weights
    _xa = xa*_weights
    return _xa.sum(dim=dim)

### Load data

In [None]:
t2m = xr.open_dataarray('t2m_MJJA_fullres.nc')
t2m = standardize_dim_names(t2m)
t2m

In [None]:
# land-sea mask
lsm = xr.open_dataarray('land_sea_mask_fullres.nc')
lsm = standardize_dim_names(lsm)
lsm = lsm.sel(lat=t2m.lat, lon=t2m.lon)
lsm = (lsm > 0.5).astype(lsm.dtype) # make it binary
lsm

In [None]:
cell_area = xr.open_dataarray('cell_area_fullres.nc')
cell_area = standardize_dim_names(cell_area)
cell_area = cell_area.sel(lat=t2m.lat, lon=t2m.lon)
cell_area

### Compute mask for France

In [None]:
_mask = lsm > 0.5 # make it bool

newlon = _mask.lon.data % 360 # first make sure longitude is in [0,360]
newlon = newlon - 360*(newlon >= 180) # then put it in [-180,180]
mask = xr.DataArray(_mask.data, coords={'lat':_mask.lat, 'lon':newlon})

mask *= (mask.lat < 52)*(mask.lat > 42)*(mask.lon > -5)*(mask.lon < 8.3) # identify the rough region
# cut the edges
mask *= ~is_above_line(mask, 1.65, 51, -4.5, 49.2)
mask *= is_above_line(mask, -1.86, 43.34, 3.4, 42.2)
mask *= ~is_above_line(mask, 2.26, 51.2, 8.27, 49)
mask *= is_above_line(mask, 8.1, 48.8, 6, 43)

# restore the original longitude
mask = xr.DataArray(mask.data, coords={'lat':mask.lat, 'lon':_mask.lon})
mask

### Compute area integral

In [None]:
## this cell takes a while

# area integral
ai = masked_average(t2m,
                    dim=['lat', 'lon'],
                    weights=lsm*cell_area, # land area weights
                    mask=mask # mask over France
                   )
ai

In [None]:
ai = ai.convert_calendar('noleap') # fix the calendar
ai.attrs = {'units': 'K', 'long_name': '2 metre temperature'}
ai.name = 't2m'
ai

In [None]:
ai.to_netcdf('t2m_MJJA_France.nc')

## Figure S1 + Temperature detrended file

### Compute anomaly

In [None]:
ai_full = xr.open_dataarray('t2m_MJJA_France.nc')

In [None]:
# compute climatology for each day of the year
y = t2m.time.dt.year.data
years = y[-1] - y[0] + 1
_data = ai_full.data.reshape((years, -1))
clim = _data.mean(axis=0)

ano = (_data - clim).reshape((-1))

ai = ai_full.copy(deep=True)
ai.data = ano

ai

### Compute quadratic trend

In [None]:
ai_seasonal = ai.groupby(ai.time.dt.year).mean()
ai_seasonal

In [None]:
y = ai.time.dt.year.data
years = y[-1] - y[0] + 1

b = ai.time.dt.dayofyear.data
days = b[-1] - b[0] + 1

years, days

In [None]:
y = 1940 + np.arange(years)
v = ai_seasonal.data

assert y.shape == v.shape

p = np.poly1d(np.polyfit(y, v, 2))
p

In [None]:
plt.close(1)
fig,ax = plt.subplots(figsize=(9,6), num=1)
v_trend = p(y)

plt.plot(y,v)
plt.plot(y,v_trend)

# plt.xticks( np.arange(0,years,10),np.arange(y[0],y[-1],10),fontsize = 14, rotation = 340)
# plt.yticks(fontsize=14)
plt.ylabel('Seasonal $T_{2m}$ anomaly')

fig.tight_layout()

# plt.savefig(f'{HOME}/t2m_France.pdf')

### Remove trend

In [None]:
ANO_ai = ai.copy(deep=True)
_data = ANO_ai.data.reshape((years, days))
_data.shape, v_trend.shape

In [None]:
_data = (_data.T - v_trend).T
ANO_ai.data = _data.reshape(-1)
ANO_ai

In [None]:
ANO_ai.to_netcdf('ANO_t2m_France.nc')

In [None]:
plt.figure()
ai.plot()
ANO_ai.plot()

## Figure S2 + Geopotential height detrended file

### Regrid to the PlaSim grid

In [None]:
## Copy cell area and land sea mask with PlaSim resolution
os.system('cp ../../land_sea_mask.nc .')
os.system('cp ../../cell_area.nc .')

In [None]:
import xesmf as xe

In [None]:
zg_fullres = xr.open_dataarray('zg_MJJA_fullres.nc')
zg_fullres

In [None]:
plasim_lon = np.sort(np.load('../../lon.npy'))
plasim_lat = np.load('../../lat.npy')

plasim_lon, plasim_lat

In [None]:
zg_regrid= xr.Dataset(
    {
        "latitude": (["latitude"], plasim_lat),
        "longitude": (["longitude"], plasim_lon),
    }
)

regridder = xe.Regridder(zg_fullres,zg_regrid, "bilinear")
regridder

In [None]:
## This cell takes a while
zg_regrid = regridder(zg_fullres)
zg_regrid

In [None]:
zg_regrid /= 9.81 # geopotential to geopotential height
zg_regrid = zg_regrid.convert_calendar('noleap') # get rid of leap years: this makes dayofyear uniform across all years
zg_regrid.name = 'z'
zg_regrid.attrs = {
    'units': 'm',
    'long_name': 'Geopotential height'
}
zg_regrid

In [None]:
zg_regrid.to_netcdf('zg500.nc')

### Detrend

#### Compute anomalies

In [4]:
zg = xr.open_dataarray('zg500.nc')
zg

In [5]:
y = zg.time.dt.year.data
years = y[-1] - y[0] + 1

b = zg.time.dt.dayofyear.data
days = b[-1] - b[0] + 1

years, days

(83, 122)

In [None]:
zg_data = zg.data.reshape(years, days, *zg.data.shape[1:])
zg_data.shape

In [None]:
clim = zg_data.mean(axis=0) # mean on the year axis
zg_ano = xr.DataArray((zg_data - clim).reshape(zg.data.shape), coords=zg.coords, attrs=zg.attrs)
zg_ano.attrs['long_name'] = 'Geopotential height anomaly'

zg_ano

#### Compute zonal and seasonal mean

In [None]:
zg_zonal_mean = zg_ano.mean('longitude') # no weighting needed since cell area depends only on latitude
zg_zonal_mean

In [None]:
zg_seasonal = zg_zonal_mean.groupby(zg_zonal_mean.time.dt.year).mean()
zg_seasonal

#### Compute latitude dependent trends

In [None]:
y = 1940 + np.arange(years)

v_trend = []
for l in range(len(zg_seasonal.latitude)):
    v = zg_seasonal.data[:,l]
    assert y.shape == v.shape
    p = np.poly1d(np.polyfit(y, v, 2))
    v_trend.append(p(y))
    
v_trend = np.stack(v_trend).T
v_trend.shape

In [None]:
i = -1

plt.close(1)
fig = plt.figure(1)

zg_seasonal.isel(latitude=i).plot()
plt.plot(y, v_trend[:,i])

fig.tight_layout()

#### Figure A.2

In [None]:
plt.close(2)
fig,ax = plt.subplots(figsize=(9,6))

plt.contourf(y, zg_seasonal.latitude.data, v_trend.T)

plt.colorbar(label='trend [m]')
plt.ylabel('Latitude [degree N]')

fig.tight_layout()

# fig.savefig(f'{HOME}/zg_trend.pdf')

#### Remove the trend

In [None]:
v_trend.shape

In [None]:
# expand the time dimension
v_trend_broad = np.ones((len(zg.latitude.data),years,days))
v_trend_broad = (v_trend_broad.T * v_trend).T
v_trend_broad = v_trend_broad.transpose((1,2,0))

# check that the broadcasting worked correctly
assert (np.std(v_trend_broad, axis=1) == 0).all()
assert (v_trend_broad[:,0,:] == v_trend).all()

v_trend_broad.shape

In [None]:
trend_xr = xr.DataArray(v_trend_broad.reshape(-1,v_trend_broad.shape[-1]), coords=zg_ano.isel(longitude=0,drop=True).coords)
trend_xr

In [None]:
zg_detrended = zg_ano - trend_xr
zg_detrended.attrs = zg_ano.attrs
zg_detrended.name = 'z'
zg_detrended

In [None]:
zg_detrended.to_netcdf('ANO_zg500.nc')

In [None]:
plt.close(1)
fig = plt.figure(1)

zg_detrended.isel(time=-1).plot()

fig.tight_layout()

---

Old stuff

In [None]:
zg_np = zg.data.reshape([years,days,zg.shape[1],zg.shape[2]]) #separate time axis into from yaers*days to years, daysa
zg_np.shape

In [None]:
zg_np_ano = (zg_np - zg_np.mean(axis=0)) #remove seasonal cycle
zg_np_ano.shape

In [None]:
zg_np_ano_season = zg_np_ano.mean(axis=1)
zg_np_ano_season.shape

In [None]:
zg_latitudinal = zg_np_ano_season.mean(axis=2)
zg_latitudinal.shape

In [None]:
trends = np.zeros([zg_latitudinal.shape[0],zg_latitudinal.shape[1]])
for lat in range(zg.shape[1]):
    trends[:,lat] = remove_trend(zg_latitudinal[:,lat])

In [None]:
plt.figure()
c = plt.contourf(trends.T)
plt.xlabel('year')
plt.ylabel('latitude')
plt.title('ZG seasonal trend [m]',loc='left')
# plt.yticks( np.arange(0,360,40),np.arange(0,90,10),fontsize =15)
plt.yticks(np.arange(0,zg.shape[1]-1,40), np.arange(np.int64(zg.latitude.values[0]),np.int64(zg.latitude.values[-1]),10))
plt.xticks(np.arange(0,years,10),np.arange(y[0],y[-1],10),fontsize = 14, rotation = 340)
plt.colorbar(c)
plt.show()

In [None]:
np.arange(zg.latitude.values[0],zg.latitude.values[-1],10)

In [None]:
zg.latitude.values[0]

In [None]:
np.int64(zg.latitude.values[-1])

In [None]:
test = xr.open_dataarray('/ClimateDynamics/MediumSpace/ClimateLearningFR/vmascolo/Data_ERA5/ANO_zg_latitudinal_fullres_MJJA.nc')

In [None]:
test

In [None]:
test2 = xr.open_dataarray('/ClimateDynamics/MediumSpace/ClimateLearningFR/vmascolo/Data_ERA5/ANO_zg.nc')

In [None]:
test2