In [None]:
'''
Main program
'''
from pathlib import Path
import os, shutil
import zipfile
import tarfile
import glob

import matplotlib
import matplotlib.pyplot as plt


import numpy as np
import xarray as xr
import logging

import memory_profiler
%load_ext memory_profiler

from s2driver import driver_S2_SAFE as S2

from grs import product, acutils,cams_product, l2a_product,__version__
from grs.fortran.grs import main_algo as grs_solver

opj = os.path.join
__version__

In [None]:
file = '/sat_data/satellite/sentinel2/L1C/31TFJ/S2A_MSIL1C_20201004T104031_N0209_R008_T31TFJ_20201004T125253.SAFE'
file = '/media/harmel/vol1/Dropbox/satellite/S2/L1C/S2B_MSIL1C_20220731T103629_N0400_R008_T31TFJ_20220731T124834.SAFE'
cams_file = '/media/harmel/vol1/Dropbox/satellite/S2/cnes/CAMS/2022-07-31-cams-global-atmospheric-composition-forecasts.nc'
file_nc = file.replace('.SAFE', '.nc')


bandIds = range(13)
resolution = 60

if not os.path.exists(file_nc):
    l1c = S2.s2image(file, band_idx=bandIds, resolution=resolution)
    print(l1c.crs)
    l1c.load_product()
    prod = product(l1c.prod)
    encoding = {'bands': {'dtype': 'int16', 'scale_factor': 0.00001, 'add_offset': .3, '_FillValue': -32768},
                'vza': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999},
                'raa': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999},
                'sza': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999}}
    l1c.prod.to_netcdf(file_nc, encoding=encoding)
    l1c.prod.close()

else:
    prod = product(xr.open_dataset(file_nc))
prod.raster

In [None]:

##################################
# GET ANCILLARY DATA (Pressure, O3, water vapor, NO2...
##################################

cams = cams_product(prod, cams_file=cams_file)
cams.plot_params()

In [None]:

##################################
## ADD ELEVATION AND PRESSURE BAND
##################################
#prod.get_elevation()


#####################################
# LOAD LUT FOR ATMOSPHERIC CORRECTION
#####################################
logging.info('loading lut...' + prod.lutfine)
lutf = acutils.lut(prod.band_names)
lutc = acutils.lut(prod.band_names)
lutf.load_lut(prod.lutfine, prod.sensordata.indband)
lutc.load_lut(prod.lutcoarse, prod.sensordata.indband)


In [None]:

####################################
#    absorbing gases correction
####################################
gas_trans = acutils.gaseous_transmittance(prod, cams)
Tg_raster = gas_trans.get_gaseous_transmittance()

logging.info('correct for gaseous absorption')
for wl in prod.wl.values:
    prod.raster['bands'].loc[wl] = prod.raster.bands.sel(wl=wl) / Tg_raster.sel(wl=wl).interp(x=prod.raster.x,
                                                                                              y=prod.raster.y)
prod.raster.bands.attrs['gas_absorption_correction'] = True
plt.figure()
Tg_raster.mean('x').mean('y').plot()
# Tg_raster.isel(wl=1).plot()
p = Tg_raster.plot.imshow(col='wl', col_wrap=3, robust=True, cmap=plt.cm.Spectral_r,
                          subplot_kws=dict(xlabel='', ylabel='', xticks=[], yticks=[]))

In [None]:

######################################
# Water mask
######################################
# Compute NDWI
green = prod.raster.bands.sel(wl=prod.b565)
nir = prod.raster.bands.sel(wl=prod.b865)
swir = prod.raster.bands.sel(wl=prod.b1600)
b2200 = prod.raster.bands.sel(wl=prod.b2200)

ndwi = (green - nir) / (green + nir)
ndwi_swir = (green - swir) / (green + swir)

prod.raster['ndwi'] = ndwi
prod.raster.ndwi.attrs = {
'description': 'Normalized difference spectral index between bands at ' + str(prod.b565) + ' and ' + str(
    prod.b865) + ' nm', 'units': '-'}
prod.raster['ndwi_swir'] = ndwi_swir
prod.raster.ndwi_swir.attrs = {
'description': 'Normalized difference spectral index between bands at ' + str(prod.b565) + ' and ' + str(
    prod.b1600) + ' nm', 'units': '-'}


prod.raster['bands'] = prod.raster.bands.where(ndwi > prod.ndwi_threshold). \
    where(b2200 < prod.sunglint_threshold). \
    where(ndwi_swir > prod.green_swir_index_threshold)


In [None]:
wl_process =[443,  490,  560,  665,  705,
             740,  783,  842,  865, 1610 , 2190]
wl_process = prod.wl_process
Nband = len(wl_process)
vza = prod.raster.vza.sel(wl=wl_process)
sza = prod.raster.sza
raa = 180 - prod.raster.raa.sel(wl=wl_process)

sza_ = np.linspace(sza.min(), sza.max(), 10)
vza_ = np.linspace(vza.min(), vza.max(), 20)
raa_ = np.linspace(raa.min(), raa.max(), 60)

In [None]:
eps_sunglint = prod.sensordata.rg
rot = prod.sensordata.rot
rrs = prod.rrs
width = prod.width
height = prod.height
chunk = 512
ptype=np.float32
logging.info('slice raster for desired wavelengths')
raster = prod.raster['bands'].sel(wl=wl_process)

solar_irr = prod.solar_irradiance.sel(wl=wl_process).values

logging.info('get/set aerosol parameters')
aotlut = np.array(lutf.aot, dtype=prod.type)
fine_refl = lutf.refl.interp(vza=vza_).interp(azi=raa_).interp(sza=sza_)
coarse_refl = lutc.refl.interp(vza=vza_).interp(azi=raa_).interp(sza=sza_)
lut_shape = fine_refl.shape
fine_Cext = lutf.Cext
coarse_Cext = lutc.Cext
aot_tot_cams_res = cams.cams_aod.interp(wavelength=wl_process)
aot_sca_cams_res = aot_tot_cams_res * cams.cams_ssa.interp(wavelength=wl_process)
# aot_tot = aot_tot_cams_res.interp(x=raster.x, y=raster.y)
# aot_sca = aot_sca_cams_res.interp(x=raster.x, y=raster.y)
# aot550guess = cams.raster.aod550.interp(x=raster.x, y=raster.y)
# fcoef = np.full((prod.height, prod.width), 0.5)

# TODO implement pre-masking, now set to zero
maskpixels = np.zeros((prod.height, prod.width))

logging.info('get pressure full raster')
pressure_corr = cams.raster.sp.interp(x=raster.x, y=raster.y) * 1e-2 / prod.pressure_ref


In [None]:
rcorr = np.full((Nband, width,height), np.nan,dtype=ptype)  # , order='F').T
rcorrg = np.full((Nband, width,height), np.nan,dtype=ptype)  # , order='F').T
aot550pix = np.full((width,height), np.nan,dtype=ptype)
brdfpix = np.full((width,height), np.nan,dtype=ptype)

In [None]:
for iy in range(0, width, chunk):
    yc = iy + chunk
    if yc > width:
        yc = width
    for ix in range(0, height, chunk):
        xc = ix + chunk
        if xc > height:
            xc = height

        _sza = sza[ix:xc, iy:yc]
        nx, ny = _sza.shape
        if (nx == 0) or (ny == 0):
            continue
        _raa = raa[:, ix:xc, iy:yc]
        _vza = vza[:, ix:xc, iy:yc]
        _maskpixels = maskpixels[ix:xc, iy:yc]
        _band_rad = raster[:, ix:xc, iy:yc]

        # prepare aerosol parameters
        aot_tot = aot_tot_cams_res.interp(x=_band_rad.x, y=_band_rad.y)
        aot_sca = aot_sca_cams_res.interp(x=_band_rad.x, y=_band_rad.y)
        aot550guess = cams.raster.aod550.interp(x=_band_rad.x, y=_band_rad.y)
        fcoef = np.full((nx,ny), 0.65)
        
        pressure_corr = cams.raster.sp.interp(x=_band_rad.x, y=_band_rad.y) * 1e-2 / prod.pressure_ref

        p = grs_solver.grs.main_algo(nx, ny, *lut_shape,
                                     aotlut, sza_, raa_, vza_,
                                     fine_refl, coarse_refl, fine_Cext, coarse_Cext,
                                     _vza, _sza, _raa, _band_rad.values, _maskpixels,
                                     wl_process, pressure_corr, eps_sunglint, solar_irr, rot,
                                     aot_tot, aot_sca, aot550guess, fcoef, rrs)
        rcorr[:, ix:xc, iy:yc] = p[0]
        rcorrg[:, ix:xc, iy:yc] = p[1]
        aot550pix[ix:xc, iy:yc] = p[2]
        brdfpix[ix:xc, iy:yc] = p[3]


In [None]:

Rrs = xr.DataArray(rcorr, coords=raster.coords, name='Rrs')
Rrs_g = xr.DataArray(rcorrg, coords=raster.coords, name='Rrs_g')
aot550 = xr.DataArray(aot550pix, coords={'y': raster.y, 'x': raster.x}, name='aot550')
brdfg = xr.DataArray(brdfpix, coords={'y': raster.y, 'x': raster.x}, name='BRDFg')
l2_prod = xr.merge([ aot550, brdfg,Rrs, Rrs_g])
l2a = l2a_product(prod, l2_prod, cams, gas_trans)


In [None]:
l2a.to_netcdf('/sat_data/satellite/sentinel2/L2A/S2B_MSIL2Agrs_20220731T103629_N0400_R008_T31TFJ_20220731T124834')

In [None]:
complevel=6
encoding = {
            'aot550': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999, "zlib": True,
                       "complevel": complevel},
            'BRDFg': {'dtype': 'int16', 'scale_factor': 0.00001, 'add_offset': .3, '_FillValue': -32768, "zlib": True,
                      "complevel": complevel},
            'Rrs': {'dtype': 'int16', 'scale_factor': 0.00001, 'add_offset': .3, '_FillValue': -32768, "zlib": True,
                    "complevel": complevel},
            'Rrs_g': {'dtype': 'int16', 'scale_factor': 0.00001, 'add_offset': .3, '_FillValue': -32768, "zlib": True,
                      "complevel": complevel},
            'o2_band': {'dtype': 'int16', 'scale_factor': 0.00001, 'add_offset': .3, '_FillValue': -32768, "zlib": True,
                        "complevel": complevel},
            'cirrus_band': {'dtype': 'int16', 'scale_factor': 0.00001, 'add_offset': .3, '_FillValue': -32768,
                            "zlib": True,
                            "complevel": complevel},
            'ndwi': {'dtype': 'int16', 'scale_factor': 0.0001, '_FillValue': -32768, "zlib": True,
                     "complevel": complevel},
            'ndwi_swir': {'dtype': 'int16', 'scale_factor': 0.0001, '_FillValue': -32768, "zlib": True,
                          "complevel": complevel},
            'vza': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999, "zlib": True, "complevel": complevel},
            'raa': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999, "zlib": True, "complevel": complevel},
            'sza': {'dtype': 'int16', 'scale_factor': 0.001, '_FillValue': -9999, "zlib": True, "complevel": complevel},

        }

In [None]:
output_path='/sat_data/satellite/sentinel2/L2A/S2B_MSIL2Agrs_20220731T103629_N0400_R008_T31TFJ_20220731T124834'
# export full raster data
basename = os.path.basename(output_path)
ofile = os.path.join(output_path, basename)

if not os.path.exists(output_path):
    os.mkdir(output_path)

# clean up to avoid permission denied
if os.path.exists(ofile + '.nc'):
    os.remove(ofile + '.nc')
if os.path.exists(ofile + '_anc.nc'):
    os.remove(ofile + '_anc.nc')

arg = 'w'
for variable in list(l2a.l2_prod.keys()):
    l2a.l2_prod[variable].to_netcdf(ofile + '.nc', arg, encoding={variable: encoding[variable]})
    arg = 'a'
l2a.l2_prod.close()


In [None]:
{variable: encoding[variable]}

In [None]:
l2afile = '/media/harmel/data_sat/satellite/sentinel2/L2A/test.nc'
l2a.l2_prod.to_netcdf(l2afile,encoding=encoding)
l2a.ancillary.to_netcdf(l2afile,'a')

In [None]:
l2a.l2_prod.Rrs.plot.imshow(col='wl',col_wrap=4,robust=True,vmin=0,vmax=0.015,cmap=plt.cm.Spectral_r)


In [None]:
filenc = '/sat_data/satellite/sentinel2/L2A/S2B_MSIL2Agrs_20220731T103629_N0400_R008_T31TFJ_20220731T124834.nc'
fdata = xr.open_dataset(filenc,decode_coords='all')

fanc = xr.open_dataset(filenc,decode_coords='all',group='zancillary')

In [None]:
variables = ['aot550', 'BRDFg', 'Rrs', 'Rrs_g', 'vza', 'sza', 'raa', 'o2_band', 'cirrus_band', 'ndwi', 'ndwi_swir']
fdata[variables]

In [None]:
l2a.l2_prod


In [None]:
encoding#[variable]

In [None]:
arg='w'
for variable in list(l2a.l2_prod.keys()):
    l2a.l2_prod[variable].to_netcdf(l2afile,arg,encoding={variable:encoding[variable]})
    arg='a'

In [None]:
fanc.to_netcdf('/sat_data/satellite/sentinel2/L2A/S2B_MSIL2Agrs_20220731T103629_N0400_R008_T31TFJ_20220731T124834_anc.nc','w')

In [None]:
d = xr.merge([fdata[variables],fanc]).compute()
d

In [None]:
for variable in list(l2a.ancillary.keys()):

    encoding[variable] = {"zlib": True, "complevel": complevel}
d.to_netcdf(l2afile) #,encoding=encoding)

In [None]:
fanc

In [None]:
print(list(fdata.keys())[::1])
