In [None]:
from datetime import datetime
import math
from pathlib import Path
import re
from typing import Union
from ipyfilechooser import FileChooser

import numpy as np
import pycrs
import s3fs
import xarray as xr
import zarr

from osgeo import gdal

import opensarlab_lib as osl

In [None]:
# # TODO Add to opensarlab-lib

# def stack_aligned(data_paths):
#     corner_coords = [osl.get_corner_coords(i) for i in data_paths]
#     ulx = set([c[0][0] for c in corner_coords])
#     uly = set([c[0][1] for c in corner_coords])
#     lrx = set([c[1][0] for c in corner_coords])
#     lry = set([c[1][1] for c in corner_coords])
#     return len(ulx) == len(uly) == len(lrx) == len(lry) == 1

def get_corners_gdal(file):
    ds=gdal.Open(str(file))
    transform = ds.GetGeoTransform()
    x = ds.RasterXSize
    y = ds.RasterYSize
    
    ulx = transform[0]
    uly = transform[3]
    lrx = transform[0] + x * transform[1]
    lry = transform[3] + y * transform[5]
    
    return {'ul': [ulx, uly], 'lr': [lrx, lry]}

def get_epsg(file):
    info = gdal.Info(str(file), format='json')
    info = info['coordinateSystem']['wkt']
    return info.split('ID["EPSG"')[-1][1:6]

def polarization_from_filename(file):
    pol_regex = 'vv|VV|vh|VH|hh|HH|hv|HV'
    pol = re.search(pol_regex, str(file))
    if pol:
        pol = pol.group(0).upper()
    return pol

def product_type_from_filename(file):
    if "RTC" in str(file):
        prod_type = "RTC"
    elif "INT" in str(file):
        prod_type = "INSAR"
    else:
        return
    return prod_type

def parse_proj_crs(proj_crs):
    crs = pycrs.parse.from_ogc_wkt(proj_crs)
    cfg_p = {}
    cfg_p['grid_mapping_name'] = crs.name
    cfg_p['crs_wkt'] = crs.proj.name.ogc_wkt.lower()

    # Is there a better way to do this? 
    for p in crs.params:
        if isinstance(p,pycrs.elements.parameters.LatitudeOrigin):
            cfg_p['latitude_of_projection_origin'] = p.value
        if isinstance(p,pycrs.elements.parameters.CentralMeridian):
            cfg_p['longitude_of_central_meridian'] = p.value
        if isinstance(p,pycrs.elements.parameters.FalseEasting):
            cfg_p['false_easting'] = p.value
        if isinstance(p,pycrs.elements.parameters.FalseNorthing):
            cfg_p['false_northing'] = p.value
        if isinstance(p,pycrs.elements.parameters.ScalingFactor):
            cfg_p['scale_factor_at_centeral_meridian'] = p.value

    cfg_p['projected_coordinate_system_name'] = crs.name
    cfg_p['geographic_coordinate_system_name'] = crs.geogcs.name
    cfg_p['horizontal_datum_name'] = crs.geogcs.datum.name.ogc_wkt
    cfg_p['reference_ellipsoid_name'] = crs.geogcs.datum.ellips.name.ogc_wkt
    cfg_p['semi_major_axis'] = crs.geogcs.datum.ellips.semimaj_ax.value
    cfg_p['inverse_flattening'] = crs.geogcs.datum.ellips.inv_flat.value
    cfg_p['longitude_of_prime_meridian'] = crs.geogcs.prime_mer.value
    cfg_p['units'] = crs.unit.unitname.ogc_wkt
    cfg_p['projection_x_coordinate'] = "x"
    cfg_p['projection_y_coordinate'] = "y"

    return cfg_p

def dates_from_product_name(product_name: Union[str, Path]) -> Union[str, None]:
    """
    Takes: a string or posix path to a HyP3 product
    Returns: a string date and timestamp parsed from the name or None if none found
    """
    regex = "[0-9]{8}T[0-9]{6}_[0-9]{8}T[0-9]{6}"
    results = re.search(regex, str(product_name))
    if results:
        return results.group(0)
    else:
        return None
    
def datetime_from_hyp3_dt_str(date_str):
    return datetime(int(date_str[0:4]),int(date_str[4:6]),int(date_str[6:8]),
                    int(date_str[9:11]),int(date_str[11:13]),int(date_str[13:15]),0)

def get_RTC_prod_hash(tiff):
    stem = Path(tiff).stem
    regex = "(?<=gpuned_)[0-9A-Z]{4}(?=_[V|H][V|H])"
    p_hash = re.search(regex, str(stem))
    if p_hash:
        return (p_hash.group(0))

In [None]:
def get_insar_attrs(tiff):
    ds=gdal.Open(str(tiff))

    #get the product type and polarization from the filename
    prod_type = product_type_from_filename(tiff)
    
    #get corner coords and extents
    corners = get_corners_gdal(tiff)
    x_extent = [corners['ul'][0], corners['lr'][0]]
    y_extent = [corners['ul'][1], corners['lr'][1]]

    # pixel resolution
    geo_trans = ds.GetGeoTransform()
    res_x = geo_trans[1]
    res_y = geo_trans[5]

    # create x and y arrays based on extents and pixel resolution
    x_coords = np.arange(x_extent[0], x_extent[1], res_x)
    y_coords = np.arange(y_extent[0], y_extent[1], res_y)

    # get the no_data value
    tiff_info = gdal.Info(str(tiff), format='json')
    no_data = tiff_info['bands'][0]['noDataValue']

    attrs = {
            'institution': 'Alaska Sattelite Facility (ASF)',
            'references': 'https://asf.alaska.edu/', 
            'source': 'SAR observation',
            'Conventions': 'CF-1.8',
            'platform': tiff.name[:2],
            'product_type': prod_type,
            'fill_value' : no_data,
            'sensor_band_identifier' : 'C',
            'x_spacing' : res_x,
            'y_spacing' : res_y,
            'title': 'SAR InSAR',
            'long_name': f'SAR InSAR',
            'description': f'SAR InSAR data',
            'times': dates_from_product_name(tiff),
        }
    return attrs

In [None]:
def get_insar_product_type_from_filename(path):
    if re.search('\w+_corr_\w*.tif', str(path)):
        p_type = 'corr'
    elif re.search('\w+_dem_\w*.tif', str(path)):
        p_type = 'dem'
    elif re.search('\w+_lv_phi_\w*.tif', str(path)):
        p_type = 'lv_phi'
    elif re.search('\w+_lv_theta_\w*.tif', str(path)):
        p_type = 'lv_theta'
    elif re.search('\w+_unw_phase_\w*.tif', str(path)):
        p_type = 'unw_phase'
    elif re.search('\w+_water_mask_\w*.tif', str(path)):
        p_type = 'water_mask'
    else:
        p_type = None
    return p_type

In [None]:
def hyp3_mintpy_InSAR_to_xarray(insar):

    # put each product type in an ndarray
    for f in insar:
        ds = gdal.Open(str(f))
        banddata = ds.GetRasterBand(1)
        data = banddata.ReadAsArray()
        prod_type = get_insar_product_type_from_filename(f)
        if prod_type:
            exec(f"{prod_type} = np.ma.masked_invalid(data, copy=True)")

    ds=gdal.Open(str(insar[1]))
    

    # get coordinate system projection
    # prj = ds.GetProjection()
    # print(prj)
    # crs = pycrs.parse.from_ogc_wkt(prj)
    # crs_proj = crs.proj.name.ogc_wkt.lower()   

    # pixel resolution
    geo_trans = ds.GetGeoTransform()
    res_x = geo_trans[1]
    res_y = geo_trans[5]
    
    #get corner coords and extents
    corners = get_corners_gdal(insar[0])
    x_extent = [corners['ul'][0], corners['lr'][0]]
    y_extent = [corners['ul'][1], corners['lr'][1]]

    # create x and y arrays based on extents and pixel resolution
    x_coords = np.arange(x_extent[0], x_extent[1], res_x)
    y_coords = np.arange(y_extent[0], y_extent[1], res_y)

    # create xarray dataset
    data_set = xr.Dataset(
        data_vars={
            'y': y_coords,
            'x': x_coords,
            'corr': (
                ('y', 'x'),
                locals()['corr'].filled(0.0),
            ),
            'dem': (
                ('y', 'x'),
                locals()['dem'].filled(0.0),
            ),
            'lv_phi': (
                ('y', 'x'),
                locals()['lv_phi'].filled(0.0),
            ),
            'lv_theta': (
                ('y', 'x'),
                locals()['lv_theta'].filled(0.0),
            ),
            'unw_phase': (
                ('y', 'x'),
                locals()['unw_phase'].filled(0.0),
            ),
            'water_mask': (
                ('y', 'x'),
                locals()['water_mask'].filled(0.0),
            ),            
        },
        attrs=get_insar_attrs(insar[0])
    )

    # Set x and y coord attributes
    attrs_x = {
        'axis': 'X',
        'units': 'm',
        'standard_name': 'projection_x_coordinate',
        'long_name': 'Easting'
    }
    attrs_y = {
        'axis': 'Y',
        'units': 'm',
        'standard_name': 'projection_y_coordinate',
        'long_name': 'Northing'
    }
    for key in attrs_x:
        data_set.x.attrs[key] = attrs_x[key]
    for key in attrs_y:
        data_set.y.attrs[key] = attrs_y[key]     

    return data_set

In [None]:
fc = FileChooser(Path.cwd())
display(fc)

In [None]:
tiff_dirs = [p for p in list(Path(fc.selected_path).glob('*')) if p.is_dir()]
stack_paths = []
for d in tiff_dirs:
    tiffs = list(d.glob('*_clip.tif'))
    stack_paths.append(tiffs)

insar_arrays = []
for insar in stack_paths:
    insar_arrays.append(hyp3_mintpy_InSAR_to_xarray(insar))

In [None]:
# Sort vv_vh_arrays by time and create a list of times for the stack time dimension
def time_sort(a):
    return a.times

insar_arrays.sort(key=time_sort)
times = [a.times for a in insar_arrays]

# Create the stack
stack = insar_arrays[0]
stack = stack.drop_vars('corr')
stack = stack.drop_vars('dem')
stack = stack.drop_vars('lv_phi')
stack = stack.drop_vars('lv_theta')
stack = stack.drop_vars('unw_phase')
stack = stack.drop_vars('water_mask')

stack = stack.assign_coords(times=times)
stack.times.attrs['axis'] = "T" 
stack.times.attrs['units'] = f"timestamp in format %Y%m%dT%H%M%S"
stack.times.attrs['calendar'] = "proleptic_gregorian"
stack.times.attrs['long_name'] = "Time"

xarr3d = xr.concat([d.corr for d in insar_arrays], dim=stack.times)
for d in insar_arrays:
    d = d.drop_vars('corr')
stack['corr'] = xarr3d

xarr3d = xr.concat([d.dem for d in insar_arrays], dim=stack.times)
for d in insar_arrays:
    d = d.drop_vars('dem')
stack['dem'] = xarr3d

xarr3d = xr.concat([d.lv_phi for d in insar_arrays], dim=stack.times)
for d in insar_arrays:
    d = d.drop_vars('lv_phi')
stack['lv_phi'] = xarr3d

xarr3d = xr.concat([d.lv_theta for d in insar_arrays], dim=stack.times)
for d in insar_arrays:
    d = d.drop_vars('lv_theta')
stack['lv_theta'] = xarr3d

xarr3d = xr.concat([d.unw_phase for d in insar_arrays], dim=stack.times)
for d in insar_arrays:
    d = d.drop_vars('unw_phase')
stack['unw_phase'] = xarr3d

xarr3d = xr.concat([d.water_mask for d in insar_arrays], dim=stack.times)
for d in insar_arrays:
    d = d.drop_vars('water_mask')
stack['water_mask'] = xarr3d


del xarr3d
del insar_arrays

In [None]:
# stack.to_netcdf(Path(fc.selected_path).parent/"full_stack.nc4")

In [None]:
stack

**Calculate the time optimized chunk shape given:**

- 1MB < Optimal chunk size < ?MB
    - https://zarr.readthedocs.io/en/stable/tutorial.html
        - "In general, chunks of at least 1 megabyte (1M) uncompressed size seem to provide better performance, at least when using the Blosc compression library."
- The depth of the stack
- The number of xarray.DataArray variables
- square spatial chunks
    - x chunk dimension == y chunk dimension

In [None]:
s3_path = "s3://asf-jupyter-data-west/zarr_test/indonesia"
s3 = s3fs.S3FileSystem(anon=True)
store = s3fs.S3Map(root=s3_path, s3=s3, check=False)

In [None]:
bits_per_mb = 8000000
mb_per_chunk = 100
bits_per_chunk = bits_per_mb * mb_per_chunk
bits_per_pixel = 32
pixels_per_chunk = bits_per_chunk / bits_per_pixel
print(f'Desired pixels per chunk: {pixels_per_chunk}')
depth = len(times)
print(f'depth: {depth}')
data_array_var_count = 2 # vh_backscatter and vv_backscatter
spatial_pixels = pixels_per_chunk // (depth * data_array_var_count)
print(f'spatial_pixels: {spatial_pixels}')
x_y_pixels = math.floor(math.sqrt(spatial_pixels))
print(f'x_y_pixels: {x_y_pixels}')
print(f'Actual pixels per chunk: {depth * data_array_var_count * x_y_pixels * x_y_pixels}')
time_optimized_chunk = (depth, x_y_pixels, x_y_pixels)
print(time_optimized_chunk)

In [None]:
group = 'stack_time_optimized_100MB_chunks'
compressor = zarr.Blosc(cname='zstd', clevel=3)
encoding = {vname: {'compressor': compressor, 'chunks': (depth,x_y_pixels,x_y_pixels)} for vname in stack.data_vars}

zarr_stack = stack.to_zarr(store=store, encoding=encoding, consolidated=True, group=group)

**Calculate the spatially optimized chunk shape**



In [None]:
bits_per_mb = 8000000
mb_per_chunk = 100
bits_per_chunk = bits_per_mb * mb_per_chunk
bits_per_pixel = 32
pixels_per_chunk = bits_per_chunk / bits_per_pixel
print(f'Desired pixels per chunk: {pixels_per_chunk}')
depth = 1
print(f'depth: {depth}')
data_array_var_count = 2 # vh_backscatter and vv_backscatter
spatial_pixels = pixels_per_chunk // (depth * data_array_var_count)
print(f'spatial_pixels: {spatial_pixels}')
x_y_pixels = math.floor(math.sqrt(spatial_pixels))
print(f'x_y_pixels: {x_y_pixels}')
print(f'Actual pixels per chunk: {depth * data_array_var_count * x_y_pixels * x_y_pixels}')
space_optimized_chunk = (depth, x_y_pixels, x_y_pixels)
print(space_optimized_chunk)

In [None]:
group = 'stack_space_optimized_100MB_chunks'
compressor = zarr.Blosc(cname='zstd', clevel=3)
encoding = {vname: {'compressor': compressor, 'chunks': (depth,x_y_pixels,x_y_pixels)} for vname in stack.data_vars}

zarr_stack = stack.to_zarr(store=store, encoding=encoding, consolidated=True, group=group)