In [None]:
from geom_median.numpy import compute_geometric_median as gm
import numpy as np
from utils import *
import dask
import dask.distributed
import xarray as xr
import rioxarray as rxr
from hdstats import nangeomedian_pcm
import geopandas as gpd
from odc.algo import (
    enum_to_bool,
    geomedian_with_mads,
    erase_bad,
    mask_cleanup,
    keep_good_only,
)
from odc.geo import BoundingBox
from odc.geo.xr import assign_crs
from odc.io.cgroups import get_cpu_quota
from odc.stac import configure_rio, stac_load
import logging
from functools import reduce
from matplotlib.gridspec import GridSpec

n_workers = 30
threads_per_worker = 1
client = dask.distributed.Client(
    n_workers=n_workers,
    threads_per_worker=threads_per_worker,
    silence_logs=logging.ERROR,
)

display(client)

In [None]:
MISSION = "LANDSAT-8"  # "LANDSAT-8" or "SENTINEL-2"

if MISSION in ["LANDSAT-8", "LANDSAT-9", "LANDSAT-4-5"]:
    aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)
    configure_rio(cloud_defaults=True, aws={"requester_pays": True}, client=client)
else:
    aws_session = rasterio.session.AWSSession(boto3.Session())
    configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
tas_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/TAS.kml").bounds), 0.1)
mount_box = resize_bbox(
    BoundingBox(*kml_to_poly("data/inputs_old/geo1.kml").bounds), 0.1
)
amery_rock = [67.45, -72.55, 67.55, -72.45]
amery_shelf = [73.47, -69.66, 74.71, -69.22]
hillary_coast = BoundingBox(*kml_to_poly("data/inputs_old/hillary.kml").bounds)
flincher_shelf = [-46, -84, -36, -74]
inland_ice = [126.0, -90, 136, -80]

bbox_list = [
    wa_bbox,
    amery_rock,
    mount_box,
    tas_bbox,
    amery_shelf,
    hillary_coast,
    flincher_shelf,
    inland_ice,
]

aoi_dict = {
    str(wa_bbox): "WA",
    str(tas_bbox): "TAS",
    str(mount_box): "MOUNT",
    str(amery_rock): "AMERY_ROCK",
    str(amery_shelf): "AMERY_SHELF",
    str(hillary_coast): "HILLARY_COAST",
    str(flincher_shelf): "FLINCHER_SHELF",
    str(inland_ice): "INLAND_ICE",
}

# Flincher and Amery shelves are high velocity ice

In [None]:
scale_factor = 10.0
max_cloud_cover = 5
min_scenes_per_id = 10

aoi_index = 6
bbox = bbox_list[aoi_index]
AOI = aoi_dict[str(bbox)]
masking_band = ["scl"]
if MISSION == "SENTINEL-1":
    if AOI == "TAS":
        bands = ["VH", "VV"]
    else:
        bands = ["HH"]
else:
    bands = ["red", "green", "blue"]
mask_filters = [("opening", 10), ("dilation", 1)]
# crs = "EPSG:3031"
resolution = 100 if MISSION in ["SENTINEL-1", "SENTINEL-2"] else 200
output_suffix = "manual_loader"
file_name_suffix = "odc_stac"  # odc
aoi_suffix = "_LARGE"

AOI = AOI + aoi_suffix if aoi_suffix else AOI
use_all_items = False

if aoi_suffix == "_LARGE":
    if AOI in ["AMERY_ROCK_LARGE", "AMERY_SHELF_LARGE"]:
        bbox = resize_bbox(BoundingBox(*bbox), scale_factor)
    elif AOI in ["INLAND_ICE_LARGE", "FLINCHER_SHELF_LARGE"]:
        pass
    else:
        bbox = resize_bbox(bbox, scale_factor)
print(AOI, bbox, bands)

In [None]:
extra_bands = (
    masking_band if MISSION == "SENTINEL-2" else None
)  # Only for Sentinel-2, Landsat-8 does not have SCL band
output_dir = f"data/inputs/{MISSION}_{AOI}"
process_dir = f"{output_dir}/true_colour"
process_ds_dir = f"{output_dir}/true_colour_ds"
ds_dir = f"{output_dir}/downsampled"
items_file = f"{output_dir}/items.json"
items_exist = os.path.exists(items_file)
print(items_file)
items_exist

In [None]:
if use_all_items and not items_exist:
    items_files = glob.glob(f"{output_dir}/items*.json")
    items_files = [f for f in items_files if f"{output_dir}/items.json" not in f]
    if len(items_files) > 0:
        items_list = [pystac.ItemCollection.from_file(f) for f in items_files]
        items = reduce(lambda x, y: x + y, items_list)
        items_exist = True
        items.save_object(f"{output_dir}/items.json")

In [None]:
if not items_exist:
    if MISSION == "SENTINEL-2":
        query = get_search_query(
            bbox,
            collections=["sentinel-2-l2a"],
            start_date="2016-01-01T00:00:00",
            end_date="2021-01-01T00:00:00",
            pystac_query=True,
        )
        use_pystac = True
        server_url = "https://earth-search.aws.element84.com/v1"
    elif MISSION == "SENTINEL-1":
        query = get_search_query(
            bbox,
            collections=["ga_s1_iw_vv_vh_c0", "ga_s1_iw_hh_c0"],
            start_date="2016-01-01T00:00:00",
            end_date="2021-01-01T00:00:00",
            pystac_query=True,
        )
        use_pystac = True
        server_url = "https://explorer.dev.dea.ga.gov.au/stac"
    elif MISSION == "LANDSAT-8":
        query = get_search_query(
            bbox,
            start_date="2013-01-01T00:00:00",
            end_date="2017-01-01T00:00:00",
            platform=["LANDSAT_8"],
            collection_category=None,
            collections=None,
            cloud_cover=5,
        )
        use_pystac = False
        server_url = "https://landsatlook.usgs.gov/stac-server/search"
    elif MISSION == "LANDSAT-4-5":
        query = get_search_query(
            bbox,
            start_date="1985-01-01T00:00:00",
            end_date="2010-12-30T00:00:00",
            platform=["LANDSAT_4", "LANDSAT_5"],
            collection_category=None,
            collections=None,
            cloud_cover=5,
        )
        use_pystac = False
        server_url = "https://landsatlook.usgs.gov/stac-server/search"

    display(query)
    items = query_stac_server(
        query,
        server_url,
        use_pystac=use_pystac,
        max_cloud_cover=max_cloud_cover if MISSION != "SENTINEL-1" else None,
    )
    print(f"Found {len(items)} items.")

    if len(items) > 0:
        scene_dict, scene_list = find_scenes_dict(
            items,
            one_per_month=False if MISSION == "LANDSAT-4-5" else True,
            acceptance_list=bands + ["thumbnail"],
            remove_duplicate_times=False if MISSION == "LANDSAT-4-5" else True,
            duplicate_idx=1,
            min_scenes_per_id=min_scenes_per_id,
        )
        pd.DataFrame(scene_list).to_csv(
            f"data/inputs/{MISSION}_{AOI}_scenes.csv", index=False
        )
        path_rows = list(scene_dict.keys())
        print("Found IDs: ", path_rows)

        path_row_list = [
            (i, path_row, len(scene_dict[path_row]))
            for i, path_row in enumerate(path_rows)
        ]
        pd.DataFrame(path_row_list, columns=["index", "path_row", "count"]).to_csv(
            f"data/inputs/{MISSION}_{AOI}_scene_counts.csv", index=False
        )
        print("Found scene counts: ", path_row_list)
        print("Found scenes counts after filtering: ", len(scene_list))

        items = pystac.ItemCollection(items)

In [None]:
if items_exist:
    path_row_list = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scene_counts.csv")
    print("Found scene counts: \n", path_row_list)

In [None]:
tile_id = ""
items_file = f"{output_dir}/items.json"
full_items_file = items_file
items_file = f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json"
items_exist = os.path.exists(items_file)

if use_all_items:
    tile_id = ""
    items_file = full_items_file
condition = tile_id if tile_id != "" else ""

In [None]:
if not items_exist:
    if use_all_items or tile_id != "":
        items = pystac.ItemCollection.from_file(full_items_file)
    scenes = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scenes.csv")
    scene_list = scenes.to_dict("records")
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(len(scene_list), "scenes found in the CSV file.")
    scene_names = [
        scene["scene_name"] for scene in scene_list if condition in scene["scene_name"]
    ]
    scene_ids = None
    if MISSION not in ["SENTINEL-1", "SENTINEL-2"]:
        scene_ids = [
            scene["scene_id"] for scene in scene_list if condition in scene["scene_id"]
        ]

    gdf = gpd.GeoDataFrame.from_features(items, "epsg:4326")
    if MISSION == "SENTINEL-2":
        id_col = "earthsearch:s3_path"
    elif MISSION == "SENTINEL-1":
        id_col = "title"
    elif MISSION == "LANDSAT-8":
        id_col = "landsat:scene_id"
    else:
        id_col = "landsat:scene_id"
    item_names = list(gdf[id_col].apply(lambda x: x.split("/")[-1]))
    checklist = scene_names if MISSION in ["SENTINEL-1", "SENTINEL-2"] else scene_ids
    idx = [item_names.index(i) for i in checklist]
    gdf = gdf.iloc[idx].reset_index(drop=True)
    print(len(gdf), "items found in the GeoDataFrame.")

    # gdf.explore()
    # print(len(scene_list), "scenes found in the CSV file.")
    if MISSION == "SENTINEL-2":
        idx = [i for i in range(len(items.items)) if items.items[i].id in scene_names]
        scene_list = [
            scene for scene in scene_list if scene["scene_name"] in scene_names
        ]
    elif MISSION == "SENTINEL-1":
        idx = [
            i
            for i in range(len(items.items))
            if items.items[i].properties["title"] in scene_names
        ]
        scene_list = [
            scene for scene in scene_list if scene["scene_name"] in scene_names
        ]
    else:
        idx = [
            i
            for i in range(len(items.items))
            if (
                items.items[i].properties["landsat:scene_id"] in scene_ids
                and items.items[i].id in scene_names
            )
        ]
        scene_list = [
            scene
            for scene in scene_list
            if (scene["scene_name"] in scene_names and scene["scene_id"] in scene_ids)
        ]
    new_items = [items.items[i] for i in idx]
    items.items = new_items
    items.save_object(f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json")
else:
    items = pystac.ItemCollection.from_file(items_file)
    scene_list = []
    features = items.to_dict()["features"]
    for feature in features:
        s = {}
        for b in bands:
            if b in feature["assets"]:
                s[b] = feature["assets"][b]["href"]
                s[b + "_alternate"] = (
                    s[b]
                    if MISSION in ["SENTINEL-1", "SENTINEL-2"]
                    else feature["assets"][b]["alternate"]["s3"]["href"]
                )
        s["scene_name"] = (
            feature["properties"]["title"] if MISSION == "SENTINEL-1" else feature["id"]
        )
        scene_list.append(s)
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(f"Loaded {len(items.items)} items from {items_file}.")
# items

In [None]:
images_dir = process_dir

In [None]:
ext = "tif" if MISSION in ["SENTINEL-1", "SENTINEL-2"] else "TIF"
scene_files = [
    f"{os.path.basename(os.path.dirname(scene[bands[0]]))}.{ext}"
    for scene in scene_list
]

In [None]:
if MISSION == "LANDSAT-8":
    patch_url = lambda x: x.replace(
        "https://landsatlook.usgs.gov/data", "s3://usgs-landsat"
    )
else:
    patch_url = None

ds_stac = stac_load(
    items[0:2],
    bands=bands,
    chunks={"x": 500, "y": 500},
    resolution=resolution,
    bbox=bbox,
    patch_url=patch_url,
    groupby="id",
    dtype="uint16" if MISSION in ["SENTINEL-2", "LANDSAT-8"] else "float32",
    nodata=0 if MISSION in ["SENTINEL-2", "LANDSAT-8"] else np.nan,
)
ds_stac

In [None]:
ext = "tif" if MISSION in ["SENTINEL-1", "SENTINEL-2"] else "TIF"
os.makedirs(images_dir, exist_ok=True)
for i in range(len(ds_stac.time)):
    ds_stac.isel(time=i).rio.to_raster(
        os.path.join(
            images_dir,
            f"{os.path.basename(os.path.dirname(items[i].assets[bands[0]].href))}.{ext}",
        )
    )
ds_stac = ds_stac.where(ds_stac > 0)

In [None]:
if MISSION == "SENTINEL-1":
    process_existing_outputs(
        glob.glob(f"{images_dir}/**"),
        output_dir,
        scale_factor=1.0,
        preserve_depth=True,  # True if you want to preserve the depth of the original dataset
        min_max_scaling=False,  # True if you want to apply min-max scaling
        stretch_contrast=True,
        gamma=0.5,
        three_channel=True,
        remove_nans=True,
        num_cpu=-1,
        write_pairs=False,
        subdir="true_colour_ds",
        file_name_suffix="",
    )
    images_dir = process_ds_dir

In [None]:
originals = glob.glob(f"{images_dir}/**")
originals = [f for f in originals if condition in f]
print(len(originals), "original scenes found.")

In [None]:
gmed_file_odc_stac = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_stac_{output_suffix}{'_full' if use_all_items else ''}.tif"
gmed_file_odc = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_{output_suffix}{'_full' if use_all_items else ''}.tif"

In [None]:
gmed_odc_stac = geomedian_with_mads(
    ds_stac,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
gmed_odc_stac = gmed_odc_stac.rio.write_crs(f"epsg:{ds_stac.rio.crs.to_epsg()}")

if os.path.exists(gmed_file_odc_stac):
    os.remove(gmed_file_odc_stac)

if MISSION == "SENTINEL-1":
    gmed_odc_stac_img = gmed_odc_stac[bands[:3]].to_array().to_numpy()
    gmed_odc_stac_img = np.nan_to_num(gmed_odc_stac_img, nan=0)
    gmed_odc_stac_img = apply_gamma(gmed_odc_stac_img, stretch_hist=True).astype(
        "uint8"
    )

    profile = rasterio.open(
        [
            f
            for f in glob.glob(images_dir + f"/*.{ext}")
            if os.path.basename(f) in scene_files
        ][0]
    ).profile
    profile["count"] = len(bands)
    # profile["transform"] = gmed_odc_stac.rio.transform()

    with rasterio.open(gmed_file_odc_stac, "w", **profile) as dst:
        for i in range(profile["count"]):
            dst.write(gmed_odc_stac_img[i, :, :], i + 1)
else:
    (gmed_odc_stac[bands[:3]] / 255).clip(0, 255).astype("uint8").rio.to_raster(
        gmed_file_odc_stac
    )

In [None]:
ext = "tif" if MISSION in ["SENTINEL-1", "SENTINEL-2"] else "TIF"
img_shapes = [
    (rasterio.open(f).count, rasterio.open(f).height, rasterio.open(f).width)
    for f in glob.glob(images_dir + f"/*.{ext}")
    if os.path.basename(f) in scene_files
]
print(len(img_shapes), "images found in the downsampled directory.")

transforms = [
    rasterio.open(f).transform
    for f in glob.glob(images_dir + f"/*.{ext}")
    if os.path.basename(f) in scene_files
]

shape_diffs = np.abs(np.diff(img_shapes, axis=0))
transform_diffs = np.abs(np.diff(transforms, axis=0))

shape_condition = np.any(shape_diffs != np.array([0, 0, 0]))
origin_condition = np.any(transform_diffs != np.zeros(9))
shape_condition, origin_condition

In [None]:
universal_masking = True
cluster_masks = True  # True if (MISSION == "SENTINEL-2") or (use_all_items) else False
force_warping = True

print(force_warping, shape_condition or origin_condition or force_warping)

if force_warping or shape_condition or origin_condition or MISSION == "SENTINEL-2":
    print("Images have different shapes, warping them to the same shape.")
    warps_dir = f"{output_dir}/warped/"
    if force_warping:
        shutil.rmtree(warps_dir, ignore_errors=True)
    os.makedirs(warps_dir, exist_ok=True)
    imgs_list = [
        f
        for f in glob.glob(images_dir + f"/*.{ext}")
        if os.path.basename(f) in scene_files
    ]
    mosaic, warps, profiles = make_mosaic(
        imgs_list,
        return_warps=True,
        return_profile_only=True,
        output_type="uint16",
        universal_masking=universal_masking,
        cluster_masks=cluster_masks,
        nodata=0,
        no_affine=True,
    )
    if universal_masking:
        # masks = warps[1]
        warps = warps[0]
    for i, warp in enumerate(warps):
        warp_path = os.path.join(warps_dir, os.path.basename(imgs_list[i]))
        if not os.path.exists(warp_path):
            with rasterio.open(warp_path, "w", **profiles[1]) as warp_ds:
                for i in range(3):
                    warp_ds.write(warp[:, :, i], i + 1)
    images_dir = warps_dir
    # warps = [np.moveaxis(warp, -1, 0) for warp in warps]

plt.imshow(mosaic / mosaic.max())
mosaic = None
warps = None

In [None]:
print(images_dir)
ext = "tif" if MISSION in ["SENTINEL-1", "SENTINEL-2"] else "TIF"
if MISSION == "SENTINEL-1":
    time_idx = 4
elif MISSION == "SENTINEL-2":
    time_idx = 2
else:
    time_idx = 3
times = [
    datetime.strptime(os.path.basename(f).split("_")[time_idx][0:8], "%Y%m%d")
    for f in glob.glob(images_dir + f"/*.{ext}")
    if os.path.basename(f) in scene_files
]
files = [
    f for f in glob.glob(images_dir + f"/*.{ext}") if os.path.basename(f) in scene_files
]

crs = meta[
    "crs"
].to_epsg()  # int(crs.split(":")[1])  # Extract EPSG code from CRS string
ds = create_dataset_from_files(
    files,
    times,
    crs,
    bands,
    chunks={"x": 500, "y": 500},
)
ds

In [None]:
gmed_odc = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
gmed_odc = gmed_odc.rio.write_crs(f"epsg:{crs}")
gmed_odc

In [None]:
if os.path.exists(gmed_file_odc):
    os.remove(gmed_file_odc)
if MISSION == "SENTINEL-1":
    gmed_odc[bands[:3]].astype("uint8").rio.to_raster(gmed_file_odc)
else:
    (gmed_odc[bands[:3]] / 255).clip(0, 255).astype("uint8").rio.to_raster(
        gmed_file_odc
    )

In [None]:
file_name_suffix = "odc"  # odc

In [None]:
enhance = True
# gm_outputs = [gmed_file_pcm, gmed_file_gm, gmed_file_odc, gmed_file_odc_stac]

gm_outputs = [gmed_file_odc] if file_name_suffix == "odc" else [gmed_file_odc_stac]

images_dir = process_ds_dir
imgs_files = [
    f for f in glob.glob(images_dir + f"/*.{ext}") if os.path.basename(f) in scene_files
][:4]

imgs = [rasterio.open(f).read() for f in imgs_files]

img_samples = imgs
to_plot = []
if MISSION == "SENTINEL-1":
    for img in img_samples:
        img[2, :, :] = (
            0  # ((img[1, :, :] + img[0, :, :]) / 2).astype("uint8")  # Create a 3-channel image
        )
        img = flip_img(img).astype("uint8")
        to_plot.append(img)
else:
    for img in img_samples:
        img = np.clip(flip_img(img) / 255, 0, 255).astype("uint8")
        to_plot.append(img)

gm_imgs = [flip_img(rasterio.open(f).read()).astype("uint8") for f in gm_outputs]

for img in gm_imgs:
    out_img = np.zeros((img.shape[0], img.shape[1], 3), dtype="uint8")
    for i in range(img.shape[2]):
        out_img[:, :, i] = img[:, :, i]
    to_plot.append(out_img)

if enhance:
    # to_plot = [apply_gamma(img, stretch_hist=True) for img in to_plot]
    to_plot = [img / img.max() for img in to_plot]

In [None]:
from matplotlib.gridspec import GridSpec

fig = plt.figure(figsize=(10, 20), dpi=300, constrained_layout=True)
gs = GridSpec(4, 2, figure=fig)
ax0 = fig.add_subplot(gs[0:2, 0:2])
ax1 = fig.add_subplot(gs[2, 0])
ax2 = fig.add_subplot(gs[2, 1])
ax3 = fig.add_subplot(gs[3, 0])
ax4 = fig.add_subplot(gs[3, 1])
ax1.imshow((to_plot[0]))
ax1.set_title("Image 0")
ax2.imshow(to_plot[1])
ax2.set_title("Image 1")
ax3.imshow(to_plot[2])
ax3.set_title("Image 2")
ax4.imshow(to_plot[3])
ax4.set_title("Image 3")
ax0.imshow(to_plot[4])
ax0.set_title(
    f"Geometric Median of {len(scene_files)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({output_suffix.replace('_', ' ')}) ({file_name_suffix.replace('_', ' ')})"
)
for ax in [ax0, ax1, ax2, ax3, ax4]:
    ax.axis("off")
plt.tight_layout()
plt.savefig(
    f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{file_name_suffix}_{output_suffix}.png",
    dpi=300,
)

In [None]:
file_name_suffix = "odc"
gm_tifs = glob.glob(output_dir + f"/*_{file_name_suffix}_manual_loader.tif")
pattern = r"_T\d+" if MISSION == "SENTINEL-1" else r"_\d+"
gm_tifs = [f for f in gm_tifs if re.search(pattern, f)]
gm_tifs

In [None]:
warps_dir = f"{output_dir}/warped_gms/"
shutil.rmtree(warps_dir, ignore_errors=True)
os.makedirs(warps_dir, exist_ok=True)
mosaic, warps, profiles = make_mosaic(
    gm_tifs,
    return_warps=True,
    return_profile_only=True,
    # output_type="uint16",
)
profile = profiles[1]
profile["nodata"] = 0
for i, warp in enumerate(warps):
    warp_path = os.path.join(warps_dir, os.path.basename(gm_tifs[i]))
    if not os.path.exists(warp_path):
        with rasterio.open(warp_path, "w", **profile) as warp_ds:
            for i in range(3):
                warp_ds.write(warp[:, :, i], i + 1)
# warps = [np.moveaxis(warp, -1, 0) for warp in warps]

plt.imshow(mosaic / mosaic.max())
mosaic = None
warps = None

gm_tifs = glob.glob(warps_dir + f"/*_{file_name_suffix}_manual_loader.tif")

In [None]:
ds = create_dataset_from_files(
    gm_tifs, crs=rasterio.open(gm_tifs[0]).crs.to_epsg(), bands=bands
)
ds

In [None]:
gmed_mosaic = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
gmed_mosaic = gmed_mosaic.rio.write_crs(
    f"epsg:{rasterio.open(gm_tifs[0]).crs.to_epsg()}"
)
gmed_mosaic

In [None]:
gmed_file_mosaic = f"data/inputs/{MISSION}_{AOI}/odc_geometric_median_mosaic_of_{file_name_suffix}_mosaics.tif"
if os.path.exists(gmed_file_mosaic):
    os.remove(gmed_file_mosaic)
gmed_mosaic[bands[:3]].astype("uint8").rio.to_raster(gmed_file_mosaic)

In [None]:
plt.figure(figsize=(10, 10), dpi=300)
gm_img = flip_img(rasterio.open(gmed_file_mosaic).read())
if gm_img.shape[2] == 2:
    gm_img = np.concatenate(
        [
            gm_img,
            gm_img[:, :, 0:1],
        ],
        axis=2,
    )
plt.imshow(gm_img / gm_img.max())
plt.axis("off")
plt.savefig(
    f"{output_dir}/odc_geometric_median_mosaic_of_{file_name_suffix}_mosaics_{MISSION}_{AOI}.png",
    dpi=300,
)

In [None]:
# fig, axes = plt.subplots(4, 2, figsize=(10, 20))
# axes[2, 0].imshow((to_plot[0]))
# axes[2, 0].set_title("Image 0")
# axes[2, 1].imshow(to_plot[1])
# axes[2, 1].set_title("Image 1")
# axes[3, 0].imshow(to_plot[2])
# axes[3, 0].set_title("Image 2")
# axes[3, 1].imshow(to_plot[3])
# axes[3, 1].set_title("Image 3")
# axes[0, 0].imshow(to_plot[4])
# axes[0, 0].set_title(f"Geometric Median of {len(imgs)} images (hdstats)")
# axes[0, 1].imshow((to_plot[5]))
# axes[0, 1].set_title(f"Geometric Median of {len(imgs)} images (geom median)")
# axes[1, 0].imshow(to_plot[6])
# axes[1, 0].set_title(f"Geometric Median of {len(imgs)} images (odc)")
# axes[1, 1].imshow(to_plot[7])
# axes[1, 1].set_title(f"Geometric Median of {len(imgs)} images (odc stac)")
# for ax in axes.flat:
#     ax.axis("off")
# plt.suptitle(
#     f"Geometric Median of {len(imgs)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({output_suffix.replace('_', ' ')})",
#     fontsize=14,
#     y=1.01,
# )
# plt.tight_layout()
# plt.savefig(
#     f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{output_suffix}.png",
#     dpi=300,
# )

In [None]:
# img = flip_img(rasterio.open("data/inputs/LANDSAT-8_MOUNT/geometric_median_odc_stac_manual_loader.tif").read())
# fig, ax = plt.subplots(figsize=(10, 5), dpi=300)
# ax.imshow(img)
# ax.axis("off")
# plt.suptitle("LANDSAT-8 Geometric Median Mosaic for 58 images over Antarctica", fontsize=14, y=1.01)
# plt.tight_layout()
# plt.savefig("AN_gm.png", dpi=300, bbox_inches='tight')