In [None]:
from geom_median.numpy import compute_geometric_median as gm
import numpy as np
from utils import *
import dask
import dask.distributed
import xarray as xr
import rioxarray as rxr
from hdstats import nangeomedian_pcm
import geopandas as gpd
from odc.algo import (
    enum_to_bool,
    geomedian_with_mads,
    erase_bad,
    mask_cleanup,
    keep_good_only,
)
from odc.geo import BoundingBox
from odc.geo.xr import assign_crs
from odc.io.cgroups import get_cpu_quota
from odc.stac import configure_rio, stac_load

aws_session = rasterio.session.AWSSession(boto3.Session())
client = dask.distributed.Client(n_workers=4, threads_per_worker=1)
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
display(client)

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
bbox_list = [
    wa_bbox,
    [67.45, -72.55, 67.55, -72.45],
]  # WA and a small area in the Arctic

In [None]:
AOI = "AMERY_ROCK"
MISSION = "SENTINEL-2"
masking_band = ["scl"]
measurements = ["red", "green", "blue", "visual"]
mask_filters = [("opening", 10), ("dilation", 1)]
# crs = "EPSG:3031"
resolution = 100
tile_id = "41CPV"
condition = tile_id if tile_id != "" else ""
output_suffix = "manual_loader"

In [None]:
bands = measurements + masking_band
output_dir = f"data/inputs/{MISSION}_{AOI}"
process_dir = f"{output_dir}/true_colour"
process_ds_dir = f"{output_dir}/true_colour_ds"
ds_dir = f"{output_dir}/downsampled"
items_file = f"{output_dir}/items.json"
items_exist = os.path.exists(items_file)

In [None]:
if not items_exist:
    query = get_search_query(
        bbox_list[1],
        collections=["SENTINEL-2"],
        start_date="2016-01-01T00:00:00",
        end_date="2021-01-01T00:00:00",
        is_landsat=False,
    )
    query["collections"] = ["sentinel-2-l2a"]
    del query["page"]
    server_url = "https://earth-search.aws.element84.com/v1"
    display(query)
    items = query_stac_server(query, server_url, pystac=True, return_pystac_items=True)
    print(f"Found {len(items)} items.")

In [None]:
if not items_exist:
    s2_scenes = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scenes.csv")
    scene_list = s2_scenes.to_dict("records")
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(len(scene_list), "scenes found in the CSV file.")
    scene_names = [
        scene["scene_name"] for scene in scene_list if condition in scene["scene_name"]
    ]

    gdf = gpd.GeoDataFrame.from_features(items, "epsg:4326")
    item_names = list(gdf["earthsearch:s3_path"].apply(lambda x: x.split("/")[-1]))
    idx = [item_names.index(i) for i in scene_names]
    gdf = gdf.iloc[idx].reset_index(drop=True)
    print(len(gdf), "items found in the GeoDataFrame.")

    gdf.explore()
    times = [
        datetime.strptime(x["scene_name"].split("_")[2], "%Y%m%d") for x in scene_list
    ]
    # print(len(scene_list), "scenes found in the CSV file.")
    idx = [i for i in range(len(items.items)) if items.items[i].id in scene_names]
    new_items = [items.items[i] for i in idx]
    items.items = new_items
    items.save_object(f"{output_dir}/items.json")
else:
    items = pystac.ItemCollection.from_file(items_file)
    scene_list = []
    features = items.to_dict()["features"]
    for feature in features:
        s = {}
        for b in bands:
            if b in feature["assets"]:
                s[b] = feature["assets"][b]["href"]
                s[b + "_alternate"] = s[b]
        s["scene_name"] = feature["id"]
        scene_list.append(s)
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    times = [
        datetime.strptime(x["scene_name"].split("_")[2], "%Y%m%d") for x in scene_list
    ]
    print(f"Loaded {len(items.items)} items from {items_file}.")
items

In [None]:
images_dir = process_ds_dir

In [None]:
_, meta = stream_scene_from_aws(
    items[0].assets["red"].href, aws_session, metadata_only=True
)
resolution_ratio = [
    meta["profile"]["transform"].a / resolution,
    -meta["profile"]["transform"].e / resolution,
]
print(f"Resolution ratio: {resolution_ratio}")

In [None]:
download_and_process_series(
    scene_list,
    bands,
    bands_suffixes,
    output_dir,
    process_dir,
    process_ds_dir,
    aws_session=aws_session,
    keep_original_band_scenes=True,
    scale_factor=resolution_ratio,
);

In [None]:
# originals = glob.glob(f"{output_dir}/Originals/**/TCI.tif", recursive=True)

In [None]:
# os.makedirs(ds_dir, exist_ok=True)
# for original in originals:
#     ds_path = os.path.join(ds_dir, f"{original.split("/")[4]}.tif")
#     if not os.path.exists(ds_path):
#         ds = downsample_dataset(original, resolution_ratio, ds_path)
#         print(f"Downsampled {original} to {ds_path}")

In [None]:
imgs = [
    rasterio.open(f).read() for f in glob.glob(images_dir + "/*.tif") if condition in f
]
print(len(imgs), "images found in the downsampled directory.")
img_data = np.concatenate(
    [np.expand_dims(flip_img(img), axis=3) for img in imgs], axis=3
).astype("float32")

In [None]:
gmed_pcm = nangeomedian_pcm(img_data, num_threads=4, eps=1e-4)

In [None]:
gmed_gm = gm(imgs, maxiter=1000)

In [None]:
profile = rasterio.open(glob.glob(images_dir + "/*.tif")[0]).profile

gmed_file_pcm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_pcm_{output_suffix}.tif"
if os.path.exists(gmed_file_pcm):
    os.remove(gmed_file_pcm)
with rasterio.open(gmed_file_pcm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_pcm[:, :, i].astype("uint8"), i + 1)

gmed_file_gm = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_gm_{output_suffix}.tif"
if os.path.exists(gmed_file_gm):
    os.remove(gmed_file_gm)
with rasterio.open(gmed_file_gm, "w", **profile) as dst:
    for i in range(profile["count"]):
        dst.write(gmed_gm.median[i, :, :].astype("uint8"), i + 1)

In [None]:
dsl = [
    rxr.open_rasterio(f, band_as_variable=True, chunks={})
    .assign_coords(time=times[i])
    .expand_dims("time", axis=2)
    for i, f in enumerate(glob.glob(images_dir + "/*.tif"))
    if condition in f
]
print(len(dsl), "datasets found in the target directory.")
dsl[0]

In [None]:
# resampled_dsl = [
#     resample_xarray_dataset(
#         ds.transpose("time", "y", "x"), scale_factor=resolution_ratio
#     )
#     for ds in dsl
# ]

In [None]:
ds = (
    xr.concat(dsl, dim="time").transpose("time", "y", "x")
    # .chunk(
    #     {"x": dsl[0].to_array().shape[2], "y": dsl[0].to_array().shape[1], "time": 1}
    # )
    .drop_attrs()
)
ds["spatial_ref"] = meta[
    "crs"
].to_epsg()  # int(crs.split(":")[1])  # Extract EPSG code from CRS string
ds = ds.rename_vars({f"band_{i+1}": b for i, b in enumerate(bands[:3])})
ds = ds[["y", "x", "spatial_ref", "time"] + measurements[:3]]
ds

In [None]:
s2_gm = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
s2_gm

In [None]:
gmed_file_odc = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_{output_suffix}.tif"
if os.path.exists(gmed_file_odc):
    os.remove(gmed_file_odc)
s2_gm[measurements[:3]].rio.to_raster(gmed_file_odc)

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(10, 20))
axes[0, 0].imshow(flip_img(imgs[0]))
axes[0, 0].set_title("Image 0")
axes[0, 1].imshow(flip_img(imgs[1]))
axes[0, 1].set_title("Image 1")
axes[1, 0].imshow(flip_img(imgs[2]))
axes[1, 0].set_title("Image 2")
axes[1, 1].imshow(flip_img(imgs[3]))
axes[1, 1].set_title("Image 3")
axes[2, 0].imshow(flip_img(rasterio.open(gmed_file_pcm).read()))
axes[2, 0].set_title(f"Geometric Median of {len(imgs)} images (hdstats)")
axes[2, 1].imshow(flip_img(rasterio.open(gmed_file_gm).read()))
axes[2, 1].set_title(f"Geometric Median of {len(imgs)} images (geom_median)")
axes[3, 0].imshow(flip_img(rasterio.open(gmed_file_odc).read()))
axes[3, 0].set_title(f"Geometric Median of {len(imgs)} images (odc)")
for ax in axes.flat:
    ax.axis("off")
plt.suptitle(
    f"Geometric Median of {len(imgs)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({output_suffix.replace('_', ' ')})",
    fontsize=14,
    y=1.01,
)
plt.tight_layout()
plt.savefig(
    f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{output_suffix}.png",
    dpi=300,
)