In [None]:
'''
Main program
'''
from pathlib import Path
import os, shutil
import zipfile
import tarfile
import glob

import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import xarray as xr
import logging

from s2driver import driver_S2_SAFE as S2
import grs
from grs import product, acutils, utils, cams_product, l2a_product
from grs.fortran.grs import main_algo as grs_solver

opj = os.path.join
grs.__version__

In [None]:

file = '/media/harmel/vol1/Dropbox/satellite/S2/L1C/S2B_MSIL1C_20220731T103629_N0400_R008_T31TFJ_20220731T124834.SAFE'
ofile = '/media/harmel/vol1/Dropbox/satellite/S2/L2A/S2B_MSIL2Agrs_20220731T103629_N0400_R008_T31TFJ_20220731T124834.nc'
resolution = 60
cams_dir = '/media/harmel/vol1/Dropbox/satellite/S2/cnes/CAMS'
allpixels=False

In [None]:
bandIds = range(13)
logging.info('Open raw image and compute angle parameters')
l1c = S2.s2image(file, band_idx=bandIds, resolution=resolution)
l1c.load_product()

In [None]:
logging.info('pass raw image as grs product object')
prod = product(l1c.prod)

In [None]:
logging.info('get CAMS auxilliary data')
cams = cams_product(prod, dir=cams_dir)

In [None]:
logging.info('loading look-up tables')
lutf = acutils.lut(prod.band_names)
lutc = acutils.lut(prod.band_names)
lutf.load_lut(prod.lutfine, prod.sensordata.indband)
lutc.load_lut(prod.lutcoarse, prod.sensordata.indband)

In [None]:
logging.info('compute gaseous transmittance from cams data')
gas_trans = acutils.gaseous_transmittance(prod, cams)
Tg_raster = gas_trans.get_gaseous_transmittance()

logging.info('correct for gaseous absorption')
prod.raster['bands'] = prod.raster.bands / Tg_raster
prod.raster.bands.attrs['gas_absorption_correction'] = True

In [None]:
logging.info('compute spectral index (e.g., NDWI)')

green = prod.raster.bands.sel(wl=prod.b565)
nir = prod.raster.bands.sel(wl=prod.b865)
swir = prod.raster.bands.sel(wl=prod.b1600)
b2200 = prod.raster.bands.sel(wl=prod.b2200)

ndwi = (green - nir) / (green + nir)
ndwi_swir = (green - swir) / (green + swir)

prod.raster['ndwi'] = ndwi
prod.raster.ndwi.attrs = {
    'description': 'Normalized difference spectral index between bands at ' + str(prod.b565) + ' and ' + str(
        prod.b865) + ' nm', 'units': '-'}
prod.raster['ndwi_swir'] = ndwi_swir
prod.raster.ndwi_swir.attrs = {
    'description': 'Normalized difference spectral index between bands at ' + str(prod.b565) + ' and ' + str(
        prod.b1600) + ' nm', 'units': '-'}

if allpixels:
    masked_raster = prod.raster.bands
else:
    logging.info('apply water masking')
    masked_raster = prod.raster.bands.where(ndwi > prod.ndwi_threshold). \
        where(b2200 < prod.sunglint_threshold). \
        where(ndwi_swir > prod.green_swir_index_threshold)

In [None]:
logging.info('round angles for speed up lut interpolation')
def rounding(xarr, resol=1):
    vals = np.unique(xarr.round(resol))
    return vals[~np.isnan(vals)]

sza_ = rounding(prod.raster.sza, 1)
azi_ = rounding((180 - prod.raster.raa) % 360, 0)
vza_ = rounding(prod.raster.vza, 1)



## Set final parameters for grs processing


In [None]:
logging.info('set final parameters for grs processing')
wl_process = prod.wl_process
eps_sunglint = prod.sensordata.rg
rot = prod.sensordata.rot
rrs = prod.rrs
Nx = prod.width
Ny = prod.height

logging.info('slice raster for desired wavelengths')
raster = masked_raster.sel(wl=wl_process)
band_rad = raster.values
vza = prod.raster.sel(wl=wl_process).vza.values
sza = prod.raster.sel(wl=wl_process).sza.values
razi = prod.raster.sel(wl=wl_process).raa.values
solar_irr = prod.solar_irradiance.sel(wl=wl_process).values

logging.info('get/set aerosol parameters')
aotlut = np.array(lutf.aot, dtype=prod.type)
fine_refl = lutf.refl.interp(vza=vza_).interp(azi=azi_).interp(sza=sza_)
coarse_refl = lutc.refl.interp(vza=vza_).interp(azi=azi_).interp(sza=sza_)
lut_shape = fine_refl.shape
fine_Cext = lutf.Cext
coarse_Cext = lutc.Cext
aot_tot_cams_res = cams.cams_aod.interp(wavelength=wl_process)
aot_sca_cams_res = aot_tot_cams_res * cams.cams_ssa.interp(wavelength=wl_process)
aot_tot = aot_tot_cams_res.interp(x=raster.x, y=raster.y)
aot_sca = aot_sca_cams_res.interp(x=raster.x, y=raster.y)
aot550guess = cams.raster.aod550.interp(x=raster.x, y=raster.y)
fcoef = np.full((prod.height, prod.width), 0.5)


# TODO implement pre-masking, now set to zero
maskpixels = np.zeros((prod.height, prod.width))

logging.info('get pressure full raster')
pressure_corr = cams.raster.sp.interp(x=raster.x, y=raster.y) * 1e-2 / prod.pressure_ref


## Run grs processing

In [None]:
logging.info('run grs process')
p = grs_solver.grs.main_algo(Nx, Ny, *lut_shape,
                             aotlut, sza_, azi_, vza_,
                             fine_refl, coarse_refl, fine_Cext, coarse_Cext,
                             vza, sza, razi, band_rad, maskpixels,
                             wl_process, pressure_corr, eps_sunglint, solar_irr, rot,
                             aot_tot, aot_sca, aot550guess, fcoef, rrs)

## Construct l2a product

In [None]:
logging.info('construct final product')
rcorr, rcorrg, aot550pix, brdfpix = p
Rrs = xr.DataArray(rcorr, coords=raster.coords, name='Rrs')
Rrs_g = xr.DataArray(rcorrg, coords=raster.coords, name='Rrs_g')
aot550 = xr.DataArray(aot550pix, coords={'y': raster.y, 'x': raster.x}, name='aot550')
brdfg = xr.DataArray(brdfpix, coords={'y': raster.y, 'x': raster.x}, name='BRDFg')
l2_prod = xr.merge([Rrs, Rrs_g, aot550, brdfg])

l2_prod = l2_prod.drop_vars('pressure')
l2a = l2a_product(prod, l2_prod, cams, gas_trans)

In [None]:
prod.raster

In [None]:
l2a.l2_prod

## Export l2a product into netcdf

In [None]:
l2a.to_netcdf(ofile)


## Plot and interact


In [None]:
variable = 'Rrs' 
raster = l2a.l2_prod[variable]#L2grs #masked[param] 
vmax = 0.03
cmap='RdBu_r'
cmap='Spectral_r'

In [None]:
from holoviews import streams
import holoviews as hv
import panel as pn
import param
import numpy as np
import xarray as xr
hv.extension('bokeh')
from holoviews import opts


opts.defaults(
    opts.GridSpace(shared_xaxis=True, shared_yaxis=True),
    opts.Image(cmap='binary_r', width=800, height=700),
    opts.Labels(text_color='white', text_font_size='8pt', text_align='left', text_baseline='bottom'),
    opts.Path(color='white'),
    opts.Spread(width=900),
    opts.Overlay(show_legend=True))
# set the parameter for spectra extraction
hv.extension('bokeh')
pn.extension()



third_dim = 'wl'

wl= raster.wl.data
Nwl = len(wl)
ds = hv.Dataset(raster.persist())
im= ds.to(hv.Image, ['x', 'y'], dynamic=True).opts(cmap=cmap ,colorbar=True,clim=(0.00,vmax)).hist(bin_range=(0,0.2)) 

polys = hv.Polygons([])
box_stream = hv.streams.BoxEdit(source=polys)
dmap, dmap_std=[],[]

def roi_curves(data,ds=ds):    
    if not data or not any(len(d) for d in data.values()):
        return hv.NdOverlay({0: hv.Curve([],'Wavelength (nm)', variable)})

    curves,envelope = {},{}
    data = zip(data['x0'], data['x1'], data['y0'], data['y1'])
    for i, (x0, x1, y0, y1) in enumerate(data):
        selection = ds.select(x=(x0, x1), y=(y0, y1))
        mean = selection.aggregate(third_dim, np.mean).data
        std = selection.aggregate(third_dim, np.std).data
        wl = mean.wl

        curves[i]= hv.Curve((wl,mean[variable]),'Wavelength (nm)', variable) 

    return hv.NdOverlay(curves)


# a bit dirty to have two similar function, but holoviews does not like mixing Curve and Spread for the same stream
def roi_spreads(data,ds=ds):    
    if not data or not any(len(d) for d in data.values()):
        return hv.NdOverlay({0: hv.Curve([],'Wavelength (nm)', variable)})

    curves,envelope = {},{}
    data = zip(data['x0'], data['x1'], data['y0'], data['y1'])
    for i, (x0, x1, y0, y1) in enumerate(data):
        selection = ds.select(x=(x0, x1), y=(y0, y1))
        mean = selection.aggregate(third_dim, np.mean).data
        std = selection.aggregate(third_dim, np.std).data
        wl = mean.wl

        curves[i]=  hv.Spread((wl,mean[variable],std[variable]),fill_alpha=0.3)

    return hv.NdOverlay(curves)

mean=hv.DynamicMap(roi_curves,streams=[box_stream])
std =hv.DynamicMap(roi_spreads, streams=[box_stream])    
hlines = hv.HoloMap({wl[i]: hv.VLine(wl[i]) for i in range(Nwl)},third_dim )

widget = pn.widgets.RangeSlider(start=0, end=vmax,step=0.001)

jscode = """
    color_mapper.low = cb_obj.value[0];
    color_mapper.high = cb_obj.value[1];
"""
link = widget.jslink(im, code={'value': jscode})

hv.output(widget_location='top_left')

# visualize and play
graphs = ((mean* std *hlines).relabel(variable))
layout = (im * polys +graphs    ).opts(opts.Image(tools=['hover']),
    opts.Curve(width=750,height=500, framewise=True,xlim=(400,1140),tools=['hover']), 
    opts.Polygons(fill_alpha=0.2, color='green',line_color='black'), 
    opts.VLine(color='black')).cols(2)
layout 