### import Modules

In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1"

import xarray as xr
import numpy as np
import pandas as pd
import dask

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, LinearSegmentedColormap
import matplotlib.dates as mdates
from matplotlib import gridspec
import cartopy.crs as ccrs

from tqdm import tqdm

from function_tools import *
from pft_params import *

### Dask setup

In [2]:
# SLURM Cluster
# cluster, client = dask_slurm_cluster(queue='smp', cores=32, scale=40)

# Distributed Cluster
# client = dask_distributed_client(n_workers=8, threads_per_worker=None)

In [3]:
# cluster.close()
# client.close()

# Input Data

In [4]:
# PS113

file_dir = '/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/data/*'
output_dir = '/albedo/work/projects/p_phytooptics/emehdipo/PS113/CMEMS/regions'
sst_file = '/albedo/work/projects/p_phytooptics/emehdipo/PS113/GHRSST/METOFFICE-GLO-SST-L4-NRT-OBS-SST-V2_2016_2019.nc'
regions = pd.read_csv('regions.csv', index_col=0)

start_date = np.datetime64('2016-04-25')
end_date = np.datetime64('2019-04-25')

In [None]:
ds_read = xr.open_mfdataset(file_dir, chunks='auto')
ds_read

# Regions

## Extract

In [None]:
def crop_ds(ds_read, regions, idx, start_date, end_date):
    
    ds = ds_read.where(
        (ds_read.lat > regions.loc[idx].lat_min)&
        (ds_read.lon > regions.loc[idx].lon_min)&
        (ds_read.lat < regions.loc[idx].lat_max)&
        (ds_read.lon < regions.loc[idx].lon_max),
        drop=True
    ).sel(time=slice(start_date, end_date))
        
    return ds

In [None]:
def read_sst(sst_file, ds):
    ## adding SST
    sst = xr.open_dataset(sst_file).compute()
    sst = sst.rename({'latitude':'lat','longitude':'lon'})
    
    sst_mean = sst.analysed_sst.mean()
    sst_std = sst.analysed_sst.std(ddof=1)
    sst_norm = (sst.analysed_sst - sst_mean)/sst_std
    sst_unc = sst.analysis_error/sst_std
    
    
    sst_norm_interp = sst_norm.interp_like(ds.CHL, method = 'nearest').astype('float32')
    sst_unc_interp = sst_unc.interp_like(ds.CHL, method = 'nearest').astype('float32')
    
    return sst_norm_interp, sst_unc_interp

In [None]:
def extract_region(
    regions_list, regions ,sst_file ,start_date, end_date, params, output_dir, 
):
    
    for idx in tqdm(regions_list):
        ds = crop_ds(ds_read, regions, idx, start_date, end_date)
    
        ## All data
        ds_all = ds[params['ALL_data']]
        ds_all = ds_all.compute()

        ## PFT
        ds_pft = ds_all[params['PFT']]
        ds_pft = np.log10(ds_pft)

        ## UNC Relative
        ds_unc_rel = ds_all[params['UNC']]

        ## UNC SD
        ds_unc_sd = unc_transform(ds_unc_rel)
        
        ## adding SST
        sst_norm_interp, sst_unc_interp = read_sst(sst_file, ds)
        
        ds_pft['sst'] = sst_norm_interp
        ds_unc_sd['sst_uncertainty'] = sst_unc_interp
        
            
        ## mask
        mask = xr.where(
            (ds_all.flags.isel(time=0) == 1)&
            (np.isnan(sst_norm_interp.isel(time=0)))
            , 0, 1)
        
        ds_pft = ds_pft.where(mask==1)
        ds_unc_sd=ds_unc_sd.where(mask==1)
        ds_unc_rel=ds_unc_rel.where(mask==1)
        
        ds_pft['mask']=mask
        ds_unc_rel['mask']=mask
        ds_unc_sd['mask']=mask
        
        # ## Save
        # ds_pft.to_netcdf(os.path.join(output_dir,str(idx),'ds_pft.nc'))
        # ds_unc_rel.to_netcdf(os.path.join(output_dir,str(idx),'ds_unc_rel.nc'))
        # ds_unc_sd.to_netcdf(os.path.join(output_dir,str(idx),'ds_unc_sd.nc'))
        
        ds_mrege  = xr.merge([ds_pft, ds_unc_sd])
        ds_mrege.to_netcdf(os.path.join(output_dir,str(idx),'ds_pft.nc'))

In [None]:
extract_region(
    range(10,11), regions ,sst_file=sst_file ,start_date=start_date, end_date=end_date, params=params, output_dir=output_dir,
)

## Fill and Add cloud

In [None]:
def remove_missing_dates(perc, ds, params):
    # remove dates with less than perc% data
    ds_sel = ds[params['PFT']]
    number_pft = len(params['PFT'])
    percent = np.isfinite(ds_sel.to_array('pft')).sum(dim=['lon','lat','pft'])/ (ds.mask.sum().values*number_pft) * 100
    time_sel = percent > perc # remove less and 5% data
    
    # Keep the expedition dates
    expedition_time_mask = (time_sel.time >= params['expedition_start_date']) & (time_sel.time <= params['expedition_end_date'])
    time_sel[expedition_time_mask]= True
    
    return ds.sel(time=time_sel)

In [None]:
def NaN_fill(ds, fill_value=-999):
    ds_filled = xr.Dataset()

    for i in list(ds.data_vars):
        ds_filled[i] = ds[i].fillna(fill_value)
        ds_filled[i].attrs['_FillValue'] = fill_value
    
    ds_filled['mask'] = ds.mask
    return ds_filled

In [None]:
def cloud_plot(idx, ds, ds_addcloud, cloud_mask, clouded_time, labels, extension):

    for t in tqdm(range(len(cloud_mask.time))):
        # Create a figure and a gridspec
        fig = plt.figure(figsize=(17, 8), tight_layout=True)
        gs = gridspec.GridSpec(2, 3)
        fig.suptitle(f'Observation and added cloud - Cluster {str(idx)}')

        # Create subplots
        ax=[]
        ax.append(fig.add_subplot(gs[0, 0]))
        ax.append(fig.add_subplot(gs[:2,1]))
        ax.append(fig.add_subplot(gs[0, 2]))
        ax.append(fig.add_subplot(gs[1, 0]))
        ax.append(fig.add_subplot(gs[1, 2]))


        #Dates
        data_date = ds.DIATO.sel(time=clouded_time).isel(time=t).time.values.astype('datetime64[D]')
        cloud_date = cloud_mask.isel(time=t).time.values.astype('datetime64[D]')

        ## DIATO
        vmin = 10**ds.DIATO.sel(time=clouded_time).quantile(0.01).values
        vmax = 10**ds.DIATO.sel(time=clouded_time).quantile(0.99).values
        diato_label = labels[(labels>vmin)&(labels<vmax)]

        (10**ds.DIATO.sel(time=clouded_time).isel(time=t)).plot(ax=ax[0],norm=LogNorm(vmin=vmin, vmax=vmax), cmap='viridis')
        cloud_mask.isel(time=t).plot(ax=ax[1], cmap='Blues_r', add_colorbar=False)
        (10**ds_addcloud.DIATO.sel(time=clouded_time).isel(time=t)).plot(ax=ax[2],norm=LogNorm(vmin=diato_label[0], vmax=diato_label[-1]),cmap='viridis')

        ax[0].set_aspect('equal')
        ax[1].set_aspect('equal')
        ax[2].set_aspect('equal')
        ax[3].set_aspect('equal')
        ax[4].set_aspect('equal')

        ax[0].set_title(f'Diatom Obs. [{data_date}]')
        ax[1].set_title(f'Cloud mask with modification derived from [{cloud_date}]')
        ax[2].set_title(f'Diatom Obs. + cloud [{cloud_date}]')

        ## CHL
        vmin = 10**ds.CHL.sel(time=clouded_time).quantile(0.02).values
        vmax = 10**ds.CHL.sel(time=clouded_time).quantile(0.98).values
        chl_label = labels[(labels>vmin)&(labels<vmax)]

        (10**ds.CHL.sel(time=clouded_time).isel(time=t)).plot(ax=ax[3],norm=LogNorm(vmin=vmin, vmax=vmax),cmap='viridis')
        (10**ds_addcloud.CHL.sel(time=clouded_time).isel(time=t)).plot(ax=ax[4],norm=LogNorm(vmin=chl_label[0], vmax=chl_label[-1]),cmap='viridis')

        ax[3].set_title(f'TChla Obs. [{data_date}]')
        ax[4].set_title(f'TChla Obs. + cloud [{cloud_date}]')

        plt.savefig(f'fig/regions/{str(idx)}/{extension}_addedcloud_{data_date}.jpg',dpi=600, bbox_inches='tight')
        plt.close()

In [None]:
def Addcloud_monthly(
    ds,
    params,
    cv_number_monthly=1,
    threshold = 10,
    keep_expedition_dates = True,
):
    
    ## create the lists
    # times
    cloud_time = []
    clouded_time = []
    # masks
    cloud_mask_lowest = []
    cloud_mask = []
    cloud_mask_chl = []
    
    
    ds_addcloud = ds.copy(deep=True)
    ds_addcloud = ds_addcloud.drop('mask')
    
    # Compute layer of any_missing
    any_missing = np.isnan(ds[params['PFT']]).to_array(dim='var').any(dim='var')
    
    for i in range(1,13):
        ## Separate Monthly Data
        monthly_data = (~any_missing).where(ds.time.dt.month==i, drop=True) 
        
        monthly_percent = monthly_data.sum(dim=['lon','lat']) / ds.mask.sum().values * 100
        monthly_percent = monthly_percent.where(monthly_percent > 5, drop=True) # remove less and 5% data

        ## Cloud dates
        monthly_percent_5p = monthly_percent[monthly_percent > np.nanpercentile(monthly_percent,threshold)]
        monthly_lowest_time = (monthly_percent_5p.sortby(monthly_percent_5p)[:cv_number_monthly]).time.values
        cloud_time.append(monthly_lowest_time)
    
        #not removing data from the expedition duration
        if keep_expedition_dates==True:
            monthly_percent = monthly_percent.where(
                (monthly_percent.time<params['expedition_start_date'] - params['delta']) | 
                (monthly_percent.time>params['expedition_end_date']   + params['delta']) ,
                drop=True,
            )

        ## Clouded dates
        # monthly_percent_95p = monthly_percent[monthly_percent < np.nanpercentile(monthly_percent,100-threshold)]
        monthly_percent_95p = monthly_percent[monthly_percent < np.nanpercentile(monthly_percent,100-0)]
        monthly_highest_time = (monthly_percent_95p.sortby(monthly_percent_95p)[-cv_number_monthly:]).time.values
        clouded_time.append(monthly_highest_time)
        
        # masks
        mask = any_missing.sel(time=monthly_lowest_time) ## mask the places that any of the data is missing
        mask_chl = np.isnan(ds.CHL.sel(time=monthly_lowest_time))

        base = (~any_missing).sel(time=monthly_highest_time) ## consider Base if all the data present.
        
        monthly_mask = mask & base.values
        monthly_mask_chl = mask_chl & base.values

        cloud_mask_lowest.append(mask)
        cloud_mask.append(monthly_mask)
        cloud_mask_chl.append(monthly_mask_chl)

        ## add cloud
        ds_addcloud.loc[{'time':monthly_highest_time}] = ds_addcloud.sel(time=monthly_highest_time).where(~monthly_mask.values)
        
        
    cloud_time = np.concatenate(cloud_time)
    clouded_time = np.concatenate(clouded_time)

    cloud_mask = xr.concat(cloud_mask, dim='time')
    cloud_mask_chl = xr.concat(cloud_mask_chl, dim='time')
    cloud_mask_lowest = xr.concat(cloud_mask_lowest, dim='time')
    ds_addcloud['mask'] = ds.mask  
    cloud_mask['sea_mask'] = ds.mask
    
    # cloud_mask_stack = cloud_mask.stack(point=['lon','lat'])
    # cloud_mask_stack = cloud_mask_stack.where(cloud_mask_stack.sea_mask==1, drop=True)
    cloud_mask_pfts = xr.concat([cloud_mask]*len(params['PFT']), dim='pft') # extend the cloud_mask to the number of pfts
    
    cloud_mask_stack = cloud_mask_pfts.stack(points=['pft','lon','lat'])
    cloud_mask_stack = cloud_mask_stack.where(cloud_mask_stack.sea_mask==1, drop=True)
    
    img_num = []
    pixel_num = []
    for t_idx, t in enumerate(cloud_mask_stack.time.values):
        pixel_num_time = np.argwhere(cloud_mask_stack.sel(time=t).values).squeeze()
        pixel_num.append(pixel_num_time + 1)

        img_num_time = [(ds.time==clouded_time[t_idx]).argmax(dim='time').values+1] * len(pixel_num_time)
        img_num.append(img_num_time)

    pixel_num = np.concatenate(pixel_num)
    img_num = np.concatenate(img_num)
    
    ds_cloudindex = xr.Dataset()
    ds_cloudindex['clouds_index'] = xr.DataArray(data=np.vstack([pixel_num, img_num]),dims=['index','nbpoints'])
        
    return ds_addcloud, cloud_time, clouded_time , cloud_mask, cloud_mask_chl, ds_cloudindex

In [None]:
dev_nums = []
dev_percs = []
train_nums = []
train_percs = []

for idx in range(1,11):
# for idx in [1]:
    print('Cluster', idx)
    ## Read
    ds = xr.open_dataset(os.path.join(output_dir,str(idx),'ds_pft.nc')).compute()
    ds = remove_missing_dates(2, ds, params)
    print(f'number of dates region {idx}:',ds.time.size)
    
    # ---------------------------------------------------------------------------------
    ## DEV
    ds_dev, cloud_time_dev, clouded_time_dev, cloud_mask_dev, cloud_mask_chl_dev, ds_cloudindex_dev = Addcloud_monthly(
        ds.drop(['sst','sst_uncertainty']),
        params=params,
        cv_number_monthly=1,
        threshold=40,
        keep_expedition_dates = True,
    )
    ds_dev['sst'] = ds.sst.astype('float32')
    ds_dev['sst_uncertainty'] = ds.sst_uncertainty.astype('float32')
    
    ## Report
    dev_num = (ds.count() - ds_dev.count())[params['PFT']].to_array(dim='PFT', name=idx).to_dataframe()
    dev_nums.append(dev_num)

    dev_perc = ((1 - (ds_dev.count()/ds.count()))*100)[params['PFT']].to_array(dim='PFT', name=idx).to_dataframe()
    dev_percs.append(dev_perc)
    
    # fill NaN values
    ds_dev_filled = NaN_fill(ds_dev, fill_value=-999)
    
    # ---------------------------------------------------------------------------------
    ## Train
    ds_train, cloud_time_train, clouded_time_train, cloud_mask_train, cloud_mask_chl_train, ds_cloudindex_train = Addcloud_monthly(
        ds_dev.drop(['sst','sst_uncertainty']), 
        params=params,
        cv_number_monthly=1,
        threshold=10,
        keep_expedition_dates = True,
    )
    
    ds_train['sst'] = ds.sst.astype('float32')
    ds_train['sst_uncertainty'] = ds.sst_uncertainty.astype('float32')
    
    train_num = (ds_dev.count() - ds_train.count())[params['PFT']].to_array(dim='PFT', name=idx).to_dataframe()
    train_nums.append(train_num)
    
    train_perc = ((1 - (ds_train.count()/ds_dev.count()))*100)[params['PFT']].to_array(dim='PFT', name=idx).to_dataframe()
    train_percs.append(train_perc)

    # ---------------------------------------------------------------------------------
    ## Save DEV
    pd.DataFrame(
        {'cloud_date':cloud_time_dev, 'clouded_date':clouded_time_dev}
                ).to_csv(os.path.join(output_dir,str(idx),'cloud_date_dev.csv'), index=False)

    ds_dev.to_netcdf(os.path.join(output_dir,str(idx),'ds_pft_dev.nc'))
    ds_dev_filled.to_netcdf(os.path.join(output_dir,str(idx),'ds_pft_dev_filled.nc'))
    cloud_mask_dev.to_netcdf(os.path.join(output_dir,str(idx),'cloud_mask_dev.nc'))
    cloud_mask_chl_dev.to_netcdf(os.path.join(output_dir,str(idx),'cloud_mask_chl_dev.nc'))
    ds_cloudindex_dev.to_netcdf(os.path.join(output_dir,str(idx),'cloud_index_dev.nc'))
    dev_num.to_csv(os.path.join(output_dir,str(idx),'dev_num.csv'))
    dev_perc.to_csv(os.path.join(output_dir,str(idx),'dev_perc.csv'))
    
    # ---------------------------------------------------------------------------------
    ## Save Train
    pd.DataFrame(
        {'cloud_date':cloud_time_train, 'clouded_date':clouded_time_train}
                ).to_csv(os.path.join(output_dir,str(idx),'cloud_date_train.csv'), index=False)
    
    ds_train.to_netcdf(os.path.join(output_dir,str(idx),'ds_pft_train.nc'))
    cloud_mask_train.to_netcdf(os.path.join(output_dir,str(idx),'cloud_mask_train.nc'))
    cloud_mask_chl_train.to_netcdf(os.path.join(output_dir,str(idx),'cloud_mask_chl_train.nc'))
    ds_cloudindex_train.to_netcdf(os.path.join(output_dir,str(idx),'cloud_index_train.nc'))
    train_num.to_csv(os.path.join(output_dir,str(idx),'train_num.csv'))
    train_perc.to_csv(os.path.join(output_dir,str(idx),'train_perc.csv'))
    
    # ---------------------------------------------------------------------------------
    # Plot
    cloud_plot(idx, ds, ds_dev, cloud_mask_dev, clouded_time_dev, params['plot_labels'], extension = 'dev')
    cloud_plot(idx, ds, ds_train, cloud_mask_train,clouded_time_train, params['plot_labels'], extension = 'train')
    # ---------------------------------------------------------------------------------


dev_nums = pd.concat(dev_nums,axis=1)
dev_percs = pd.concat(dev_percs,axis=1)
train_nums = pd.concat(train_nums,axis=1)
train_percs= pd.concat(train_percs,axis=1)

dev_nums.to_csv(os.path.join(output_dir,'dev_nums.csv'))
dev_percs.to_csv(os.path.join(output_dir,'dev_percs.csv'))
train_nums.to_csv(os.path.join(output_dir,'train_nums.csv'))
train_percs.to_csv(os.path.join(output_dir,'train_percs.csv'))

## Region Observation Uncertainty

In [5]:
# for idx in range(len(regions)):
for idx in [9]:
    print('Cluster', idx)

    ds = xr.open_dataset(os.path.join(output_dir,str(idx),'ds_pft_train.nc')).compute()
    
    # unc_sd_mean = 
    # unc_sd_median = ds.drop('mask').median().to_array('PFT').rename('median')
    # unc_sd_sd = ds.drop('mask').std().to_array('PFT').rename('std')
    # err_df = pd.concat([unc_sd_mean.to_dataframe() , unc_sd_median.to_dataframe(),unc_sd_sd.to_dataframe()], axis=1)
    # # err_df.to_csv(os.path.join(output_dir,str(idx),'unc_sd_description.csv'))

Cluster 9


In [16]:
unc_sd_mean  = ds[params['UNC']].mean()
unc_sd_mean

In [15]:
unc_sd_median = ds[params['UNC']].median()
unc_sd_median

In [24]:
unc_sd_median.to_pandas()

CHL_uncertainty       0.118000
DIATO_uncertainty     0.381079
DINO_uncertainty      0.327400
HAPTO_uncertainty     0.342107
GREEN_uncertainty     0.254379
PROKAR_uncertainty    0.299965
dtype: float64

In [32]:
ds.sst_uncertainty.median()