# Assimilate GLSD data with DIESEL

This notebook runs assimilation of GLSD data using the DIESEL version of the Ensemble Kalman filter. 

It also compares two assimilation methods (normal vs cell-averaged observations).

In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import dask
import pandas as pd
import dask.array as da
import xarray as xr
from climate.utils import load_dataset

from dask.distributed import Client, wait, progress                             
import diesel as ds                                                             
from diesel.scoring import compute_RE_score, compute_CRPS, compute_energy_score 
from diesel.estimation import localize_covariance 

In [None]:
base_folder = "/storage/homefs/ct19x463/Dev/Climate/Data/"
results_folder = "/storage/homefs/ct19x463/Dev/Climate/reporting/all_at_once_vs_sequential"

## Build Cluster

In [None]:
cluster = ds.cluster.UbelixCluster(n_nodes=12, mem_per_node=64, cores_per_node=3,
            partition="gpu", qos="job_gpu")                                     
cluster.scale(9)                                                           
client = Client(cluster)                                                    
                                                                                
# Add to builtins so we have one global client.
# Note that this is necessary before importing the EnsembleKalmanFilter module, so that the module is aware of the cluster.
__builtins__.CLIENT = client                                                

In [None]:
from diesel.kalman_filtering import EnsembleKalmanFilter 
from dask.diagnostics import ProgressBar
ProgressBar().register()

In [None]:
cluster

In [None]:
TOT_ENSEMBLES_NUMBER = 30
(dataset_mean, dataset_members,
    dataset_instrumental, dataset_reference,
    dataset_members_zarr)= load_dataset(
    base_folder, TOT_ENSEMBLES_NUMBER, ignore_members=True)
print("Loading done.")

In [None]:
from climate.kalman_filter import EnsembleKalmanFilterScatter
helper_filter = EnsembleKalmanFilterScatter(dataset_mean, dataset_members_zarr, dataset_instrumental, client)

## Prepare vectors.

In [None]:
assimilation_date = '1990-10-16'
mean_ds = helper_filter.dataset_mean.get_window_vector(assimilation_date, assimilation_date, variable='temperature')
ensemble_ds = helper_filter.dataset_members.get_window_vector(assimilation_date, assimilation_date, variable='temperature')

In [None]:
mean_ds, ensemble_ds = client.persist(mean_ds), client.persist(ensemble_ds)

## Load Data

In [None]:
year = int(assimilation_date[:4])
data_df = pd.read_csv(os.path.join(base_folder, "Instrumental/GLSD/yearly_csv/temperature_{}.csv".format(year)), index_col=0)
data_ds = xr.Dataset.from_dataframe(data_df)

# Rename the date variable and make latitude/longitude into coordinates.
data_ds = data_ds.rename({'date': 'time'})
data_ds = data_ds.set_coords(['time', 'latitude', 'longitude'])
data_ds = data_ds['temperature']

## Prepare Forward Operator

In [None]:
# Select one month.
# Note that GLSD uses different reference for month (01 instead of 16), so have to replace.
assimilation_date_datasel= assimilation_date[:-2] + '01'
data_month_ds = data_ds.where(data_ds.time==assimilation_date_datasel, drop=True)

# Need to clean data since dataset contains erroneous measurements, i.e. 
# either extreme values (10^30) or values that are exactly zero for a given station across time.
data_month_ds = data_month_ds.where((data_month_ds > -100.0) & (data_month_ds < 100.0) & (da.abs(data_month_ds) > 0.0001), drop=True)

In [None]:
# Get the model cell index corresponding to each observations.
from climate.utils import match_vectors_indices
matched_inds = match_vectors_indices(mean_ds, data_month_ds)

# WARNING: Never try to execute bare loops in DASK, it will exceed the maximal graph depth.
G = np.zeros((data_month_ds.shape[0], mean_ds.shape[0]))
for obs_nr, model_cell_ind in enumerate(matched_inds):
    G[obs_nr, model_cell_ind] = 1.0

G = da.from_array(G)
G = client.persist(G)

## (Deprecated) Make Filter more stable by only assimilating mean for model cells that contain multiple observations.

The idea here is that having multiple (in this case around 50) observations being assimilated in a single model cell can lead to numerical instabilities. 
We thus work with one observations per cell, being the mean of all the observations. 
In the end the idea was abandoned, since it only plays a role for the updating of the ensemble members. For the mean everything works.

In [None]:
obs_per_cell = da.sum(G, axis=0)
obs_per_cell[obs_per_cell == 0] = 1
G_norm = G / obs_per_cell
G_norm = client.persist(G_norm)

In [None]:
# The operator G_avg contains one single observation per model cell (or zero), which is the average of all observations belonging to that cell.
averaged_data = (G_norm.T @ data_month_ds.values).T.compute()
G_avg = da.eye(averaged_data.shape[0])
G_avg = G_avg[np.flatnonzero(averaged_data), :]
d_avg = averaged_data[np.flatnonzero(averaged_data)]

## Estimate Covariance 

In [None]:
 # Estimate covariance using empirical covariance of the ensemble.       
raw_estimated_cov_lazy = ds.estimation.empirical_covariance(ensemble_ds.chunk((1, 1800)))  
                                                                                
# Persist the covariance on the cluster.                                
raw_estimated_cov = client.persist(raw_estimated_cov_lazy) 
progress(raw_estimated_cov)

In [None]:
# Construct (lazy) covariance matrix.                                       
lambda0 = 1500 # Localization in kilometers.

lengthscales = da.from_array([lambda0])   
kernel = ds.covariance.squared_exponential(lengthscales)

In [None]:
# Perform covariance localization.
grid_pts = da.vstack([mean_ds.latitude, mean_ds.longitude]).T
grid_pts = client.persist(grid_pts.rechunk((1800, 2)))
localization_matrix = kernel.covariance_matrix(grid_pts, grid_pts, metric='haversine') 
localization_matrix = client.persist(localization_matrix)
progress(localization_matrix)

In [None]:
# TODO: Here we have added multiplicative inflation.
loc_estimated_cov = localize_covariance(raw_estimated_cov, localization_matrix)
loc_estimated_cov = client.persist(loc_estimated_cov)
progress(loc_estimated_cov)

# Run Assimilation: All-at-once (aao) vs sequential (seq).

In [None]:
 # Run data assimilation using an ensemble Kalman filter.                
my_filter = EnsembleKalmanFilter()                                      

data_std = 3.0
data_vector = client.persist(da.from_array(data_month_ds.data))

In [None]:
# Assimilate all data.
mean_updated_aao, _ = my_filter.update_ensemble(
    mean_ds.data, ensemble_ds.data, G,
    data_vector, data_std, loc_estimated_cov)

# Trigger computations and block. Otherwise will clutter the scheduler. 
mean_updated_aao = client.persist(mean_updated_aao)                
# ensemble_updated_one_go_loc = client.persist(ensemble_updated_one_go_loc)
# progress(mean_updated_aao) # Block till end of computations.                               

In [None]:
# Run the sequential version.
mean_updated_seq = my_filter.update_mean_sequential_nondask(
    mean_ds.data, G,
    data_vector, data_std, loc_estimated_cov)

## Compare the different updates.

In [None]:
# Basic plotting functions.
%matplotlib inline 
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 22})
plt.rcParams['figure.dpi'] = 100
import cartopy.crs as ccrs
from shapely import geometry

def plot(unstacked_data, outfile=None, vmin=None, vmax=None):
    cm = 1/2.54  # centimeters in inches
    fig = plt.figure(figsize=(40*cm, 25*cm))
    ax = plt.axes(projection=ccrs.Mollweide())
    # ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    ax.set_global()
    unstacked_data.plot.contourf(levels=30, ax=ax, transform=ccrs.PlateCarree(),
                                vmin=vmin, vmax=vmax, cmap='RdBu_r',
                               cbar_kwargs={'ticks': [-30, -20, -10, 0, 10, 20, 30],
                                           'label': 'temperature'})
    ax.coastlines()    
    if outfile is not None: plt.savefig(outfile, bbox_inches='tight', dpi=120)

In [None]:
unstacked_updated_mean_aao = helper_filter.dataset_mean.unstack_window_vector(mean_updated_aao.compute(), time=assimilation_date, variable_name='temperature')
plot(unstacked_updated_mean_aao, vmin=-40, vmax=40)

In [None]:
unstacked_updated_mean_seq = helper_filter.dataset_mean.unstack_window_vector(mean_updated_seq, time=assimilation_date, variable_name='temperature')
plot(unstacked_updated_mean_seq, vmin=-40, vmax=40)

In [None]:
# Plot difference.
plot(unstacked_updated_mean_aao - unstacked_updated_mean_seq, vmin=-7, vmax=7)

In [None]:
# Plot original data (before updating.
unstacked_mean = helper_filter.dataset_mean.unstack_window_vector(mean_ds.values.reshape(-1), time=assimilation_date, variable_name='temperature')
plot(unstacked_mean, vmin=-40, vmax=40)

In [None]:
# Plot station data.
df = data_month_ds.to_dataframe()
# Could reset coordinates if you really wanted
# df = df.reset_index()
cm = 1/2.54  # centimeters in inches
fig = plt.figure(figsize=(40*cm, 25*cm))
ax = plt.axes(projection=ccrs.Mollweide())
ax.set_global()
    
ax.coastlines()  

df.plot.scatter('longitude', 'latitude', c=data_month_ds.name, cmap='jet', ax=ax, transform=ccrs.PlateCarree())

In [None]:
# Plot error wrt reference.
plot(unstacked_updated_mean_aao - dataset_reference.temperature.sel(time=assimilation_date), vmin=-7, vmax=7)

In [None]:
plot(unstacked_updated_mean_seq - dataset_reference.temperature.sel(time=assimilation_date), vmin=-7, vmax=7)

In [None]:
# Plot original error.
plot(unstacked_mean - dataset_reference.temperature.sel(time=assimilation_date), vmin=-7, vmax=7)

## Compute accuracy metrics.

In [None]:
from diesel.scoring import compute_RE_score, compute_CRPS, compute_energy_score, compute_RMSE

compute_RMSE(mean_ds.values, stacked_ref, min_lat=-70, max_lat=70)

In [None]:
ref = dataset_reference.temperature.sel(time=assimilation_date)
stacked_ref = ref.stack(stacked_dim=('latitude', 'longitude'))

stacked_prior_mean = unstacked_mean.stack(stacked_dim=('latitude', 'longitude'))
stacked_updated_mean_seq = unstacked_updated_mean_seq.stack(stacked_dim=('latitude', 'longitude'))
stacked_updated_mean_aao = unstacked_updated_mean_aao.stack(stacked_dim=('latitude', 'longitude'))

print(compute_RMSE(stacked_prior_mean.values, stacked_ref, min_lat=-70, max_lat=70))
print(compute_RMSE(stacked_updated_mean_seq.values, stacked_ref, min_lat=-70, max_lat=70))
print(compute_RMSE(stacked_updated_mean_aao.values, stacked_ref, min_lat=-70, max_lat=70))

## Save results (updated temperature fields).

In [None]:
updated_means_aao, updated_means_seq, prior_means = [], [], []

In [None]:
updated_means_aao.append(unstacked_updated_mean_aao.copy())
updated_means_seq.append(unstacked_updated_mean_seq.copy())
prior_means.append(unstacked_mean.copy())

In [None]:
xr.concat(updated_means_aao, dim='time')