In [None]:
from geom_median.numpy import compute_geometric_median as gm
import numpy as np
from utils import *
import dask
import dask.distributed
import xarray as xr
import rioxarray as rxr
from hdstats import nangeomedian_pcm
import geopandas as gpd
from odc.algo import (
    enum_to_bool,
    geomedian_with_mads,
    erase_bad,
    mask_cleanup,
    keep_good_only,
)
from odc.geo import BoundingBox
from odc.geo.xr import assign_crs
from odc.io.cgroups import get_cpu_quota
from odc.stac import configure_rio, stac_load
import pystac
import logging

MISSION = "SENTINEL-2"  # or "LANDSAT-8"

client = dask.distributed.Client(
    n_workers=4, threads_per_worker=1, silence_logs=logging.ERROR
)

if MISSION == "LANDSAT-8":
    aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)
    configure_rio(cloud_defaults=True, aws={"requester_pays": True}, client=client)
else:
    aws_session = rasterio.session.AWSSession(boto3.Session())
    configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
display(client)

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
geo_1_box = resize_bbox(
    BoundingBox(*kml_to_poly("data/inputs_old/geo1.kml").bounds), 0.1
)
bbox_list = [
    wa_bbox,
    [67.45, -72.55, 67.55, -72.45],
    geo_1_box,
]  # WA and a small area in the Arctic

In [None]:
AOI = "AU_TEST"
aoi_idx = 0
bbox = bbox_list[aoi_idx]
masking_band = ["scl"]
measurements = ["red", "green", "blue"]
mask_filters = [("opening", 10), ("dilation", 1)]
# crs = "EPSG:3031"
resolution = 100 if MISSION == "SENTINEL-2" else 200
tile_id = ""
condition = tile_id if tile_id != "" else ""
output_suffix = "odc_loader"

In [None]:
bands = measurements + masking_band if MISSION == "SENTINEL-2" else measurements
output_dir = f"data/inputs/{MISSION}_{AOI}"
items_file = f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json"
items_exist = os.path.exists(items_file)
items_exist

In [None]:
# MISSION = "SENTINEL-1"
# bands = ["VH", "VV"]
# aus_bbox = BoundingBox(
#     left=147.07251,
#     bottom=-42.22120,
#     right=147.24274,
#     top=-42.03035,
#     crs="EPSG:4326"
# )
# bbox = aus_bbox

In [None]:
if not items_exist:
    if MISSION == "SENTINEL-2":
        query = get_search_query(
            bbox,
            collections=["sentinel-2-l2a"],
            start_date="2016-01-01T00:00:00",
            end_date="2021-01-01T00:00:00",
            pystac_query=True,
        )
        use_pystac = True
        server_url = "https://earth-search.aws.element84.com/v1"
    elif MISSION == "SENTINEL-1":
        query = get_search_query(
            bbox,
            collections=["ga_s1_iw_vv_vh_c0"],
            start_date="2016-01-01T00:00:00",
            end_date="2021-01-01T00:00:00",
            pystac_query=True,
        )
        use_pystac = True
        server_url = "https://explorer.dev.dea.ga.gov.au/stac"
    elif MISSION == "LANDSAT-8":
        query = get_search_query(
            bbox,
            start_date="2013-01-01T00:00:00",
            end_date="2017-01-01T00:00:00",
            platform=["LANDSAT_8"],
            collection_category=None,
            collections=None,
        )
        use_pystac = False
        server_url = "https://landsatlook.usgs.gov/stac-server/search"
    elif MISSION == "LANDSAT-4-5":
        query = get_search_query(
            bbox,
            start_date="1985-01-01T00:00:00",
            end_date="2010-12-30T00:00:00",
            platform=["LANDSAT_4", "LANDSAT_5"],
            collection_category=None,
            collections=None,
            cloud_cover=None,
        )
        use_pystac = False
        server_url = "https://landsatlook.usgs.gov/stac-server/search"

    display(query)
    items = query_stac_server(query, server_url, pystac=use_pystac)
    print(f"Found {len(items)} items.")

    scene_dict, scene_list = find_scenes_dict(
        items,
        one_per_month=True,
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=True,
        duplicate_idx=1,
    )
    pd.DataFrame(scene_list).to_csv(
        f"data/inputs/{MISSION}_{AOI}_scenes.csv", index=False
    )
    path_rows = list(scene_dict.keys())
    print("Found IDs: ", path_rows)

    items = pystac.ItemCollection(items)

In [None]:
scene_list = []
scene_dict = dict()
for feature in items:
    is_item = type(feature) != dict

    if is_item:
        feature = feature.to_dict()

    id = feature["properties"]["title"]

    if "landsat:scene_id" in feature["properties"]:
        scene_id = feature["properties"]["landsat:scene_id"]
    else:
        scene_id = None

    assets = feature["assets"]

    if len(bands) > 0:
        acceptance_condition = all([s in assets for s in bands])
    else:
        acceptance_condition = True

    if acceptance_condition:
        scene_dict[id] = dict(scene_id=scene_id)
        for s in bands:
            url = assets[s]["href"]
            if "alternate" in assets[s]:
                url_alternate = assets[s]["alternate"]["s3"]["href"]
            else:
                url_alternate = url

            scene_dict[id][s] = url
            scene_dict[id][f"{s}_alternate"] = url_alternate

f0 = items[0].to_dict() if type(items[0]) != dict else items[0]
if "landsat:scene_id" in f0["properties"]:
    path_rows = [k.split("_")[2] for k in scene_dict]
    time_ind = 3
else:
    if type(items[0]) == dict:
        path_rows = ["_".join(k.split("_")[3:6]) for k in scene_dict]
    else:
        path_rows = ["_".join(k.split("_")[3:5])[0:15] for k in scene_dict]
    time_ind = 2
scene_dict_pr = {}
for pr in path_rows:
    temp_dict = {}
    required_keys = [k for k in scene_dict if pr in k]
    for k in required_keys:
        temp_dict[k] = scene_dict[k]
    scene_dict_pr[pr] = temp_dict

# scene_dict_pr_time = {}
# for pr in scene_dict_pr:
#     se = pd.Series(list(scene_dict_pr[pr].keys())).astype("str")
#     g = [s.split("_")[time_ind][0:6] for s in list(scene_dict_pr[pr].keys())]
#     if len(start_end_years) != 0:
#         years = [
#             int(s.split("_")[time_ind][0:4]) for s in list(scene_dict_pr[pr].keys())
#         ]
#         year_range = range(start_end_years[0], start_end_years[1] + 1)
#         valid_idx = list(
#             filter(lambda i: years[i] in year_range, range(len(years)))
#         )
#         g = [g[i] for i in range(len(g)) if i in valid_idx]
#         se = se.iloc[valid_idx]
#     groups = list(se.groupby(g))
#     temp_dict_time = {}
#     for i, t in enumerate([el[0] for el in groups]):
#         if type(t) == tuple:
#             t = t[0]
#         temp_list = []
#         if one_per_month:
#             temp_dict = scene_dict_pr[pr][groups[i][1].iloc[0]]
#             temp_dict["scene_name"] = groups[i][1].iloc[0]
#             temp_list.append(temp_dict)
#         else:
#             for k in list(groups[i][1]):
#                 temp_dict = scene_dict_pr[pr][k]
#                 temp_dict["scene_name"] = k
#                 temp_list.append(temp_dict)
#         if remove_duplicate_times:
#             if duplicate_idx == 0:
#                 times_idx = sorted(
#                     np.unique(
#                         [
#                             re.findall(r"\d{8}", d["scene_name"])[0]
#                             for d in temp_list
#                         ],
#                         return_index=True,
#                     )[1].tolist()
#                 )
#             else:
#                 times_list = [
#                     re.findall(r"\d{8}", d["scene_name"])[0] for d in temp_list
#                 ]
#                 unique_times = np.unique(times_list)
#                 unique_idx_list = [
#                     [i for i in range(len(times_list)) if times_list[i] == ut]
#                     for ut in unique_times
#                 ]
#                 times_idx = []
#                 for i, idx in enumerate(unique_idx_list):
#                     if len(idx) < duplicate_idx + 1:
#                         temp_idx = len(idx) - 1
#                     else:
#                         temp_idx = duplicate_idx
#                     times_idx.append(idx[temp_idx])
#             temp_list = [temp_list[idx] for idx in times_idx]

#         scene_list.extend(temp_list)
#         temp_dict_time[t] = temp_list
#     scene_dict_pr_time[pr] = temp_dict_time

In [None]:
feature = items[0].to_dict()
id = feature["id"]
assets = feature["assets"]
scene_dict = {}
scene_dict[id] = dict(scene_id=None)
for s in ["data"]:
    url = assets[s]["href"]
    if "alternate" in assets[s]:
        url_alternate = assets[s]["alternate"]["s3"]["href"]
    else:
        url_alternate = url

    scene_dict[id][s] = url
    scene_dict[id][f"{s}_alternate"] = url_alternate
path_rows = ["_".join(k.split("_")[0:1]) for k in scene_dict]

scene_dict_pr = {}
for pr in path_rows:
    temp_dict = {}
    required_keys = [k for k in scene_dict if pr in k]
    for k in required_keys:
        temp_dict[k] = scene_dict[k]
    scene_dict_pr[pr] = temp_dict
path_rows, scene_dict, scene_dict_pr

In [None]:
# [(i, len(scene_dict[path_row])) for i, path_row in enumerate(path_rows)]
[path_rows[i] for i in [0, 6, 7]]

In [None]:
tile_id = "AMP"
condition = tile_id if tile_id != "" else ""

In [None]:
if not items_exist:
    scenes = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scenes.csv")
    scene_list = scenes.to_dict("records")
    print(len(scene_list), "scenes found in the CSV file.")
    scene_names = [
        scene["scene_name"] for scene in scene_list if condition in scene["scene_name"]
    ]
    scene_ids = None
    if MISSION != "SENTINEL-2":
        scene_ids = [
            scene["scene_id"] for scene in scene_list if condition in scene["scene_id"]
        ]

    gdf = gpd.GeoDataFrame.from_features(items, "epsg:4326")
    id_col = "earthsearch:s3_path" if MISSION == "SENTINEL-2" else "landsat:scene_id"
    item_names = list(gdf[id_col].apply(lambda x: x.split("/")[-1]))
    checklist = scene_names if MISSION == "SENTINEL-2" else scene_ids
    idx = [item_names.index(i) for i in checklist]
    gdf = gdf.iloc[idx].reset_index(drop=True)
    print(len(gdf), "items found in the GeoDataFrame.")

    # gdf.explore()
    # print(len(scene_list), "scenes found in the CSV file.")
    if MISSION == "SENTINEL-2":
        idx = [i for i in range(len(items.items)) if items.items[i].id in scene_names]
    else:
        idx = [
            i
            for i in range(len(items.items))
            if (
                items.items[i].properties["landsat:scene_id"] in scene_ids
                and items.items[i].id in scene_names
            )
        ]
    new_items = [items.items[i] for i in idx]
    items.items = new_items
    items.save_object(f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json")
else:
    items = pystac.ItemCollection.from_file(items_file)
    print(f"Loaded {len(items.items)} items from {items_file}.")
items

In [None]:
if MISSION == "SENTINEL-2":
    patch_url = None
else:
    patch_url = lambda x: x.replace(
        "https://landsatlook.usgs.gov/data/", "s3://usgs-landsat/"
    )

ds = stac_load(
    items,
    bands=bands,
    # crs=crs,
    chunks={},
    # groupby="solar_day",
    resolution=resolution,
    patch_url=patch_url,
)
# ds[measurements] =  ds[measurements] - 1000
# ds[measurements] = (ds[measurements] / 256).clip(0, 255).astype("uint8")
ds

In [None]:
# ds.scl.attrs = {
#     "units": "1",
#     "nodata": 0,
#     "flags_definition": {
#         "qa": {
#             "bits": [0, 1, 2, 3, 4, 5, 6, 7],
#             "values": {
#                 "0": "no data",
#                 "1": "saturated or defective",
#                 "2": "dark area pixels",
#                 "3": "cloud shadows",
#                 "4": "vegetation",
#                 "5": "bare soils",
#                 "6": "water",
#                 "7": "unclassified",
#                 "8": "cloud medium probability",
#                 "9": "cloud high probability",
#                 "10": "thin cirrus",
#                 "11": "snow or ice",
#             },
#             "description": "Sen2Cor Scene Classification",
#         }
#     },
#     # "crs": crs,
#     "grid_mapping": "spatial_ref",
# }
# pq_mask = enum_to_bool(
#     mask=ds["scl"],
#     categories=(
#         "cloud high probability",
#         "cloud medium probability",
#         "thin cirrus",
#         "cloud shadows",
#         # "saturated or defective",
#     ),
# )
# # apply morphological filters (might improve cloud mask)
# pq_mask = mask_cleanup(pq_mask, mask_filters=mask_filters)

# # apply the cloud mask and drop scl layers
# ds = erase_bad(ds, where=pq_mask)
# ds = ds.drop_vars("scl")

# # remove nodata which is == 0
# ds = ds.where(ds > 0)

# # and remove any data that's above 10,000 (very dodgy)
# ds = ds.where(ds <= 10000)
# ds

In [None]:
s2_gm = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
# s2_gm = s2_gm.fillna(0)  # fill NaNs with 0s
s2_gm

In [None]:
gmed_odc = s2_gm[measurements[:3]].to_array().to_numpy()
gmed_odc = apply_gamma(
    flip_img((gmed_odc / 256).clip(0, 255).astype("uint8")),
    stretch_hist=True if MISSION == "SENTINEL-2" else False,
)

In [None]:
# ds = ds.fillna(0)
ds_loaded = ds[measurements[:3]].compute()

In [None]:
img_data = np.moveaxis(ds_loaded.to_array().to_numpy(), [0, 1], [-2, -1])
img_data = np.where(np.isnan(img_data), 0, img_data)  # replace NaNs with 0s
img_data.shape

In [None]:
gmed_pcm = nangeomedian_pcm(img_data.astype("float32"), num_threads=4, eps=1e-4)
# gmed_pcm = (gmed_pcm / 256).clip(0, 255).astype("uint8")
gmed_pcm = apply_gamma(
    (gmed_pcm / 256).clip(0, 255).astype("uint8"),
    stretch_hist=True if MISSION == "SENTINEL-2" else False,
)

In [None]:
imgs = [img_data[:, :, :, i] for i in range(img_data.shape[-1])]
print(len(imgs), "images loaded for geometric median computation.")

In [None]:
gmed_gm = gm(imgs)
# gmrd_gm = (gmed_gm.median / 256).clip(0, 255).astype("uint8")
gmed_gm = apply_gamma(
    (gmed_gm.median / 256).clip(0, 255).astype("uint8"),
    stretch_hist=True if MISSION == "SENTINEL-2" else False,
)

In [None]:
ds_loaded.to_array()[:, 0, :, :].rio.to_raster(f"{output_dir}/temp.tif")
profile = rasterio.open(f"{output_dir}/temp.tif").profile
os.remove(f"{output_dir}/temp.tif")
profile.update({"dtype": "uint8", "nodata": 0, "count": 3})
profile

In [None]:
gmed_file_pcm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_pcm_{output_suffix}.tif"
if os.path.exists(gmed_file_pcm):
    os.remove(gmed_file_pcm)
with rasterio.open(gmed_file_pcm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_pcm[:, :, i], i + 1)

gmed_file_gm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_gm_{output_suffix}.tif"
if os.path.exists(gmed_file_gm):
    os.remove(gmed_file_gm)
with rasterio.open(gmed_file_gm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_gm[:, :, i], i + 1)

gmed_file_odc = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_{output_suffix}.tif"
if os.path.exists(gmed_file_odc):
    os.remove(gmed_file_odc)
with rasterio.open(gmed_file_odc, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_odc[:, :, i], i + 1)

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(10, 20))
axes[2, 0].imshow(
    apply_gamma(
        (imgs[0] / 256).clip(0, 255).astype("uint8"),
        stretch_hist=True if MISSION == "SENTINEL-2" else False,
    )
)
axes[2, 0].set_title("Image 0")
axes[2, 1].imshow(
    apply_gamma(
        (imgs[1] / 256).clip(0, 255).astype("uint8"),
        stretch_hist=True if MISSION == "SENTINEL-2" else False,
    )
)
axes[2, 1].set_title("Image 1")
axes[3, 0].imshow(
    apply_gamma(
        (imgs[2] / 256).clip(0, 255).astype("uint8"),
        stretch_hist=True if MISSION == "SENTINEL-2" else False,
    )
)
axes[3, 0].set_title("Image 2")
axes[3, 1].imshow(
    apply_gamma(
        (imgs[3] / 256).clip(0, 255).astype("uint8"),
        stretch_hist=True if MISSION == "SENTINEL-2" else False,
    )
)
axes[3, 1].set_title("Image 3")
axes[0, 0].imshow(flip_img(rasterio.open(gmed_file_pcm).read()))
axes[0, 0].set_title(f"Geometric Median of {len(imgs)} images (hdstats)")
axes[0, 1].imshow(flip_img(rasterio.open(gmed_file_gm).read()))
axes[0, 1].set_title(f"Geometric Median of {len(imgs)} images (geom_median)")
axes[1, 0].imshow(flip_img(rasterio.open(gmed_file_odc).read()))
axes[1, 0].set_title(f"Geometric Median of {len(imgs)} images (odc)")
for ax in axes.flat:
    ax.axis("off")
plt.suptitle(
    f"Geometric Median of {len(imgs)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({output_suffix.replace('_', ' ')})",
    fontsize=14,
    y=1.01,
)
plt.tight_layout()
plt.savefig(
    f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{output_suffix}.png",
    dpi=300,
)