# Author information
## Credit:
© Ehsan Mehdipour (ehsan.mehdipour@awi.de)
Alfred Wegener Insitute for Polar and Marine Research, Bremerhaven, Germany


## Objective:
This code is provided for data analysis and gap-filling of ocean color product especially Total Chlorophyll-a and Phytoplankton Functional Type datasets.
The dataset is accessable throguh Copernicus Marine Service with the following DOI:


# Setup and configuration

## Import Modules

In [1]:
# Modules for data analysis
import os
os.environ["OMP_NUM_THREADS"] = "1"
import xarray as xr
import numpy as np
import pandas as pd
import dask
import dask.array as da

# Modules for data visualisation
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LogNorm, LinearSegmentedColormap
import matplotlib.dates as mdates
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 12
from matplotlib import gridspec
import matplotlib.ticker as mticker
import cartopy.crs as ccrs
import cartopy.feature

# miscellaneous Modules
import string
from tqdm import tqdm

# Manual modules or parameters
from function import *
from params import *

## Dask setup for heavy computation and parallalization

In [None]:
# Creating SLURM Cluster
# cluster, client = dask_slurm_cluster(queue='smp', cores=16, scale=16)

# Creating Distributed Cluster
# client = dask_distributed_client(n_workers=8, threads_per_worker=None)

# 
# cluster.close()
# client.close()

# Importing Satellite Data

## Input and output parameters configuration

In [2]:
# CSV file detailing the boundries of regions of interest (ROI)
regions = pd.read_csv('regions.csv', index_col=0)

In [3]:
# Reading the metadata files for all satellite images in the data_dir folder

ds_read = xr.open_mfdataset(params['data_dir'], chunks='auto')
ds_read

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,43.35 GiB,35.60 MiB
Shape,"(1247, 4320, 8640)","(1, 4320, 8640)"
Dask graph,1247 chunks in 2495 graph layers,1247 chunks in 2495 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 43.35 GiB 35.60 MiB Shape (1247, 4320, 8640) (1, 4320, 8640) Dask graph 1247 chunks in 2495 graph layers Data type int8 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,43.35 GiB,35.60 MiB
Shape,"(1247, 4320, 8640)","(1, 4320, 8640)"
Dask graph,1247 chunks in 2495 graph layers,1247 chunks in 2495 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 173.39 GiB 127.98 MiB Shape (1247, 4320, 8640) (1, 4095, 8193) Dask graph 4988 chunks in 2495 graph layers Data type float32 numpy.ndarray",8640  4320  1247,

Unnamed: 0,Array,Chunk
Bytes,173.39 GiB,127.98 MiB
Shape,"(1247, 4320, 8640)","(1, 4095, 8193)"
Dask graph,4988 chunks in 2495 graph layers,4988 chunks in 2495 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [5]:
# Function for extracting the overall interesed region (e.g. whole atlantic ocean)

def ds_crop(ds_read, params):
    ds = ds_read.where(
        (ds_read.lat > params['boundaries']['lat_min'])&
        (ds_read.lon > params['boundaries']['lon_min'])&
        (ds_read.lat < params['boundaries']['lat_max'])&
        (ds_read.lon < params['boundaries']['lon_max']),
        drop=True
    ).sel(time=slice(params['start_date'],params['end_date']))
    
    flags = ds[params['flags']]
    ds_unc_rel= ds[params['UNC']]
    ds = ds[params['PFT']]
    
    return ds, ds_unc_rel, flags

In [6]:
# data_dir = '/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/data/20180[5-6]*_cmems_obs-oc_glo_bgc-plankton_my_l3-multi-4km_P1D.nc'
# ds_input = xr.open_mfdataset(data_dir, chunks='auto')
# ds_input = ds_input[params['PFT']]

# ds_input = ds_input.where(
#     (ds_input.lat > -50)&
#     (ds_input.lon > -64)&
#     (ds_input.lat < 52)&
#     (ds_input.lon < 3),
#     drop=True
# ).sel(time=slice('20160425','20190425'))

# # ds_input.to_netcdf('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/merged/ds_input.nc', compute=True)

In [7]:
ds, ds_unc_rel, flags = ds_crop(ds_read, params)

# Uncertainty

Converting the relative uncertainty in linear scale to logscale standard deviation (SD) uncertainty

In [None]:
# An empirical function that convert the relative uncertainty in linear scale to standard deviation uncertainty in logscale (based on )

def Unc_LinRel_to_LogSD(ds_unc_rel):
    ds_unc_sd = np.log10((ds_unc_rel/100)+1)
    return ds_unc_sd

In [None]:
# renaming uncertainty to match with the ds
ds_unc_rel = ds_unc_rel.rename(params['UNC_dict'])

# converting the relative uncertainty to SD
ds_unc_sd = Unc_LinRel_to_LogSD(ds_unc_rel)

# Temporal mean of the uncertainty
ds_unc_rel_mean = ds_unc_rel.mean(dim='time')
ds_unc_sd_mean = ds_unc_sd.mean(dim='time')

In [None]:
# Saving the temporal average uncertainty

# ds_unc_rel_mean.to_netcdf((os.path.join(params['work_dir'],'/spatial_avg_unc_rel_L3.nc'), compute=True)
# ds_unc_sd_mean.to_netcdf((os.path.join(params['work_dir'],'spatial_avg_unc_sd_L3.nc'), compute=True)

In [None]:
# reopen if already saved.

ds_unc_rel_mean = xr.open_dataset(os.path.join(params['work_dir'],'spatial_avg_unc_rel_L3.nc')).compute()
ds_unc_sd_mean = xr.open_dataset(os.path.join(params['work_dir'],'spatial_avg_unc_sd_L3.nc')).compute()

In [None]:
## Plotting the average temporal uncertainty

fig, ax = plt.subplots(2,3, constrained_layout=True,figsize=(8,6), subplot_kw={'projection': ccrs.PlateCarree()}, sharex=True, sharey=True)
ax = ax.flatten()
for i ,p in enumerate(params['PFT']):
    f = ds_unc_rel_mean[p].plot.contourf(ax=ax[i], cmap = 'coolwarm', robust=True, add_colorbar=False)
    ax[i].set_title(p, fontsize=14)
    
    gl = ax[i].gridlines(draw_labels=True,alpha=0.3,facecolor='grey',edgecolor='dimgrey', linestyle='--')
    gl.right_labels = False
    gl.top_labels = False

    if i<3:
        gl.bottom_labels = False
    if i not in [0,3]:
        gl.left_labels = False
    
    gl.xformatter = mticker.FuncFormatter(lambda x, _: f"{abs(x):.0f}° {'E' if x >= 0 else 'W'}")
    gl.yformatter = mticker.FuncFormatter(lambda y, _: f"{abs(y):.0f}° {'N' if y >= 0 else 'S'}")
    
    cbar = fig.colorbar(f, ax=ax[i], orientation='vertical',extend='both', shrink=0.6, pad=0.01)
    # cbar.set_label(r'Relative Uncertainty [%]',fontsize=12)
    fig.text(1.01, 0.5, r'Relative uncertainty [\%]', ha='center', va='center', rotation=90, fontsize=12)

# plt.savefig('fig/spatial_avg_unc_rel_L3_3.jpg',dpi=300, bbox_inches='tight')

In [None]:
## Plotting the average temporal uncertainty

fig, ax = plt.subplots(2,3, constrained_layout=True,figsize=(8,6), subplot_kw={'projection': ccrs.PlateCarree()}, sharex=True, sharey=True)
ax = ax.flatten()
for i ,p in enumerate(params['PFT']):
    f = ds_unc_sd_mean[p].plot.contourf(ax=ax[i], cmap = 'coolwarm', robust=True, add_colorbar=False)
    ax[i].set_title(p, fontsize=14)
    
    gl = ax[i].gridlines(draw_labels=True,alpha=0.3,facecolor='grey',edgecolor='dimgrey', linestyle='--')
    gl.right_labels = False
    gl.top_labels = False

    if i<3:
        gl.bottom_labels = False
    if i not in [0,3]:
        gl.left_labels = False
    
    gl.xformatter = mticker.FuncFormatter(lambda x, _: f"{abs(x):.0f}° {'E' if x >= 0 else 'W'}")
    gl.yformatter = mticker.FuncFormatter(lambda y, _: f"{abs(y):.0f}° {'N' if y >= 0 else 'S'}")
    
    cbar = fig.colorbar(f, ax=ax[i], orientation='vertical',extend='both', shrink=0.6, pad=0.01)
    fig.text(1.01, 0.5, r'Standard deviation (SD) [$\frac{mg}{m^{3}}]$', ha='center', va='center', rotation=90, fontsize=12)

# plt.savefig('fig/spatial_avg_unc_rel_L3_3.jpg',dpi=300, bbox_inches='tight')

# Missing rates

In [8]:
spatial_count = ds[['CHL','DIATO']].count(dim=['time'])
temporal_count = ds[['CHL','DIATO']].count(dim=['lat','lon'])

#### Temporal

In [None]:
temporal_count = temporal_count.compute()

In [10]:
sea_count = ds.flags.isel(time=0).where(ds.flags.isel(time=0)==0).count().compute().values

type: 'Dataset' object has no attribute 'flags'

In [None]:
temporal_perc = 100 - (temporal_count / sea_count * 100)

In [None]:
window_size = 30

temporal_perc_mov = temporal_perc.rolling(time=window_size, center=True).mean()
# temporal_perc_mov

In [None]:
temporal_perc.to_netcdf('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/temporal_missing_perc.nc')

In [None]:
temporal_perc_mov.to_netcdf('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/temporal_missing_perc_mov.nc')

In [None]:
fig,ax = plt.subplots(1,1,figsize=(12,4))

temporal_perc.CHL.plot(ax=ax, label='TChla', c = 'lightblue')
plt.axhline(y=temporal_perc.mean().CHL.values, color='blue', linestyle='--', label='TChla avg')
temporal_perc_mov.CHL.plot(ax=ax, label='TChla moving avg', c = 'blue')

temporal_perc.DIATO.plot(ax=ax, label='PFTs', c = 'lightcoral', alpha=0.8)
plt.axhline(y=temporal_perc.mean().DIATO.values, color='darkred', linestyle='--', label='PFTs avg')
temporal_perc_mov.DIATO.plot(ax=ax, label='PFTs moving avg', c = 'darkred')

ax.set_ylabel('Temporal Missing Rate [%]')
ax.set_xlabel('')
ax.legend(fontsize=8, loc='lower left',ncol=2)

ax.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))


ax.xaxis.set_minor_locator(mdates.MonthLocator(interval=1))

ax.text(x=temporal_perc.time.values[-10], y= 83, s = f'{temporal_perc.DIATO.mean().values:.1f}%',color='darkred');
ax.text(x=temporal_perc.time.values[-10], y= 53, s = f'{temporal_perc.CHL.mean().values:.1f}%',color='blue');
# ax.xaxis.set_minor_formatter(mdates.DateFormatter('%m'))

# plt.savefig('fig/temporal_missing_rate.jpg',dpi=600, bbox_inches='tight')

#### Spatial

In [None]:
spatial_count = spatial_count.compute()

In [None]:
msk = ds.flags.isel(time=0).compute()

In [None]:
spatial_perc = (100 - (spatial_count / ds.time.size * 100)).where(msk==0)

In [None]:
spatial_perc.to_netcdf('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/spatial_missing_perc.nc')

In [None]:
fig,ax = plt.subplots(1,2,figsize=(5,5),subplot_kw={'projection': ccrs.PlateCarree()}, constrained_layout=True, sharey=True)
levels = [20,30,40,50,60,70,80,90,100]

p1 = spatial_perc.CHL.plot.contourf(ax=ax[0], levels=levels, cmap = 'viridis', add_colorbar=False,transform=ccrs.PlateCarree())
ax[0].set_aspect('equal')
ax[0].set_title('TChla');
# plt.colorbar(p1, label='Spatial Missing Rate [%]', extend='both')
ax[0].coastlines()
gl = ax[0].gridlines(draw_labels=True,alpha=0.3, linestyle='--')
gl.right_labels = False
gl.top_labels = False


p2 = spatial_perc.DIATO.plot.contourf(ax=ax[1], levels=levels , cmap = 'viridis', add_colorbar=False,transform=ccrs.PlateCarree())
ax[1].set_aspect('equal')
ax[1].set_title('PFTs');
# plt.colorbar(p2, label='Spatial Missing Rate [%]', extend='both')
ax[1].coastlines()
gl = ax[1].gridlines(draw_labels=True,alpha=0.3, linestyle='--')
gl.right_labels = False
gl.top_labels = False

cbar = fig.colorbar(p1, ax=ax, orientation='horizontal', label='Spatial Missing Rate [%]')

# plt.savefig('fig/spatial_missing_rate.jpg',dpi=600, bbox_inches='tight')

### Missing all together

In [None]:
ds_read = xr.open_mfdataset(["/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/data/20180101_cmems_obs-oc_glo_bgc-plankton_my_l3-multi-4km_P1D.nc",
                            "/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/data/20190101_cmems_obs-oc_glo_bgc-plankton_my_l3-multi-4km_P1D.nc"], chunks='auto')
ds_read

In [None]:
ds = ds_read.where(
    (ds_read.lat > -50)&
    (ds_read.lon > -64)&
    (ds_read.lat < 52)&
    (ds_read.lon < 3),
    drop=True
).sel(time=slice('20160425','20190425'))

flags = ds['flags']
ds_unc_rel= ds[params['UNC']]
ds = ds[params['PFT']]

In [None]:
temporal_perc= xr.open_dataset('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/temporal_missing_perc.nc').compute()
temporal_perc_mov=xr.open_dataset('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/temporal_missing_perc_mov.nc').compute()
spatial_perc=xr.open_dataset('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/spatial_missing_perc.nc').compute()

In [None]:
merged_input = xr.open_dataset('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/merged/ds_input.nc')
merged_dincae = xr.open_dataset('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/merged/ds_dincae.nc')

In [None]:
merged_input = merged_input.DIATO.sel(time='20180526')
merged_dincae = merged_dincae.DIATO.sel(time='20180526')
merged_input = merged_input.where(np.isfinite(merged_dincae)).compute()

In [None]:
regions = pd.read_csv('regions.csv')


In [None]:
plt.rcParams['text.usetex'] = True

In [None]:
# Create a figure
fig = plt.figure(figsize=(9, 7), constrained_layout=True)


# Create a GridSpec with 2 rows and 5 columns
gs = gridspec.GridSpec(2, 5, height_ratios=[1, 1], figure=fig) 


#----------------------------------------------------------------------------------------------------
# Second row: 5 smaller figures
axes = [fig.add_subplot(gs[0, i], projection=ccrs.PlateCarree()) for i in range(5)]

#--------------------------------------------------------------------------------------------------
# Samples

t1 = '2018-05-26'
t2 = '2019-01-01'
vmin=0.01
vmax=1

p1 = ds.CHL.sel(time=t1).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[0], add_colorbar=False)
axes[0].set_title(f'TChla\n[{t1}]', fontsize=14)
p2 = ds.DIATO.sel(time=t1).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[1], add_colorbar=False)
axes[1].set_title(f'Diatoms\n[{t1}]', fontsize=14)
p3 = ds.DIATO.sel(time=t2).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[2], add_colorbar=False)
axes[2].set_title(f'Diatoms\n[{t2}]', fontsize=14)

p=[p1,p2,p3]
    
cbar = plt.colorbar(p[2], ax=axes[0:3] , orientation='horizontal', shrink=0.6, extend='both')
cbar.set_ticks([0.01, 0.1, 1])
cbar.set_ticklabels([0.01, 0.1, 1])
cbar.set_label(r'Chla Concentration $[mg\cdot m^{-3}]$', fontsize=14)
cbar.ax.tick_params(labelsize=11)

#-----------------------------------------------------------------------------------------------------
# spatial missing
levels = [20,30,40,50,60,70,80,90,100]
p1 = spatial_perc.CHL.plot.contourf(ax=axes[3], levels=levels, cmap = 'coolwarm', add_colorbar=False,transform=ccrs.PlateCarree())
axes[3].set_title('TChla', fontsize=14);


p2 = spatial_perc.DIATO.plot.contourf(ax=axes[4], levels=levels , cmap = 'coolwarm', add_colorbar=False,transform=ccrs.PlateCarree())
axes[4].set_title('PFT', fontsize=14);

cbar = fig.colorbar(p1, ax=axes[3:5], orientation='horizontal')
cbar.set_label('Spatial variation of average missing rate $[\%]$', fontsize=14)
cbar.ax.tick_params(labelsize=11)
#------------------------
for i, ax in enumerate(axes):
    if i<3:
        ax.add_feature(cartopy.feature.LAND, facecolor='lightgrey', edgecolor='dimgrey', alpha=0.5)
    else:
        ax.add_feature(cartopy.feature.LAND, facecolor='white', edgecolor='k', alpha=0.8)
        
    gl = ax.gridlines(draw_labels=True,alpha=0.3, linestyle='-')
    gl.right_labels = False
    gl.top_labels = False
    
    gl.xformatter = mticker.FuncFormatter(lambda x, _: f"{abs(x):.0f}° {'E' if x >= 0 else 'W'}")
    gl.yformatter = mticker.FuncFormatter(lambda y, _: f"{abs(y):.0f}° {'N' if y >= 0 else 'S'}")
    
    if i in [1,2,3,4]:
        gl.left_labels = False
        
    gl.xlocator = mticker.FixedLocator([-60,-30,0])
    ax.text(-0.03, 1.03, r'$\textbf{('+ f'{string.ascii_lowercase[i]}' + ')}$', transform=ax.transAxes, 
            size=12, weight='bold')
    
#-----------------------------------------------------------------------------
# temporal missing
ax1 = fig.add_subplot(gs[1, :])

# temporal_perc.CHL.plot(ax=ax1, label='TChla', c = 'lightblue')
# ax1.axhline(y=temporal_perc.mean().CHL.values, color='blue', linestyle='--', label='TChla avg.')
# temporal_perc_mov.CHL.plot(ax=ax1, label='TChla moving avg.', c = 'blue')

# temporal_perc.DIATO.plot(ax=ax1, label='PFT', c = 'lightcoral', alpha=0.8)
# ax1.axhline(y=temporal_perc.mean().DIATO.values, color='darkred', linestyle='--', label='PFT avg.')
# temporal_perc_mov.DIATO.plot(ax=ax1, label='PFT moving avg.', c = 'darkred')



temporal_perc.CHL.plot(ax=ax1, label='TChla', c = 'grey', alpha=0.6)
ax1.axhline(y=temporal_perc.mean().CHL.values, color='black', linestyle='--', label='TChla avg.')
temporal_perc_mov.CHL.plot(ax=ax1, label='TChla moving avg.', c = 'black')

temporal_perc.DIATO.plot(ax=ax1, label='PFT', c = 'lightcoral', alpha=0.6)
ax1.axhline(y=temporal_perc.mean().DIATO.values, color='darkred', linestyle='--', label='PFT avg.')
temporal_perc_mov.DIATO.plot(ax=ax1, label='PFT moving avg.', c = 'darkred')




ax1.set_title('Temporal variation of average missing rate $[\%]$', fontsize=14)
ax1.set_ylabel('')
ax1.set_xlabel('')

legend = ax1.legend(fontsize=10, loc='lower left',ncol=2)
legend.set_alpha(0.3)
for line in legend.get_lines():
    line.set_linewidth(2.0)

ax1.tick_params(axis='y', labelsize=11)
ax1.set_ylim(20,100)

ax1.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[1,4,7,10]))
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y–%m–%d'))
ax1.xaxis.set_minor_locator(mdates.MonthLocator(interval=1))

ax1.text(x=temporal_perc.time.values[-20], y= 84, s = f'{temporal_perc.DIATO.mean().values:.1f}\%',color='saddlebrown', size=14);
ax1.text(x=temporal_perc.time.values[-20], y= 54, s = f'{temporal_perc.CHL.mean().values:.1f}\%',color='black', size=14);

ax1.text(0, 1.05, r'$\textbf{('+ f'{string.ascii_lowercase[5]}' + ')}$', transform=ax1.transAxes, 
        size=12, weight='bold');
gl = ax1.grid(alpha=0.2, linestyle='-')

# plt.savefig('fig/spatiotempiral_missing5.png',dpi=300, bbox_inches='tight')

In [None]:
from matplotlib.patches import Rectangle
import geopandas as gpd


In [None]:
gdf = gpd.read_file("clusters/PS113-path.shp")

In [None]:
# Create a figure
fig = plt.figure(figsize=(10, 7), constrained_layout=True)


# Create a GridSpec with 2 rows and 5 columns
gs = gridspec.GridSpec(2, 5, height_ratios=[1.5, 1], figure=fig) 


#----------------------------------------------------------------------------------------------------
# Second row: 5 smaller figures
axes = [fig.add_subplot(gs[0, i], projection=ccrs.PlateCarree()) for i in range(5)]

#--------------------------------------------------------------------------------------------------
# Samples

t1 = '2018-01-01'
t2 = '2019-01-01'
vmin=0.01
vmax=1

p1 = ds.CHL.sel(time=t1).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[0], add_colorbar=False)
axes[0].set_title(f'CHL\n[{t1}]', fontsize=14)
p2 = ds.DIATO.sel(time=t1).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[1], add_colorbar=False)
axes[1].set_title(f'DIATO\n[{t1}]', fontsize=14)
p3 = ds.DIATO.sel(time=t2).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[2], add_colorbar=False)
axes[2].set_title(f'DIATO\n[{t2}]', fontsize=14)

p=[p1,p2,p3]
    
cbar = plt.colorbar(p[2], ax=axes[0:3] , orientation='horizontal', shrink=0.6, extend='both')
cbar.set_ticks([0.01, 0.1, 1])
cbar.set_ticklabels([0.01, 0.1, 1])
cbar.set_label(r'Chla Concentration $[mg\cdot m^{-3}]$', fontsize=14)
cbar.ax.tick_params(labelsize=11)

###-------------------------------------
## regions
for idx in range(len(regions)):
    lon_min = regions.iloc[idx].lon_min
    lon_max = regions.iloc[idx].lon_max

    lat_min = regions.iloc[idx].lat_min
    lat_max = regions.iloc[idx].lat_max

    lat_center = (lat_min+lat_max)/2
    lon_center = (lon_min+lon_max)/2
    
    
    patch = axes[2].add_patch(Rectangle((lon_min, lat_min), (lon_max - lon_min), (lat_max - lat_min),
             edgecolor = 'k',
             fill=False,
             lw=1.5,))
    if idx == 0:
        patch.set_label('ROI')

    axes[2].text(x = lon_center, y = lat_center, s = f'{idx+1}', color='white',
            horizontalalignment='center', verticalalignment='center', fontsize=14, fontweight='bold',
              bbox=dict(facecolor='k', alpha=1, edgecolor='none', boxstyle='round,pad=0.05'))

gdf.plot(ax=axes[1], label='PS113 track', linewidth=1.5, color='sienna')
axes[1].legend(loc = 'lower right', fontsize=8)
axes[2].legend(loc = 'lower right', fontsize=10)
#-----------------------------------------------------------------------------------------------------
# spatial missing
levels = [20,30,40,50,60,70,80,90,100]
p1 = spatial_perc.CHL.plot.contourf(ax=axes[3], levels=levels, cmap = 'coolwarm', add_colorbar=False,transform=ccrs.PlateCarree())
axes[3].set_title('CHL', fontsize=14);


p2 = spatial_perc.DIATO.plot.contourf(ax=axes[4], levels=levels , cmap = 'coolwarm', add_colorbar=False,transform=ccrs.PlateCarree())
axes[4].set_title('PFT', fontsize=14);

cbar = fig.colorbar(p1, ax=axes[3:5], orientation='horizontal')
cbar.set_label('Spatial variation of average missing rate $[\%]$', fontsize=14)
cbar.ax.tick_params(labelsize=11)
#------------------------
for i, ax in enumerate(axes):
    if i<3:
        ax.add_feature(cartopy.feature.LAND, facecolor='lightgrey', edgecolor='dimgrey', alpha=0.3)
    else:
        ax.add_feature(cartopy.feature.LAND, facecolor='white', edgecolor='k', alpha=0.8)
        
    gl = ax.gridlines(draw_labels=True,alpha=0.3, linestyle='-')
    gl.right_labels = False
    gl.top_labels = False
    
    gl.xformatter = mticker.FuncFormatter(lambda x, _: f"{abs(x):.0f}° {'E' if x >= 0 else 'W'}")
    gl.yformatter = mticker.FuncFormatter(lambda y, _: f"{abs(y):.0f}° {'N' if y >= 0 else 'S'}")
    
    if i in [1,2,3,4]:
        gl.left_labels = False
        
    gl.xlocator = mticker.FixedLocator([-60,-30,0])
    ax.text(-0.03, 1.03, r'$\textbf{('+ f'{string.ascii_lowercase[i]}' + ')}$', transform=ax.transAxes, 
            size=12, weight='bold')
    
#-----------------------------------------------------------------------------
# temporal missing
ax1 = fig.add_subplot(gs[1, :])

# temporal_perc.CHL.plot(ax=ax1, label='TChla', c = 'lightblue')
# ax1.axhline(y=temporal_perc.mean().CHL.values, color='blue', linestyle='--', label='TChla avg.')
# temporal_perc_mov.CHL.plot(ax=ax1, label='TChla moving avg.', c = 'blue')

# temporal_perc.DIATO.plot(ax=ax1, label='PFT', c = 'lightcoral', alpha=0.8)
# ax1.axhline(y=temporal_perc.mean().DIATO.values, color='darkred', linestyle='--', label='PFT avg.')
# temporal_perc_mov.DIATO.plot(ax=ax1, label='PFT moving avg.', c = 'darkred')



temporal_perc.CHL.plot(ax=ax1, label='CHL', c = 'grey', alpha=0.6)
ax1.axhline(y=temporal_perc.mean().CHL.values, color='black', linestyle='--', label='CHL avg.')
temporal_perc_mov.CHL.plot(ax=ax1, label='CHL moving avg.', c = 'black')

temporal_perc.DIATO.plot(ax=ax1, label='PFT', c = 'lightcoral', alpha=0.6)
ax1.axhline(y=temporal_perc.mean().DIATO.values, color='darkred', linestyle='--', label='PFT avg.')
temporal_perc_mov.DIATO.plot(ax=ax1, label='PFT moving avg.', c = 'darkred')




ax1.set_title('Temporal variation of average missing rate $[\%]$', fontsize=14)
ax1.set_ylabel('')
ax1.set_xlabel('')

legend = ax1.legend(fontsize=10, loc='lower left',ncol=2)
legend.set_alpha(0.3)
for line in legend.get_lines():
    line.set_linewidth(2.0)

ax1.tick_params(axis='y', labelsize=11)
ax1.set_ylim(20,100)

ax1.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[1,4,7,10]))
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y–%m–%d'))
ax1.xaxis.set_minor_locator(mdates.MonthLocator(interval=1))

ax1.text(x=temporal_perc.time.values[-20], y= 84, s = f'{temporal_perc.DIATO.mean().values:.1f}\%',color='saddlebrown', size=14);
ax1.text(x=temporal_perc.time.values[-20], y= 54, s = f'{temporal_perc.CHL.mean().values:.1f}\%',color='black', size=14);

ax1.text(0, 1.05, r'$\textbf{('+ f'{string.ascii_lowercase[5]}' + ')}$', transform=ax1.transAxes, 
        size=12, weight='bold');
gl = ax1.grid(alpha=0.2, linestyle='-')

plt.savefig('fig/spatiotempiral_missing6.png',dpi=300, bbox_inches='tight')

In [None]:
# Create a figure
fig = plt.figure(figsize=(8, 7), constrained_layout=True)

# Create a GridSpec with 2 rows and 5 columns
gs = gridspec.GridSpec(2, 4, height_ratios=[1, 1], figure=fig) 

#----------------------------------------------------------------------------------------------------
# Second row: 5 smaller figures
axes = [fig.add_subplot(gs[0, i], projection=ccrs.PlateCarree()) for i in range(4)]

#--------------------------------------------------------------------------------------------------
# Samples

t1 = '2018-01-01'
t2 = '2019-01-01'
vmin=0.01
vmax=1

p1 = ds.CHL.sel(time=t1).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[0], add_colorbar=False)
axes[0].set_title(f'TChla\n[{t1}]', fontsize=14)
p2 = ds.DIATO.sel(time=t1).plot(norm=LogNorm(vmin=vmin,vmax=vmax), ax=axes[1], add_colorbar=False)
axes[1].set_title(f'PG [e.g. Diatoms]\n[{t1}]', fontsize=14)

p=[p1,p2]
    
cbar = plt.colorbar(p[1], ax=axes[0:2] , orientation='horizontal', shrink=0.6, extend='both')
cbar.set_ticks([0.01, 0.1, 1])
cbar.set_ticklabels([0.01, 0.1, 1])
cbar.set_label(r'Chla Concentration $[mg.m^{-3}]$', fontsize=14)
cbar.ax.tick_params(labelsize=11)

#-----------------------------------------------------------------------------------------------------
# spatial missing
levels = [0,10,20,30,40,50,60,70,80]
p1 = (100 - spatial_perc).CHL.plot.contourf(ax=axes[2], levels=levels, cmap = 'coolwarm_r', add_colorbar=False,transform=ccrs.PlateCarree())
axes[2].set_title(r'Total Chlorophyll-a' + '\n' + r'$\mathbf{(TChla)}$', fontsize=14, loc='center');


p2 = (100 - spatial_perc).DIATO.plot.contourf(ax=axes[3], levels=levels , cmap = 'coolwarm_r', add_colorbar=False,transform=ccrs.PlateCarree())
axes[3].set_title(r'Phytoplankton groups'+ '\n' + r'$\mathbf{(PG)}$', fontsize=14, loc='center');

cbar = fig.colorbar(p1, ax=axes[2:4], orientation='horizontal')
cbar.set_label('Spatial variation of data availability $[\%]$', fontsize=14)
cbar.ax.tick_params(labelsize=11)
#------------------------
for i, ax in enumerate(axes):
    if i<2:
        ax.add_feature(cartopy.feature.LAND, facecolor='lightgrey', alpha=0.5)
        ax.coastlines('110m', color='black', alpha=0.7)
    else:
        ax.add_feature(cartopy.feature.LAND, facecolor='white', alpha=1)
        ax.coastlines('110m', color='black', alpha=0.7)
    gl = ax.gridlines(draw_labels=True,alpha=0.3, linestyle='-')
    gl.right_labels = False
    gl.top_labels = False
    
    if i in [1,2,3]:
        gl.left_labels = False
        
    gl.xlocator = mticker.FixedLocator([-60,-30,0])
    ax.text(-0.05, 1.02, string.ascii_uppercase[i], transform=ax.transAxes, 
            size=16, weight='bold')
    
#-----------------------------------------------------------------------------
# temporal missing
ax1 = fig.add_subplot(gs[1, :])

# temporal_perc.CHL.plot(ax=ax1, label='TChla', c = 'lightblue')
# ax1.axhline(y=temporal_perc.mean().CHL.values, color='blue', linestyle='--', label='TChla avg.')
# temporal_perc_mov.CHL.plot(ax=ax1, label='TChla moving avg.', c = 'blue')

# temporal_perc.DIATO.plot(ax=ax1, label='PFT', c = 'lightcoral', alpha=0.8)
# ax1.axhline(y=temporal_perc.mean().DIATO.values, color='darkred', linestyle='--', label='PFT avg.')
# temporal_perc_mov.DIATO.plot(ax=ax1, label='PFT moving avg.', c = 'darkred')



(100 - temporal_perc).CHL.plot(ax=ax1, label='TChla', c = 'grey', alpha=0.6)
ax1.axhline(y=(100 - temporal_perc).mean().CHL.values, color='black', linestyle='--', label='TChla avg.')
(100 - temporal_perc_mov).CHL.plot(ax=ax1, label='TChla moving avg.', c = 'black')

(100 - temporal_perc).DIATO.plot(ax=ax1, label='PG', c = 'lightcoral', alpha=0.6)
ax1.axhline(y=(100 - temporal_perc).mean().DIATO.values, color='darkred', linestyle='--', label='PG avg.')
(100 - temporal_perc_mov).DIATO.plot(ax=ax1, label='PG moving avg.', c = 'darkred')




ax1.set_title('Temporal variation of average data availability $[\%]$', fontsize=14)
ax1.set_ylabel('')
ax1.set_xlabel('')

legend = ax1.legend(fontsize=10, loc='upper left',ncol=2)
legend.set_alpha(0.3)
for line in legend.get_lines():
    line.set_linewidth(2.0)

ax1.tick_params(axis='y', labelsize=11)
ax1.set_ylim(0,100)

ax1.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[1,4,7,10]))
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax1.xaxis.set_minor_locator(mdates.MonthLocator(interval=1))

ax1.text(x=temporal_perc.time.values[-1], y= 20 , s = f'{(100 - temporal_perc).DIATO.mean().values:.0f}\%',color='saddlebrown', size=14);
ax1.text(x=temporal_perc.time.values[-1], y= 55, s = f'{(100 - temporal_perc).CHL.mean().values:.0f}\%',color='black', size=14);


ax1.text(0, 1.05, string.ascii_uppercase[5], transform=ax1.transAxes, 
        size=16, weight='bold');
gl = ax1.grid(alpha=0.2, linestyle='-')

plt.savefig('fig/spatiotempiral_missing5.png',dpi=600, bbox_inches='tight')

### region missing rate

In [None]:
for region in range(1,11):
    ds = xr.open_dataset(f'/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/{region}/ds_pft.nc').compute()
    print(f'region {region} size: {ds.lon.size} × {ds.lat.size}')
    print(f'region {region} time: {ds.time.size}')
    total_valid_pixels = (ds.mask==1).sum().values * ds.time.size
    missing_rate = (1 - (ds.count() / total_valid_pixels)) * 100
    print(region,'CHL',f'{missing_rate.CHL.values:.1f}')
    print(region,'PFT',f'{missing_rate.DIATO.values:.1f}')

### Percentage of Cross-validation

In [None]:
test_percs = []
dev_percs = []
test_counts = []
dev_counts = []
for region in range(1,11):
    print(region)
    ds_test = xr.open_dataset(f'/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/{region}/ds_pft.nc').compute()
    ds_dev = xr.open_dataset(f'/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/{region}/ds_pft_dev.nc').compute()
    ds_train = xr.open_dataset(f'/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions/{region}/ds_pft_train.nc').compute()
    
    test_perc = (ds_test.sel(time = ds_dev.time.values).count() - ds_dev.count())/ds.count()*100
    test_percs.append(test_perc[params['PFT']].to_pandas().rename(region))
    
    dev_perc = (ds_dev.count() - ds_train.count())/ds_dev.count()*100
    dev_percs.append(dev_perc[params['PFT']].to_pandas().rename(region))
    
    
#     test_count = ds_test.sel(time = ds_dev.time.values).count() - ds_dev.count()
#     test_counts.append(test_count[params['PFT']].to_pandas().rename(region))
#     print(test_count.CHL.values)
    
#     dev_count = ds_dev.count() - ds_train.count()
#     dev_counts.append(dev_count[params['PFT']].to_pandas().rename(region))
#     print(dev_count.CHL.values)

In [None]:
for i in [np.round(i.CHL,1) for i in test_percs]:
    print(i)

In [None]:
pd.concat(CV_percs,axis=1).to_csv(os.path.join(output_dir,'CV_points_rate.csv'))

### TChla & PFT relation

In [None]:
pft_list_unc = ['CHL_uncertainty','DIATO_uncertainty','DINO_uncertainty','HAPTO_uncertainty','GREEN_uncertainty','PROKAR_uncertainty']
pft_list = ['CHL','DIATO','DINO','HAPTO','GREEN','PROKAR','flags']

In [None]:
pft_labels = {'CHL':'Total Chlorophyll a', 
              'DIATO':'Diatoms', 
              'DINO':'Dinoflagellates', 
              'HAPTO':'Haptophyte',
              'GREEN':'Green Algae', 
              # 'PROCHLO':r"$\it{Prochlorococcus}$",
              'PROKAR':'Prokaryotic Phytop.'}

In [None]:
ds = ds_read.where(
    (ds_read.lat > -50)&
    (ds_read.lon > -64)&
    (ds_read.lat < 52)&
    (ds_read.lon < 3),
    drop=True
).sel(time=slice('20160425','20190425'))

ds = ds[pft_list]
# ds = ds.drop('flags')

In [None]:
tchla = ds.CHL
pft = ds.drop('CHL')

In [None]:
pft_sum = pft.to_array(dim='PFT').sum('PFT')
# pft_from_tchla = np.log10(tchla * (pft / pft_sum))
# pft = np.log10(pft)
tchla = np.log10(tchla)
pft_sum = np.log10(pft_sum)
pft_sum = pft_sum.where(np.isfinite(pft_sum))

In [None]:
# tchla = tchla.compute()
# pft_sum = pft_sum.compute()

In [None]:
da_corr = xr.corr(pft_sum, tchla, dim='time')

In [None]:
da_corr.to_netcdf('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/corr_CHL_sumpft_logscale.nc')

In [None]:
da_corr = xr.open_dataarray('/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/corr_CHL_sumpft_logscale.nc').compute()

In [None]:
mask = xr.where(ds.flags.isel(time=0)==0,True, False).compute()

In [None]:
# for pf in pft_list_corr:
#     ds_corr[pf] = ds_corr[pf].where(np.isfinite(ds_corr[pf]))

da_corr = da_corr.where(da_corr <= 1, 1)
da_corr = da_corr.where(da_corr >= -1, -1)
da_corr = da_corr.where(mask)

In [None]:
fig, ax = plt.subplots(1,1, constrained_layout=True,figsize=(4,5), subplot_kw={'projection': ccrs.PlateCarree()})
# ax = ax.flatten()

vmin = np.nanpercentile(da_corr,1)
vmax = np.nanpercentile(da_corr,99)
cbar_levels=[-1,-0.7,-0.5,-0.3,0.3,0.5,0.7,1]

diver_color = LinearSegmentedColormap.from_list("", ["maroon", "white", "darkcyan"])
# diver_color = LinearSegmentedColormap.from_list("", ["cadetblue", "firebrick", "cadetblue"])

p = da_corr.plot.contourf(ax=ax,levels= cbar_levels, cmap = diver_color, add_colorbar=False)
ax.set_title('Observed TChla vs. sum of all PFTs')
ax.set_aspect('equal')
ax.coastlines()
gl = ax.gridlines(draw_labels=True,alpha=0.3, linestyle='--')
gl.right_labels = False
gl.top_labels = False
ax.set_facecolor('lightgrey')
    
cbar = fig.colorbar(p, ax=ax, orientation='vertical', shrink=0.7)
# cbar.set_ticks(cbar_levels)
# cbar.set_ticklabels(cbar_levels)
cbar.set_label(r'Pearson Correlation Coefficient',fontsize=12)

plt.savefig('fig/corr_CHL_sumpft_logscale.jpg',dpi=600, bbox_inches='tight')

## L3

### L3 vs HPLC

In [None]:
# %%time
# l3_points_list = []
# HPLC_points_list = []

# for i in tqdm(range(len(HPLC_subset.time.values))):
#     # get the deteils of the HPLC point
#     HPLC_point = HPLC_subset.isel(time=i,drop=False)
    
#     l3_point = ds_subset.sel(time=HPLC_point.time.values ,lat=HPLC_point.lat.values,lon=HPLC_point.lon.values, method='nearest')
#     # get the index of the middle pixel
    
#     lat_center = (ds_subset.lat==l3_point.lat).argmax().values
#     lon_center = (ds_subset.lon==l3_point.lon).argmax().values
#     # Compute the CV
#     s33 = ds.sel(time=HPLC_point.time.values).isel(lat = slice(lat_center-1,lat_center+2),lon= slice(lon_center-1,lon_center+2))
#     cv = abs(s33.std(ddof=1)/s33.mean()).where(s33.count()>=5)
#     s33_mean = s33.mean().where(cv<0.2)
#     # assign the middle pixel lat and lon to the matchup
#     s33_mean = s33_mean.assign_coords({'lat':l3_point.lat})
#     s33_mean = s33_mean.assign_coords({'lon':l3_point.lon})
        
#     HPLC_points_list.append(HPLC_point)
#     l3_points_list.append(l3_point)
    
# HPLC_points = xr.concat(HPLC_points_list, dim='time')
# l3_points = xr.concat(l3_points_list, dim='time')
# HPLC_points_list = None
# l3_points_list = None

In [None]:
# HPLC_validation = (l3_points - HPLC_points)
# RMSE = np.sqrt(np.mean(np.square(HPLC_validation)))
# RMSE['Total'] = np.sqrt(np.mean(np.square(HPLC_validation.to_array())))
# RMSE_df = RMSE.to_array('PFT').rename('RMSE').to_dataframe()
# RMSE_df.loc['Count'] =  HPLC_validation.to_array().count().values.astype('int')

In [None]:
# RMSE_df.to_csv(os.path.join(output_dir, str(region), f'RMSE/RMSE_L3_HPLC_experiment.nc'))

### L3 vs ACS

In [None]:
# %%time
# l3_points_list = []
# ACS_points_list = []

# for i in tqdm(tqdm(range(len(ACS_subset.time.values))),position=0, leave=True):
#     # get the deteils of the ACS point
#     ACS_point = ACS_subset.isel(time=i,drop=False)
    
#     l3_point = ds_subset.sel(time=ACS_point.time.values ,lat=ACS_point.lat.values,lon=ACS_point.lon.values, method='nearest')
    
#     # get the index of the middle pixel
#     lat_center = (ds_subset.lat==l3_point.lat).argmax().values
#     lon_center = (ds_subset.lon==l3_point.lon).argmax().values
    
#     # Compute the CV
#     s33 = ds_subset.sel(time=ACS_point.time.values ).isel(lat = slice(lat_center-1,lat_center+2),lon= slice(lon_center-1,lon_center+2))
#     cv = abs(s33.std(ddof=1)/s33.mean()).where(s33.count()>=5);
#     s33_mean = s33.mean().where(cv<0.2);
#     # assign the middle pixel lat and lon to the matchup
#     s33_mean = s33_mean.assign_coords({'lat':l3_point.lat})
#     s33_mean = s33_mean.assign_coords({'lon':l3_point.lon})
        
#     ACS_points_list.append(ACS_point)
#     l3_points_list.append(l3_point)
    

# ACS_points = xr.concat(ACS_points_list, dim='time')
# l3_points = xr.concat(l3_points_list, dim='time')
# ACS_points_list = None
# l3_points_list = None

In [None]:
# ACS_validation = (l3_points - ACS_points)
# RMSE = np.sqrt(np.mean(np.square(ACS_validation)))
# RMSE['Total'] = np.sqrt(np.mean(np.square(ACS_validation.to_array())))
# RMSE_df = RMSE.to_array('PFT').rename('RMSE').to_dataframe()
# RMSE_df.loc['Count'] =  ACS_validation.to_array().count().values.astype('int')

In [None]:
# RMSE_df.to_csv(os.path.join(output_dir, str(region), f'RMSE/RMSE_L3_ACS_experiment.nc'))