In [None]:
from geom_median.numpy import compute_geometric_median as gm
import numpy as np
from utils import *
import dask
import dask.distributed
import xarray as xr
import rioxarray as rxr
from hdstats import nangeomedian_pcm
import geopandas as gpd
from odc.algo import (
    enum_to_bool,
    geomedian_with_mads,
    erase_bad,
    mask_cleanup,
    keep_good_only,
)
from odc.geo import BoundingBox
from odc.geo.xr import assign_crs
from odc.io.cgroups import get_cpu_quota
from odc.stac import configure_rio, stac_load
import pystac
import logging

requester_pays = False
if requester_pays:
    aws_session = rasterio.session.AWSSession(
        boto3.Session(), aws_unsigned=True, requester_pays=True
    )
else:
    aws_session = rasterio.session.AWSSession(boto3.Session(), aws_unsigned=True)
client = dask.distributed.Client(
    n_workers=4, threads_per_worker=1, silence_logs=logging.ERROR
)
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
display(client)

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
bbox_list = [
    wa_bbox,
    [67.45, -72.55, 67.55, -72.45],
]  # WA and a small area in the Arctic

In [None]:
AOI = "WA"
MISSION = "SENTINEL-2"
aoi_idx = 0
bbox = bbox_list[aoi_idx]
masking_band = ["scl"]
measurements = ["red", "green", "blue"]
mask_filters = [("opening", 10), ("dilation", 1)]
# crs = "EPSG:3031"
resolution = 100
tile_id = ""
condition = tile_id if tile_id != "" else ""
output_suffix = "odc_loader"

In [None]:
bands = measurements + masking_band
output_dir = f"data/inputs/{MISSION}_{AOI}"
items_file = f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json"
items_exist = os.path.exists(items_file)

In [None]:
if not items_exist:
    query = get_search_query(
        bbox,
        collections=["SENTINEL-2"],
        start_date="2016-01-01T00:00:00",
        end_date="2021-01-01T00:00:00",
        is_landsat=False,
    )
    query["collections"] = ["sentinel-2-l2a"]
    del query["page"]
    server_url = "https://earth-search.aws.element84.com/v1"
    display(query)
    items = query_stac_server(query, server_url, pystac=True, return_pystac_items=True)
    print(f"Found {len(items)} items.")

In [None]:
if not items_exist:
    s2_scenes = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scenes.csv")
    scene_list = s2_scenes.to_dict("records")
    print(len(scene_list), "scenes found in the CSV file.")
    scene_names = [
        scene["scene_name"] for scene in scene_list if condition in scene["scene_name"]
    ]

    gdf = gpd.GeoDataFrame.from_features(items, "epsg:4326")
    item_names = list(gdf["earthsearch:s3_path"].apply(lambda x: x.split("/")[-1]))
    idx = [item_names.index(i) for i in scene_names]
    gdf = gdf.iloc[idx].reset_index(drop=True)
    print(len(gdf), "items found in the GeoDataFrame.")

    # gdf.explore()
    # print(len(scene_list), "scenes found in the CSV file.")
    idx = [i for i in range(len(items.items)) if items.items[i].id in scene_names]
    new_items = [items.items[i] for i in idx]
    items.items = new_items
    items.save_object(f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json")
else:
    items = pystac.ItemCollection.from_file(items_file)
    print(f"Loaded {len(items.items)} items from {items_file}.")
items

In [None]:
ds = stac_load(
    items,
    bands=measurements + masking_band,
    # crs=crs,
    chunks={},
    # groupby="solar_day",
    resolution=resolution,
)
# ds[measurements] =  ds[measurements] - 1000
# ds[measurements] = (ds[measurements] / 256).clip(0, 255).astype("uint8")
ds

In [None]:
# ds.scl.attrs = {
#     "units": "1",
#     "nodata": 0,
#     "flags_definition": {
#         "qa": {
#             "bits": [0, 1, 2, 3, 4, 5, 6, 7],
#             "values": {
#                 "0": "no data",
#                 "1": "saturated or defective",
#                 "2": "dark area pixels",
#                 "3": "cloud shadows",
#                 "4": "vegetation",
#                 "5": "bare soils",
#                 "6": "water",
#                 "7": "unclassified",
#                 "8": "cloud medium probability",
#                 "9": "cloud high probability",
#                 "10": "thin cirrus",
#                 "11": "snow or ice",
#             },
#             "description": "Sen2Cor Scene Classification",
#         }
#     },
#     # "crs": crs,
#     "grid_mapping": "spatial_ref",
# }
# pq_mask = enum_to_bool(
#     mask=ds["scl"],
#     categories=(
#         "cloud high probability",
#         "cloud medium probability",
#         "thin cirrus",
#         "cloud shadows",
#         # "saturated or defective",
#     ),
# )
# # apply morphological filters (might improve cloud mask)
# pq_mask = mask_cleanup(pq_mask, mask_filters=mask_filters)

# # apply the cloud mask and drop scl layers
# ds = erase_bad(ds, where=pq_mask)
# ds = ds.drop_vars("scl")

# # remove nodata which is == 0
# ds = ds.where(ds > 0)

# # and remove any data that's above 10,000 (very dodgy)
# ds = ds.where(ds <= 10000)
# ds

In [None]:
s2_gm = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
# s2_gm = s2_gm.fillna(0)  # fill NaNs with 0s
s2_gm

In [None]:
gmed_odc = s2_gm[measurements[:3]].to_array().to_numpy()
gmed_odc = apply_gamma(
    flip_img((gmed_odc / 256).clip(0, 255).astype("uint8")), stretch_hist=True
)

In [None]:
# ds = ds.fillna(0)
ds_loaded = ds[measurements[:3]].compute()

In [None]:
img_data = np.moveaxis(ds_loaded.to_array().to_numpy(), [0, 1], [-2, -1])
img_data.shape

In [None]:
gmed_pcm = nangeomedian_pcm(img_data.astype("float32"), num_threads=4, eps=1e-4)
# gmed_pcm = (gmed_pcm / 256).clip(0, 255).astype("uint8")
gmed_pcm = apply_gamma((gmed_pcm / 256).clip(0, 255).astype("uint8"), stretch_hist=True)

In [None]:
imgs = [img_data[:, :, :, i] for i in range(img_data.shape[-1])]
print(len(imgs), "images loaded for geometric median computation.")

In [None]:
gmed_gm = gm(imgs)
# gmrd_gm = (gmed_gm.median / 256).clip(0, 255).astype("uint8")
gmed_gm = apply_gamma(
    (gmed_gm.median / 256).clip(0, 255).astype("uint8"), stretch_hist=True
)

In [None]:
ds.to_array()[:, 0, :, :].rio.to_raster(f"{output_dir}/temp.tif")
profile = rasterio.open(f"{output_dir}/temp.tif").profile
os.remove(f"{output_dir}/temp.tif")
profile.update({"dtype": "uint8", "nodata": 0, "count": 3})
profile

In [None]:
gmed_file_pcm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_pcm_{output_suffix}.tif"
if os.path.exists(gmed_file_pcm):
    os.remove(gmed_file_pcm)
with rasterio.open(gmed_file_pcm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_pcm[:, :, i], i + 1)

gmed_file_gm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_gm_{output_suffix}.tif"
if os.path.exists(gmed_file_gm):
    os.remove(gmed_file_gm)
with rasterio.open(gmed_file_gm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_gm[:, :, i], i + 1)

gmed_file_odc = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_{output_suffix}.tif"
if os.path.exists(gmed_file_odc):
    os.remove(gmed_file_odc)
with rasterio.open(gmed_file_odc, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_odc[:, :, i], i + 1)

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(10, 20))
axes[2, 0].imshow(
    apply_gamma((imgs[0] / 256).clip(0, 255).astype("uint8"), stretch_hist=True)
)
axes[2, 0].set_title("Image 0")
axes[2, 1].imshow(
    apply_gamma((imgs[1] / 256).clip(0, 255).astype("uint8"), stretch_hist=True)
)
axes[2, 1].set_title("Image 1")
axes[3, 0].imshow(
    apply_gamma((imgs[2] / 256).clip(0, 255).astype("uint8"), stretch_hist=True)
)
axes[3, 0].set_title("Image 2")
axes[3, 1].imshow(
    apply_gamma((imgs[3] / 256).clip(0, 255).astype("uint8"), stretch_hist=True)
)
axes[3, 1].set_title("Image 3")
axes[0, 0].imshow(flip_img(rasterio.open(gmed_file_pcm).read()))
axes[0, 0].set_title(f"Geometric Median of {len(imgs)} images (hdstats)")
axes[0, 1].imshow(flip_img(rasterio.open(gmed_file_gm).read()))
axes[0, 1].set_title(f"Geometric Median of {len(imgs)} images (geom_median)")
axes[1, 0].imshow(flip_img(rasterio.open(gmed_file_odc).read()))
axes[1, 0].set_title(f"Geometric Median of {len(imgs)} images (odc)")
for ax in axes.flat:
    ax.axis("off")
plt.suptitle(
    f"Geometric Median of {len(imgs)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({output_suffix.replace('_', ' ')})",
    fontsize=14,
    y=1.01,
)
plt.tight_layout()
plt.savefig(
    f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{output_suffix}.png",
    dpi=300,
)