In [None]:
from geom_median.numpy import compute_geometric_median as gm
import numpy as np
from utils import *
import dask
import dask.distributed
import xarray as xr
import rioxarray as rxr
from hdstats import nangeomedian_pcm
import geopandas as gpd
from odc.algo import (
    enum_to_bool,
    geomedian_with_mads,
    erase_bad,
    mask_cleanup,
    keep_good_only,
)
from odc.geo import BoundingBox
from odc.geo.xr import assign_crs
from odc.io.cgroups import get_cpu_quota
from odc.stac import configure_rio, stac_load
import logging

MISSION = "SENTINEL-2"  # or "SENTINEL-2"

client = dask.distributed.Client(
    n_workers=4, threads_per_worker=1, silence_logs=logging.ERROR
)

if MISSION == "LANDSAT-8":
    aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)
    configure_rio(cloud_defaults=True, aws={"requester_pays": True}, client=client)
else:
    aws_session = rasterio.session.AWSSession(boto3.Session())
    configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
display(client)

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
geo_1_box = resize_bbox(
    BoundingBox(*kml_to_poly("data/inputs_old/geo1.kml").bounds), 0.1
)
bbox_list = [
    wa_bbox,
    [67.45, -72.55, 67.55, -72.45],
    geo_1_box,
]  # WA and a small area in the Arctic

In [None]:
AOI = "MOUNT"
aoi_index = 2
bbox = bbox_list[aoi_index]
masking_band = ["scl"]
measurements = ["red", "green", "blue"]
mask_filters = [("opening", 10), ("dilation", 1)]
# crs = "EPSG:3031"
resolution = 100 if MISSION == "SENTINEL-2" else 200
tile_id = "32DNF"
condition = tile_id if tile_id != "" else ""
output_suffix = "manual_loader"

In [None]:
bands = measurements + masking_band if MISSION == "SENTINEL-2" else measurements
output_dir = f"data/inputs/{MISSION}_{AOI}"
process_dir = f"{output_dir}/true_colour"
process_ds_dir = f"{output_dir}/true_colour_ds"
ds_dir = f"{output_dir}/downsampled"
items_file = f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json"
items_exist = os.path.exists(items_file)

In [None]:
if not items_exist:
    if MISSION == "SENTINEL-2":
        query = get_search_query(
            bbox,
            collections=["sentinel-2-l2a"],
            start_date="2016-01-01T00:00:00",
            end_date="2021-01-01T00:00:00",
            pystac_query=True,
        )
        use_pystac = True
        server_url = "https://earth-search.aws.element84.com/v1"
    elif MISSION == "LANDSAT-8":
        query = get_search_query(
            bbox,
            start_date="2013-01-01T00:00:00",
            end_date="2017-01-01T00:00:00",
            platform=["LANDSAT-8"],
            collection_category=None,
            collections=None,
        )
        use_pystac = False
        server_url = "https://landsatlook.usgs.gov/stac-server/search"
    elif MISSION == "LANDSAT-4-5":
        query = get_search_query(
            bbox,
            start_date="1985-01-01T00:00:00",
            end_date="2010-12-30T00:00:00",
            platform=["LANDSAT_4", "LANDSAT_5"],
            collection_category=None,
            collections=None,
            cloud_cover=None,
        )
        use_pystac = False
        server_url = "https://landsatlook.usgs.gov/stac-server/search"

    display(query)
    items = query_stac_server(query, server_url, pystac=use_pystac)
    print(f"Found {len(items)} items.")

    scene_dict, scene_list = find_scenes_dict(
        items,
        one_per_month=True,
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=True,
        duplicate_idx=1,
    )
    pd.DataFrame(scene_list).to_csv(
        f"data/inputs/{MISSION}_{AOI}_scenes.csv", index=False
    )
    path_rows = list(scene_dict.keys())
    print("Found IDs: ", path_rows)

    items = pystac.ItemCollection(items)

In [None]:
tile_id = "32DNF"
condition = tile_id if tile_id != "" else ""

In [None]:
if not items_exist:
    scenes = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scenes.csv")
    scene_list = scenes.to_dict("records")
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(len(scene_list), "scenes found in the CSV file.")
    scene_names = [
        scene["scene_name"] for scene in scene_list if condition in scene["scene_name"]
    ]
    scene_ids = None
    if MISSION != "SENTINEL-2":
        scene_ids = [
            scene["scene_id"] for scene in scene_list if condition in scene["scene_id"]
        ]

    gdf = gpd.GeoDataFrame.from_features(items, "epsg:4326")
    id_col = "earthsearch:s3_path" if MISSION == "SENTINEL-2" else "landsat:scene_id"
    item_names = list(gdf[id_col].apply(lambda x: x.split("/")[-1]))
    checklist = scene_names if MISSION == "SENTINEL-2" else scene_ids
    idx = [item_names.index(i) for i in checklist]
    gdf = gdf.iloc[idx].reset_index(drop=True)
    print(len(gdf), "items found in the GeoDataFrame.")

    # gdf.explore()
    # print(len(scene_list), "scenes found in the CSV file.")
    if MISSION == "SENTINEL-2":
        idx = [i for i in range(len(items.items)) if items.items[i].id in scene_names]
    else:
        idx = [
            i
            for i in range(len(items.items))
            if (
                items.items[i].properties["landsat:scene_id"] in scene_ids
                and items.items[i].id in scene_names
            )
        ]
    new_items = [items.items[i] for i in idx]
    items.items = new_items
    items.save_object(f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json")
else:
    items = pystac.ItemCollection.from_file(items_file)
    scene_list = []
    features = items.to_dict()["features"]
    for feature in features:
        s = {}
        for b in bands:
            if b in feature["assets"]:
                s[b] = feature["assets"][b]["href"]
                s[b + "_alternate"] = (
                    s[b]
                    if MISSION == "SENTINEL-2"
                    else feature["assets"][b]["alternate"]["s3"]["href"]
                )
        s["scene_name"] = feature["id"]
        scene_list.append(s)
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(f"Loaded {len(items.items)} items from {items_file}.")
items

In [None]:
images_dir = process_ds_dir

In [None]:
_, meta = stream_scene_from_aws(
    (
        items[0].assets["red"].href
        if MISSION == "SENTINEL-2"
        else items[0].assets["red"].to_dict()["alternate"]["s3"]["href"]
    ),
    aws_session,
    metadata_only=True,
)
resolution_ratio = [
    meta["profile"]["transform"].a / resolution,
    -meta["profile"]["transform"].e / resolution,
]
print(f"Resolution ratio: {resolution_ratio}")

In [None]:
download_and_process_series(
    scene_list,
    bands,
    bands_suffixes,
    output_dir,
    process_dir,
    process_ds_dir,
    aws_session=aws_session,
    keep_original_band_scenes=True,
    scale_factor=resolution_ratio,
);

In [None]:
# originals = glob.glob(f"{output_dir}/Originals/**/TCI.tif", recursive=True)

In [None]:
# os.makedirs(ds_dir, exist_ok=True)
# for original in originals:
#     ds_path = os.path.join(ds_dir, f"{original.split("/")[4]}.tif")
#     if not os.path.exists(ds_path):
#         ds = downsample_dataset(original, resolution_ratio, ds_path)
#         print(f"Downsampled {original} to {ds_path}")

In [None]:
ext = "tif" if MISSION == "SENTINEL-2" else "TIF"
imgs = [
    rasterio.open(f).read()
    for f in glob.glob(images_dir + f"/*.{ext}")
    if condition in f
]
print(len(imgs), "images found in the downsampled directory.")
img_shapes = [img.shape for img in imgs]
img_shapes = [img.shape for img in imgs]

shape_diffs = np.abs(np.diff(img_shapes, axis=0))
shape_condition = np.any(shape_diffs != np.array([0, 0, 0]))
if shape_condition:
    print("Images have different shapes, warping them to the same shape.")
    warps_dir = f"{output_dir}/warped/"
    os.makedirs(warps_dir, exist_ok=True)
    imgs_list = [f for f in glob.glob(images_dir + f"/*.{ext}") if condition in f]
    mosaic, warps, profiles = make_mosaic(
        imgs_list, return_warps=True, return_profile_only=True
    )
    for i, warp in enumerate(warps):
        warp_path = os.path.join(warps_dir, os.path.basename(imgs_list[i]))
        if not os.path.exists(warp_path):
            with rasterio.open(warp_path, "w", **profiles[1]) as warp_ds:
                for i in range(3):
                    warp_ds.write(warp[:, :, i], i + 1)
    images_dir = warps_dir
    warps = [np.moveaxis(warp, -1, 0) for warp in warps]

to_concat = warps if shape_condition else imgs
img_data = np.concatenate(
    [np.expand_dims(flip_img(img), axis=3) for img in to_concat], axis=3
).astype("float32")

# img_data = np.where(np.isnan(img_data), 0, img_data)  # replace NaNs with 0s

In [None]:
gmed_pcm = nangeomedian_pcm(img_data, num_threads=4, eps=1e-4)

In [None]:
gmed_gm = gm(to_concat, maxiter=1000)

In [None]:
profile = rasterio.open(glob.glob(images_dir + f"/*.{ext}")[0]).profile

gmed_file_pcm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_pcm_{output_suffix}.tif"
if os.path.exists(gmed_file_pcm):
    os.remove(gmed_file_pcm)
with rasterio.open(gmed_file_pcm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_pcm[:, :, i].astype("uint8"), i + 1)

gmed_file_gm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_gm_{output_suffix}.tif"
if os.path.exists(gmed_file_gm):
    os.remove(gmed_file_gm)
with rasterio.open(gmed_file_gm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_gm.median[i, :, :].astype("uint8"), i + 1)

In [None]:
time_idx = 2 if MISSION == "SENTINEL-2" else 3
times = [
    datetime.strptime(os.path.basename(f).split("_")[time_idx], "%Y%m%d")
    for f in glob.glob(images_dir + f"/*.{ext}")
    if condition in f
]
files = [f for f in glob.glob(images_dir + f"/*.{ext}") if condition in f]
dsl = [
    rxr.open_rasterio(f, band_as_variable=True, chunks={})
    .assign_coords(time=t)
    .expand_dims("time", axis=2)
    for (f, t) in zip(files, times)
    if condition in f
]
print(len(dsl), "datasets found in the target directory.")
dsl[0]

In [None]:
# resampled_dsl = [
#     resample_xarray_dataset(
#         ds.transpose("time", "y", "x"), scale_factor=resolution_ratio
#     )
#     for ds in dsl
# ]

In [None]:
ds = (
    xr.concat(dsl, dim="time").transpose("time", "y", "x")
    # .chunk(
    #     {"x": dsl[0].to_array().shape[2], "y": dsl[0].to_array().shape[1], "time": 1}
    # )
    .drop_attrs()
)
ds["spatial_ref"] = meta[
    "crs"
].to_epsg()  # int(crs.split(":")[1])  # Extract EPSG code from CRS string
ds = ds.rename_vars({f"band_{i+1}": b for i, b in enumerate(bands[:3])})
ds = ds[["y", "x", "spatial_ref", "time"] + measurements[:3]]
ds

In [None]:
s2_gm = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
s2_gm

In [None]:
gmed_file_odc = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_{output_suffix}.tif"
if os.path.exists(gmed_file_odc):
    os.remove(gmed_file_odc)
s2_gm[measurements[:3]].rio.to_raster(gmed_file_odc)

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(10, 20))
axes[2, 0].imshow(flip_img(imgs[0]))
axes[2, 0].set_title("Image 0")
axes[2, 1].imshow(flip_img(imgs[1]))
axes[2, 1].set_title("Image 1")
axes[3, 0].imshow(flip_img(imgs[2]))
axes[3, 0].set_title("Image 2")
axes[3, 1].imshow(flip_img(imgs[3]))
axes[3, 1].set_title("Image 3")
axes[0, 0].imshow(flip_img(rasterio.open(gmed_file_pcm).read()))
axes[0, 0].set_title(f"Geometric Median of {len(imgs)} images (hdstats)")
axes[0, 1].imshow(flip_img(rasterio.open(gmed_file_gm).read()))
axes[0, 1].set_title(f"Geometric Median of {len(imgs)} images (geom_median)")
axes[1, 0].imshow(flip_img(rasterio.open(gmed_file_odc).read()))
axes[1, 0].set_title(f"Geometric Median of {len(imgs)} images (odc)")
for ax in axes.flat:
    ax.axis("off")
plt.suptitle(
    f"Geometric Median of {len(imgs)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({output_suffix.replace('_', ' ')})",
    fontsize=14,
    y=1.01,
)
plt.tight_layout()
plt.savefig(
    f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{output_suffix}.png",
    dpi=300,
)