# Assimilate GLSD data with DIESEL for 20th century.

This notebook runs assimilation of GLSD data using the DIESEL version of the Ensemble Kalman filter. It compares sequential and all-at-once assimilation on the whole 20th century.

In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import dask
import pandas as pd
import dask.array as da
import xarray as xr
from climate.utils import load_dataset, match_vectors_indices


from dask.distributed import Client, wait, progress                             
import diesel as ds                                                             
from diesel.scoring import compute_RE_score, compute_CRPS, compute_energy_score 
from diesel.estimation import localize_covariance 

  from distributed.utils import tmpfile


In [2]:
base_folder = "/storage/homefs/ct19x463/Dev/Climate/Data/"
results_folder = "/storage/homefs/ct19x463/Dev/DIESEL/reporting/paleoclimate/results/twentieth_century/"

## Build Cluster

In [3]:
cluster = ds.cluster.UbelixCluster(n_nodes=12, mem_per_node=64, cores_per_node=3,
            partition="gpu", qos="job_gpu")                                     
cluster.scale(18)                                                           
client = Client(cluster)                                                    
                                                                                
# Add to builtins so we have one global client.
# Note that this is necessary before importing the EnsembleKalmanFilter module, so that the module is aware of the cluster.
__builtins__.CLIENT = client                                                

In [4]:
from diesel.kalman_filtering import EnsembleKalmanFilter 
from dask.diagnostics import ProgressBar
ProgressBar().register()

In [5]:
cluster

Tab(children=(HTML(value='<div class="jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-OutputArea-outpu…

In [6]:
TOT_ENSEMBLES_NUMBER = 30
(dataset_mean, dataset_members,
    dataset_instrumental, dataset_reference,
    dataset_members_zarr)= load_dataset(
    base_folder, TOT_ENSEMBLES_NUMBER, ignore_members=True)
print("Loading done.")

  sample = dates.ravel()[0]
  dataset_mean['time'] = dataset_mean.indexes['time'].to_datetimeindex()
  dataset_members['time'] = dataset_members.indexes['time'].to_datetimeindex()


Loading done.


In [7]:
from climate.kalman_filter import EnsembleKalmanFilterScatter
helper_filter = EnsembleKalmanFilterScatter(dataset_mean, dataset_members_zarr, dataset_instrumental, client)

Maximal distance to matched point: 120.54565778878536 km.


In [8]:
my_filter = EnsembleKalmanFilter()                                      
data_std = 0.1
year = 1901

In [None]:
# All at once.
for month in ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12']:
    # Prepare vectors.
    assimilation_date = '{}-{}-16'.format(year, month)
    mean_ds = helper_filter.dataset_mean.get_window_vector(assimilation_date, assimilation_date, variable='temperature')
    ensemble_ds = helper_filter.dataset_members.get_window_vector(assimilation_date, assimilation_date, variable='temperature')
    
    mean_ds, ensemble_ds = client.persist(mean_ds), client.persist(ensemble_ds)
    
    # Load data.
    data_df = pd.read_csv(os.path.join(base_folder, "Instrumental/GLSD/yearly_csv/temperature_{}.csv".format(year)), index_col=0)
    data_ds = xr.Dataset.from_dataframe(data_df)

    # Rename the date variable and make latitude/longitude into coordinates.
    data_ds = data_ds.rename({'date': 'time'})
    data_ds = data_ds.set_coords(['time', 'latitude', 'longitude'])
    data_ds = data_ds['temperature']
    
    # Prepare forward.
    date= '{}-{}-01'.format(year, month)
    data_month_ds = data_ds.where(data_ds.time==date, drop=True)

    # Need to clean data since dataset contains erroneous measurements, i.e. 
    # either extreme values (10^30) or values that are exactly zero for a given station across time.
    data_month_ds = data_month_ds.where((data_month_ds > -100.0) & (data_month_ds < 100.0) & (da.abs(data_month_ds) > 0.0001), drop=True)
    data_vector = client.persist(da.from_array(data_month_ds.data))

    
    # Get the model cell index corresponding to each observations.
    matched_inds = match_vectors_indices(mean_ds, data_month_ds)

    # WARNING: Never try to execute bare loops in DASK, it will exceed the maximal graph depth.
    G = np.zeros((data_month_ds.shape[0], mean_ds.shape[0]))
    for obs_nr, model_cell_ind in enumerate(matched_inds):
        G[obs_nr, model_cell_ind] = 1.0

    G = da.from_array(G)
    G = client.persist(G)
    
    # Estimate covariance.
    raw_estimated_cov_lazy = ds.estimation.empirical_covariance(ensemble_ds.chunk((1, 1800)))  
                                                                                
    # Persist the covariance on the cluster.                                
    raw_estimated_cov = client.persist(raw_estimated_cov_lazy) 
    progress(raw_estimated_cov)
    
    # Construct (lazy) covariance matrix.                                       
    lambda0 = 1500 # Localization in kilometers.
    lengthscales = da.from_array([lambda0])   
    kernel = ds.covariance.squared_exponential(lengthscales)
    
    # Build localization matrix.
    grid_pts = da.vstack([mean_ds.latitude, mean_ds.longitude]).T
    grid_pts = client.persist(grid_pts.rechunk((1800, 2)))
    localization_matrix = kernel.covariance_matrix(grid_pts, grid_pts, metric='haversine') 
    localization_matrix = client.persist(localization_matrix)
    progress(localization_matrix)
    
    # Localize covariance.
    loc_estimated_cov = localize_covariance(raw_estimated_cov, localization_matrix)
    loc_estimated_cov = client.persist(loc_estimated_cov)
    progress(loc_estimated_cov)
    
    # Assimilate all data.
    mean_updated_aao, ensemble_updated_aao = my_filter.update_ensemble(
        mean_ds.data, ensemble_ds.data, G,
        data_vector, data_std, loc_estimated_cov)

    # Trigger computations and block. Otherwise will clutter the scheduler. 
    mean_updated_aao = client.persist(mean_updated_aao)                
    ensemble_updated_aao = client.persist(ensemble_updated_aao)
    progress(ensemble_updated_aao) # Block till end of computations.        
    
    # Save data.
    np.save(os.path.join(results_folder, "mean_updated_aao_{}.npy".format(date)),
        mean_updated_aao.compute())
    np.save(os.path.join(results_folder, "ensemble_updated_aao_{}.npy".format(date)),
        ensemble_updated_aao.compute())

In [None]:
# Construct localization matrix.                                      
lambda0 = 1500 # Localization in kilometers.
lengthscales = da.from_array([lambda0])   
kernel = ds.covariance.squared_exponential(lengthscales)
    
# Build localization matrix.
mean_dummy = helper_filter.dataset_mean.get_window_vector('1961-01-16', '1961-01-16', variable='temperature') # Dummy, just to get the grid.

grid_pts = da.vstack([mean_dummy.latitude, mean_dummy.longitude]).T
grid_pts = client.persist(grid_pts.rechunk((1800, 2)))
localization_matrix = kernel.covariance_matrix(grid_pts, grid_pts, metric='haversine') 
localization_matrix = client.persist(localization_matrix)
progress(localization_matrix)

In [10]:
# Now sequential.
for month in ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12']:
    # Prepare vectors.
    assimilation_date = '{}-{}-16'.format(year, month)
    mean_ds = helper_filter.dataset_mean.get_window_vector(assimilation_date, assimilation_date, variable='temperature')
    ensemble_ds = helper_filter.dataset_members.get_window_vector(assimilation_date, assimilation_date, variable='temperature')
    
    mean_ds, ensemble_ds = client.persist(mean_ds), client.persist(ensemble_ds)
    
    # Load data.
    data_df = pd.read_csv(os.path.join(base_folder, "Instrumental/GLSD/yearly_csv/temperature_{}.csv".format(year)), index_col=0)
    data_ds = xr.Dataset.from_dataframe(data_df)

    # Rename the date variable and make latitude/longitude into coordinates.
    data_ds = data_ds.rename({'date': 'time'})
    data_ds = data_ds.set_coords(['time', 'latitude', 'longitude'])
    data_ds = data_ds['temperature']
    
    # Prepare forward.
    date= '{}-{}-01'.format(year, month)
    data_month_ds = data_ds.where(data_ds.time==date, drop=True)

    # Need to clean data since dataset contains erroneous measurements, i.e. 
    # either extreme values (10^30) or values that are exactly zero for a given station across time.
    data_month_ds = data_month_ds.where((data_month_ds > -100.0) & (data_month_ds < 100.0) & (da.abs(data_month_ds) > 0.0001), drop=True)
    data_vector = client.persist(da.from_array(data_month_ds.data))

    
    # Get the model cell index corresponding to each observations.
    matched_inds = match_vectors_indices(mean_ds, data_month_ds)

    # WARNING: Never try to execute bare loops in DASK, it will exceed the maximal graph depth.
    G = np.zeros((data_month_ds.shape[0], mean_ds.shape[0]))
    for obs_nr, model_cell_ind in enumerate(matched_inds):
        G[obs_nr, model_cell_ind] = 1.0

    G = da.from_array(G)
    G = client.persist(G)
    
    # Assimilate all data.
    mean_updated_seq, ensemble_updated_seq = my_filter.update_ensemble_sequential_nondask(
        mean_ds.data, ensemble_ds.data, G,
        data_vector, data_std, localization_matrix)
    
    # Save data.
    np.save(os.path.join(results_folder, "mean_updated_seq_{}.npy".format(date)),
        mean_updated_seq)
    np.save(os.path.join(results_folder, "ensemble_updated_seq_{}.npy".format(date)),
        ensemble_updated_seq)

  return blockwise(*args, **kwargs)


Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Maximal distance to matched point: 113.08002097917435 km.
0
1
2
3
4
5
6
7
8
9
10
1

In [11]:
year

1816

# Run Assimilation: All-at-once (aao) vs sequential (seq).

## Compare the different updates.

In [None]:
# Basic plotting functions.
%matplotlib inline 
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 22})
plt.rcParams['figure.dpi'] = 100
import cartopy.crs as ccrs
from shapely import geometry

def plot(unstacked_data, ax, outfile=None, vmin=None, vmax=None):
    # ax = plt.axes(projection=ccrs.Mollweide())
    # ax.set_global()
    unstacked_data.plot.contourf(levels=30, ax=ax, transform=ccrs.PlateCarree(),
                                vmin=vmin, vmax=vmax, cmap='RdBu_r',
                                 add_colorbar=False, add_labels=False,
                               #cbar_kwargs={'ticks': [-30, -20, -10, 0, 10, 20, 30],
                                           # 'label': 'temperature'}
                                 extend='both',
                                )
    # Center on Europe
    ax.set_extent([-25, 30, 30, 75], crs=ccrs.PlateCarree())
    ax.coastlines() 
    ax.set_title('')
    ax.set_ylabel('')
    if outfile is not None: plt.savefig(outfile, bbox_inches='tight', dpi=120)

In [None]:
cm = 1/2.54  # centimeters in inches
fig, axs = plt.subplots(6, 3, figsize=(60*cm, 50*cm),
                       subplot_kw={'projection': ccrs.PlateCarree()})

for i, month in enumerate(['01', '02', '03', '04', '05', '06']):
    mean_updated_aao = np.load(os.path.join(results_folder, 'mean_updated_aao_1816-{}-01.npy'.format(month)))
    mean_updated_seq = np.load(os.path.join(results_folder, 'mean_updated_seq_1816-{}-01.npy'.format(month)))
    
    unstacked_updated_mean_aao = helper_filter.dataset_mean.unstack_window_vector(mean_updated_aao, time='1816-{}-16'.format(month), variable_name='temperature')
    unstacked_updated_mean_seq = helper_filter.dataset_mean.unstack_window_vector(mean_updated_seq, time='1816-{}-16'.format(month), variable_name='temperature')
    ref = dataset_reference.temperature.sel(time='1816-{}-16'.format(month))
    
    plot(unstacked_updated_mean_aao, axs[i, 0], vmin=-20, vmax=30)
    plot(unstacked_updated_mean_seq, axs[i, 1], vmin=-20, vmax=30)    
    plot(ref, axs[i, 2], vmin=-20, vmax=30)    

In [None]:
unstacked_updated_mean_aao = helper_filter.dataset_mean.unstack_window_vector(mean_updated_aao.compute(), time='1816-01-16', variable_name='temperature')
plot(unstacked_updated_mean_aao, vmin=-40, vmax=40)

In [None]:
unstacked_updated_ensemble_0_aao = helper_filter.dataset_mean.unstack_window_vector(ensemble_updated_aao[0, :].compute(), time='1961-01-16', variable_name='temperature')
plot(unstacked_updated_ensemble_0_aao, vmin=-40, vmax=40)

In [None]:
unstacked_updated_mean_seq = helper_filter.dataset_mean.unstack_window_vector(mean_updated_seq, time='1961-01-16', variable_name='temperature')
plot(unstacked_updated_mean_seq, vmin=-40, vmax=40)

In [None]:
# Plot difference.
plot(unstacked_updated_mean_aao - unstacked_updated_mean_seq, vmin=-7, vmax=7)

In [None]:
# Plot original data (before updating.
unstacked_mean = helper_filter.dataset_mean.unstack_window_vector(mean_ds.values.reshape(-1), time='1961-01-16', variable_name='temperature')
plot(unstacked_mean.temperature, vmin=-40, vmax=40)

In [None]:
# Plot station data.
df = data_month_ds.to_dataframe()
# Could reset coordinates if you really wanted
# df = df.reset_index()
cm = 1/2.54  # centimeters in inches
fig = plt.figure(figsize=(40*cm, 25*cm))
ax = plt.axes(projection=ccrs.Mollweide())
ax.set_global()
    
ax.coastlines()  

df.plot.scatter('longitude', 'latitude', c=data_month_ds.name, cmap='jet', ax=ax, transform=ccrs.PlateCarree())

In [None]:
# Plot error wrt reference.
plot(unstacked_updated_mean_aao - dataset_reference.temperature.sel(time='1961-01-16'), vmin=-7, vmax=7)

In [None]:
plot(unstacked_updated_mean_seq - dataset_reference.temperature.sel(time='1961-01-16'), vmin=-7, vmax=7)

In [None]:
# Plot original error.
plot(unstacked_mean.temperature - dataset_reference.temperature.sel(time='1961-01-16'), vmin=-7, vmax=7)

In [None]:
helper_filter.dataset_members.dataset_members.time.values

In [None]:
(dataset_reference.temperature.sel(time='1816-12-16') - dataset_reference.temperature.sel(time='1900-06-16')).plot()