In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
from odc.stac import configure_rio
from dask_jobqueue import SLURMCluster
from dask.distributed import Client as daskClient
# from dask import compute as dask_compute
# from xarray import open_zarr as xr_open_zarr

from stacathome import MaxiCube

configure_rio(cloud_defaults=True, aws={"aws_unsigned": True})

# Construct Cube Object

In [2]:
aoi = 'somalia'
requested_bands = ['B02', 'B03', 'B04', 'B8A']
resolution = 0.00018  # 20  # 0.00018
crs = 4326  # 32737  # 4326
chunksize_xy = 256
chunksize_t = 1000

out_path = '/Net/Groups/BGI/data/DataStructureMDI/DATA/Incoming/Sentinel/_2/S2A_L2A/ForSites/Sentinel2tiles'
os.makedirs(out_path, exist_ok=True)

zarr_store = '/Net/Groups/BGI/scratch/mzehner/VCI_Somalia/Somalia_S2_2014_2026.zarr'

mxc = MaxiCube(aoi=aoi,
               requested_bands=requested_bands,
               crs=crs,
               resolution=resolution,
               chunksize_xy=chunksize_xy,
               chunksize_t=chunksize_t,
               path=out_path,
               # zarr_path=zarr_store,
               )

# this is for no longer present scenes, which are still stored in the item file
# mxc.req_items = mxc.items_local_global
# mxc.compare_local(report=True)
# mxc.items_local_global = mxc.req_items_local
# mxc.req_items = mxc.items_local_global
# mxc.compare_local(report=True)

# mxc.save_items()

len(mxc.items_local_global)

19136

In [42]:
mxc.save()

In [4]:
from stacathome import load_maxicube
mxc_loaded = load_maxicube(out_path + '/saved.maxicube')
mxc_loaded.req_items = mxc_loaded.items_local_global
mxc_loaded.compare_local(report=True)

All data already downloaded.


# Parallel request and download of tiles using SLURM

In [None]:
cluster = SLURMCluster(
    queue='work',                  # Specify the SLURM queue
    cores=1,                          # Number of cores per job
    memory='768MB',                    # Memory per job
    walltime='03:00:00',              # Job duration (hh:mm:ss)
)

# Scale up the number of workers
# cluster.scale(jobs=8)  # Adjust the number of jobs/workers
cluster.adapt(minimum=1, maximum=20)

# Create a Dask client that connects to the cluster
client = daskClient(cluster)

# Check cluster status
cluster

In [22]:
process = mxc.download_all('2021-01-01', '2021-01-31')

In [24]:
client.close()
cluster.close()

In [None]:
mxc.save_items(process)

In [10]:
mxc.plot(subset_chunks_by=50)

# Load the requested data as on-the-fly xarray

In [None]:
otf_cube = mxc.load_otf_cube(subset=4, enlarge_by_n_chunks=2)
otf_cube

In [None]:
otf_cube.B02.mean(dim=[mxc.dimension_names['longitude'],
                  mxc.dimension_names['latitude']]).plot()

In [None]:
(otf_cube.B02.where(otf_cube.B02 != 0, np.nan).median(dim='time')/10000).plot()

# Insert data into a larger consistent cube

In [3]:
# setup for large datacube
zarr_store = '/Net/Groups/BGI/scratch/mzehner/VCI_Somalia/Somalia_S2_2014_2026.zarr'
mxc.construct_large_cube(zarr_store, overwrite=False)

Zarr already exists at /Net/Groups/BGI/scratch/mzehner/VCI_Somalia/Somalia_S2_2014_2026.zarr. Skipping creation. Set overwrite=True to overwrite.


In [None]:
cluster = SLURMCluster(
    queue='work',                  # Specify the SLURM queue
    cores=1,                          # Number of cores per job
    memory='8GB',                    # Memory per job
    walltime='03:00:00',              # Job duration (hh:mm:ss)
)

# Scale up the number of workers
# cluster.scale(jobs=8)  # Adjust the number of jobs/workers
cluster.adapt(minimum=1, maximum=40)

# Create a Dask client that connects to the cluster
client = daskClient(cluster)

# Check cluster status
cluster

In [None]:
res = mxc.fill_large_cube(client=client)
# TODO: save the filled in times into the chunktable?

# for d in delayed_subsets:
#     mxc.chunk_table.loc[d[0], 'timerange_in_zarr'].append([np.datetime_as_string(d[1], unit='D'),
#                                                            np.datetime_as_string(d[2], unit='D')])
# mxc.chunk_table.loc[
#     mxc.chunk_table.
#     clip(
#         subset.boundingbox).index
#         ]['timerange_in_zarr'] = self.chunk_table.loc[
#             self.chunk_table.clip(
#                 subset.boundingbox).index
#                 ]['timerange_in_zarr'].apply(
#                     lambda x: x.append([min_time,max_time
#                                         ]))

In [9]:
client.close()
cluster.close()

# Addressing the large dataset as a minicube

In [None]:
mc_int = mxc.get_chunk(0, ('2021-01-01', '2021-01-31')).compute()
mc_int

In [None]:
mc_int.B02.mean(dim=[mxc.dimension_names['latitude'],
                mxc.dimension_names['longitude']]).plot()

In [None]:
def ndvi(band_red, band_nir):
    return (band_nir - band_red) / (band_nir + band_red)


mc_float = mc_int.where(mc_int.B02 != 0, np.nan).dropna(dim='time', how='all')
ndvi(mc_float.B04.mean(dim='time'), mc_float.B8A.mean(dim='time')).plot.imshow()

In [None]:
mc_drop_fill = mxc.get_chunk(
    0, ('2021-01-01', '2021-01-31'), drop_fill=True).compute()
mc_drop_fill

In [None]:
mxc.request_items('2015-01-01', '2026-01-31', subset=750, new_request=True)

In [None]:
otf_cube = mxc.load_otf_cube(mxc.req_items_local, mxc.subset(chunk_id=750)[0])
otf_cube

In [None]:
subset = 1070
mxc.request_items('2021-01-01', '2024-05-01', subset=subset, new_request=True)
otf_cube = mxc.load_otf_cube(
    mxc.req_items_local, subset=subset, enlarge_by_n_chunks=0, drop_fill=True)

# Again: handle faulty downloads

In [26]:
from rasterio.windows import Window
from rasterio.errors import RasterioIOError, WarpOperationError
from rasterio import open as rio_open


def check_asset(item):
    not_found, read_failed = [], []
    for a in item.assets:
        path = item.assets[a].href
        if not path.startswith('/Net') or not os.path.exists(path):
            not_found.append(path)
            continue
        try:
            # with rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', VSI_CACHE=True, ):
            with rio_open(path) as src:
                # src.width//2 - 128, src.height//2 - 128,
                src.read(1, window=Window(0, 0, 256, 256))
                # src.width//2 + 128, src.height//2 + 128))
        except RasterioIOError:
            not_found.append(path)
        except WarpOperationError:
            read_failed.append(path)
        except Exception as e:
            read_failed.append(path + f' {e}')
    return not_found, read_failed

In [None]:
from tqdm import tqdm
# this will take a while, but failed in SLURM
res = []
for i in tqdm(mxc.items_local_global):
    res.append(check_asset(i))

In [35]:
not_found, read_failed = [], []
for r in res:
    if len(r[0]) > 0:
        not_found.extend(r[0])
    if len(r[1]) > 0:
        read_failed.extend(r[1])

for i in not_found:
    if os.path.exists(i):
        print(i, os.path.getsize(i)//1000000)
        os.remove(i)
    else:
        print('not exist:', i)

In [None]:
import os
import rasterio

def is_valid_tiff_rasterio(file_path):
    try:
        with rasterio.open(file_path) as src:
            src.read(1)  # Try reading the first band
        return True
    except (rasterio.errors.RasterioIOError, Exception) as e:
        print(f"Error with file {file_path}: {e}")
        return False

def check_tiff_files_recursively_rasterio(directory):
    valid_files = []
    invalid_files = []

    for root, dirs, files in tqdm(os.walk(directory)):
        for file in files:
            if file.endswith(('.tif', '.tiff')):
                file_path = os.path.join(root, file)
                if is_valid_tiff_rasterio(file_path):
                    valid_files.append(file_path)
                else:
                    invalid_files.append(file_path)

    return valid_files, invalid_files

# Example usage
directory = out_path
valid_tiffs, invalid_tiffs = check_tiff_files_recursively_rasterio(directory)

print("Valid TIFFs:", valid_tiffs)
print("Invalid TIFFs:", invalid_tiffs)

In [None]:
from tqdm import tqdm
import logging
import io

from stacathome.walltowall import _load_otf_cube_bulk
from rasterio.errors import WarpOperationError

gdf = mxc.items_as_geodataframe()

log_stream = io.StringIO()
logging.basicConfig(level=logging.ERROR, stream=log_stream, format='%(levelname)s:%(message)s')
e_msg = []
for i in tqdm([10000]):  # range(0, len(mxc.chunk_table), 5500)):
    subset, _ = mxc.subset(i)


    items = gdf.clip(subset.boundingbox)['asset_items'].to_list()

    if len(items) == 0:
        continue

    cube = _load_otf_cube_bulk(subset=subset, 
                               filtered_items=items, 
                               requested_bands=None)

    try:
        cube.mean(dim=[mxc.dimension_names['latitude'], 
                       mxc.dimension_names['longitude']]).compute()
        # cube.compute() 
    except WarpOperationError as e:
        pass
    #    print(i)
    # print(log_stream.getvalue())
logged_message = log_stream.getvalue()
log_stream.close()
logged_message

msgs = logged_message.split('\n')
erred_scenes = []
for m in msgs:
    if m.startswith('ERROR:Aborting load due to failure while reading:'):
        erred_scenes.append(m.split(' ')[-1].split(':')[0])
erred_scenes = set(erred_scenes)

ids = [i.id for i in mxc.items_local_global]
empty_items = []

for i in erred_scenes:
    path_sep = i.split('/')
    item_name = path_sep[-6]
    item_name = item_name[:27] + item_name[33:-5]
    asset = path_sep[-1][-11:-8]
    print(item_name, asset)
    # try:
    #     pos = ids.index(item_name)
    #     if asset in mxc.items_local_global[pos].assets:
    #         del mxc.items_local_global[pos].assets[asset]
    #         if len(mxc.items_local_global[pos].assets) == 0:
    #             empty_items.append(pos)
    # except ValueError:
    #     continue
