In [None]:
import numpy as np
from utils import *
import dask
import dask.distributed
import geopandas as gpd
from odc.algo import geomedian_with_mads
from odc.geo import BoundingBox
from odc.stac import configure_rio, stac_load
from functools import reduce
from matplotlib.gridspec import GridSpec

dask_n_workers = 1  # os.cpu_count() - 2
threads_per_worker = os.cpu_count()
gmed_n_workers = 1  # os.cpu_count() - 2

shutil.rmtree("temp_dask_dir", ignore_errors=True)
os.makedirs("temp_dask_dir", exist_ok=True)

# Increase connection and TCP timeouts
dask.config.set({"distributed.comm.timeouts.connect": "60s"})
dask.config.set({"distributed.comm.timeouts.tcp": "300s"})

# Increase nanny timeouts (for worker management)
dask.config.set({"distributed.nanny.timeouts.startup": "300s"})
dask.config.set({"distributed.nanny.timeouts.connect": "300s"})
dask.config.set({"distributed.nanny.timeouts.terminate": "300s"})

client = dask.distributed.Client(
    n_workers=dask_n_workers,
    threads_per_worker=threads_per_worker,
    local_directory="temp_dask_dir",
)

display(client)

Specifying Mission and configuring AWS

In [None]:
aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)
configure_rio(cloud_defaults=True, aws={"requester_pays": True}, client=client)

Runtime params

In [None]:
MISSION = "LANDSAT-8"
max_cloud_cover = 5
min_scenes_per_id = 10
amery_rock = [67.45, -72.55, 67.55, -72.45]
AOI = "AMERY_ROCK_DEMO"
bands = ["red", "green", "blue"]
resolution = 200
band_scale = np.float32(2.0e-5)
band_offset = np.float32(-0.1)

In [None]:
output_dir = f"data/inputs/{MISSION}_{AOI}"
process_dir = f"{output_dir}/true_colour"
items_file = f"{output_dir}/items.json"
items_exist = os.path.exists(items_file)
print(items_file)
items_exist

Querying and processing data

In [None]:
if not items_exist:
    query = get_search_query(
        amery_rock,
        start_date="2013-01-01T00:00:00",
        end_date="2017-01-01T00:00:00",
        platform=["LANDSAT_8"],
        collection_category=None,
        collections=None,
        cloud_cover=max_cloud_cover,
    )
    use_pystac = False
    server_url = "https://landsatlook.usgs.gov/stac-server/search"
    display(query)

In [None]:
if not items_exist:
    items = query_stac_server(
        query,
        server_url,
        use_pystac=use_pystac,
        max_cloud_cover=max_cloud_cover,
    )
    print(f"Found {len(items)} items.")

In [None]:
if not items_exist and len(items) > 0:
    scene_dict, scene_list = find_scenes_dict(
        items.copy(),
        one_per_month=True,
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=True,
        duplicate_idx=1,
        min_scenes_per_id=min_scenes_per_id,
        id_filter="L1GT",
    )
    pd.DataFrame(scene_list).to_csv(
        f"data/inputs/{MISSION}_{AOI}_scenes.csv", index=False
    )
    path_rows = list(scene_dict.keys())
    print("Found IDs: ", path_rows)

    path_row_list = [
        (i, path_row, len(scene_dict[path_row])) for i, path_row in enumerate(path_rows)
    ]
    pd.DataFrame(path_row_list, columns=["index", "path_row", "count"]).to_csv(
        f"data/inputs/{MISSION}_{AOI}_scene_counts.csv", index=False
    )
    print("Found scene counts: ", path_row_list)
    print("Found scenes counts after filtering: ", len(scene_list))

    items = pystac.ItemCollection(items)

In [None]:
path_row_list = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scene_counts.csv")
print("Found scene counts: \n", path_row_list)

In [None]:
tile_id = ""
items_file = f"{output_dir}/items.json"
full_items_file = items_file
items_file = f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json"
items_exist = os.path.exists(items_file)
condition = tile_id if tile_id != "" else ""

In [None]:
if not items_exist:
    if tile_id != "":
        items = pystac.ItemCollection.from_file(full_items_file)
    scenes = pd.read_csv(f"data/inputs/{MISSION}_{AOI}_scenes.csv")
    scene_list = scenes.to_dict("records")
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(len(scene_list), "scenes found in the CSV file.")
    scene_names = [
        scene["scene_name"] for scene in scene_list if condition in scene["scene_name"]
    ]
    scene_ids = [
        scene["scene_id"] for scene in scene_list if condition in scene["scene_id"]
    ]

    gdf = gpd.GeoDataFrame.from_features(items, "epsg:4326")
    id_col = "landsat:scene_id"
    item_names = list(gdf[id_col].apply(lambda x: x.split("/")[-1]))
    idx = [item_names.index(i) for i in scene_ids]
    gdf = gdf.iloc[idx].reset_index(drop=True)
    print(len(gdf), "items found in the GeoDataFrame.")

    idx = [
        i
        for i in range(len(items.items))
        if (
            items.items[i].properties["landsat:scene_id"] in scene_ids
            and items.items[i].id in scene_names
        )
    ]
    scene_list = [
        scene
        for scene in scene_list
        if (scene["scene_name"] in scene_names and scene["scene_id"] in scene_ids)
    ]
    new_items = [items.items[i] for i in idx]
    items.items = new_items
    items.save_object(f"{output_dir}/items{'_' + tile_id if tile_id else ''}.json")
else:
    items = pystac.ItemCollection.from_file(items_file)
    scene_list = []
    features = items.to_dict()["features"]
    for feature in features:
        s = {}
        for b in bands:
            if b in feature["assets"]:
                s[b] = feature["assets"][b]["href"]
                s[b + "_alternate"] = feature["assets"][b]["alternate"]["s3"]["href"]
        s["scene_name"] = feature["id"]
        scene_list.append(s)
    bands_suffixes = get_band_suffixes(scene_list[0], bands)
    print(f"Loaded {len(items.items)} items from {items_file}.")
# items

In [None]:
# gdf.explore()

In [None]:
images_dir = process_dir

Dwonlading metadata

In [None]:
_, meta = stream_scene(
    items[0].assets[bands[0]].to_dict()["alternate"]["s3"]["href"],
    aws_session,
    metadata_only=True,
)
resolution_ratio = [
    meta["profile"]["transform"].a / resolution,
    -meta["profile"]["transform"].e / resolution,
]
print(f"Resolution ratio: {resolution_ratio}")

Dwonlading original files

In [None]:
scene_name_map = lambda x: (x.replace("_SR", ""))
download_and_process_series(
    scene_list,
    bands,
    bands_suffixes,
    output_dir,
    process_dir,
    "temp",
    aws_session=aws_session,
    keep_original_band_scenes=True,
    scene_name_map=scene_name_map,
    extra_bands=None,
    download_only=True,
    stream_out_scale_factor=resolution_ratio,
);

In [None]:
scene_files = [os.path.basename(scene["local_path"]) for scene in scene_list]

Processing original files and making true colour composites for manual loading

In [None]:
ext = "TIF"
orig_dir = f"{output_dir}/Originals"
dir_list = [glob.glob(f"{dir}/**") for dir in glob.glob(f"{orig_dir}/**")]
dir_list = [
    [
        list(filter(lambda x: x.endswith(f"{idx}.{ext}"), dir_name))[0]
        for idx in bands_suffixes
    ]
    for dir_name in dir_list
]
process_existing_outputs(
    dir_list,
    output_dir,
    scale_factor=1.0,  # resolution_ratio,
    preserve_depth=True,  # True if you want to preserve the depth of the original dataset
    min_max_scaling=False,  # True if you want to apply min-max scaling
    stretch_contrast=False,
    gamma=1.0,
    three_channel=False,
    remove_nans=False,
    num_cpu=-1,
    write_pairs=False,
)

In [None]:
ext = "TIF"
originals = glob.glob(f"{output_dir}/Originals/**/*.{ext}", recursive=True)
originals = [f for f in originals if condition in f]
print(len(originals), "original scenes found.")

Output files

In [None]:
gmed_file_odc_stac = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc_stac.tif"
gmed_file_odc = f"data/inputs/{MISSION}_{AOI}/geometric_median{'_' + tile_id if tile_id else ''}_odc.tif"

Loading data via odc-stac

In [None]:
orig_dir = f"{output_dir}/Originals"
print(f"Originals directory: {orig_dir}")

patch_url = lambda x: os.path.join(*([orig_dir] + x.split("/")[-2:]))

ds_stac = stac_load(
    items,
    bands=bands,
    chunks={},  # {"x": 500, "y": 500},
    groupby="id",
    resolution=resolution,
    patch_url=patch_url,
    preserve_original_order=True,
    crs=meta["crs"],
    # bbox=amery_rock,
)
ds_stac[bands] = ds_stac[bands] * band_scale + band_offset
ds_stac = ds_stac.where(ds_stac > 0)
ds_stac

Geomedian via odc-stac

In [None]:
gmed_odc_stac = geomedian_with_mads(
    ds_stac,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
    num_threads=gmed_n_workers,
)
gmed_odc_stac = gmed_odc_stac.rio.write_crs(f"epsg:{ds_stac.rio.crs.to_epsg()}")

if os.path.exists(gmed_file_odc_stac):
    os.remove(gmed_file_odc_stac)


gmed_odc_stac[bands[:3]].rio.to_raster(gmed_file_odc_stac)

Geo-referrencing input data for manual loading. Also optionally applying masking filters for removing shadow effects around edges.

In [None]:
ext = "TIF"
img_shapes = [
    (rasterio.open(f).count, rasterio.open(f).height, rasterio.open(f).width)
    for f in glob.glob(images_dir + f"/*.{ext}")
    if os.path.basename(f) in scene_files
]
print(len(img_shapes), "images found in the downsampled directory.")

transforms = [
    rasterio.open(f).transform
    for f in glob.glob(images_dir + f"/*.{ext}")
    if os.path.basename(f) in scene_files
]

shape_diffs = np.abs(np.diff(img_shapes, axis=0))
transform_diffs = np.abs(np.diff(transforms, axis=0))

shape_condition = np.any(shape_diffs != np.array([0, 0, 0]))
origin_condition = np.any(transform_diffs != np.zeros(9))
shape_condition, origin_condition

In [None]:
universal_masking = True
cluster_masks = True

images_dir = process_dir
print("Images have different shapes, warping them to the same shape.")
warps_dir = f"{output_dir}/warped/"

shutil.rmtree(warps_dir, ignore_errors=True)
os.makedirs(warps_dir, exist_ok=True)

imgs_list = [
    f for f in glob.glob(images_dir + f"/*.{ext}") if os.path.basename(f) in scene_files
]
mosaic, warps, profiles = make_mosaic(
    imgs_list,
    return_warps=True,
    return_profile_only=True,
    output_type="uint16",
    universal_masking=universal_masking,
    cluster_masks=cluster_masks,
    nodata=0,
)
if universal_masking:
    # masks = warps[1]
    warps = warps[0]
warp_profile = profiles[1]
warp_profile.update(
    blockxsize=warp_profile["width"], blockysize=1, tiled=False, interleave="pixel"
)
for i, warp in enumerate(warps):
    warp_path = os.path.join(warps_dir, os.path.basename(imgs_list[i]))
    if not os.path.exists(warp_path):
        with rasterio.open(warp_path, "w", **warp_profile) as warp_ds:
            for i in range(3):
                warp_ds.write(warp[:, :, i], i + 1)
images_dir = warps_dir
# warps = [np.moveaxis(warp, -1, 0) for warp in warps]

plt.imshow(mosaic / mosaic.max())
mosaic = None
warps = None

Manual loading of data

In [None]:
images_dir = f"{output_dir}/warped/"
print(images_dir)
ext = "TIF"
times = [
    datetime.strptime(os.path.basename(f).split("_")[3][0:8], "%Y%m%d")
    for f in glob.glob(images_dir + f"/*.{ext}")
    if os.path.basename(f) in scene_files
]
files = [f"{os.path.join(images_dir, item.id)}_PROC.{ext}" for item in items]
files = [f for f in files if os.path.basename(f) in scene_files]
assert all([os.path.exists(f) for f in files]), "Not all files exist!"

crs = meta[
    "crs"
].to_epsg()  # int(crs.split(":")[1])  # Extract EPSG code from CRS string
ds = create_dataset_from_files(
    files,
    times,
    crs,
    bands,
    band_scale=band_scale,
    band_offset=band_offset,
    chunks={},  # {"x": 500, "y": 500},
    # bbox=amery_rock,
    bbox_crs=4326,
)
ds

Geomedian for manually loaded data

In [None]:
gmed_odc = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
    num_threads=gmed_n_workers,
)
gmed_odc = gmed_odc.rio.write_crs(f"epsg:{crs}")
gmed_odc

In [None]:
if os.path.exists(gmed_file_odc):
    os.remove(gmed_file_odc)
gmed_odc[bands[:3]].rio.to_raster(gmed_file_odc)

Plotting results

In [None]:
file_name_suffix = "odc"  # odc

In [None]:
enhance = True

gm_outputs = [gmed_file_odc] if file_name_suffix == "odc" else [gmed_file_odc_stac]

images_dir = process_dir
imgs_files = [
    f for f in glob.glob(images_dir + f"/*.{ext}") if os.path.basename(f) in scene_files
][:4]

imgs = [rasterio.open(f).read() for f in imgs_files]

img_samples = imgs
to_plot = []

for img in img_samples:
    img = np.clip(flip_img(img) / 255, 0, 255).astype("uint8")
    to_plot.append(img)

gm_imgs = [np.nan_to_num(flip_img(rasterio.open(f).read())) for f in gm_outputs]

for img in gm_imgs:
    out_img = np.zeros((img.shape[0], img.shape[1], 3), dtype="float32")
    for i in range(img.shape[2]):
        out_img[:, :, i] = img[:, :, i]
    to_plot.append(out_img)

if enhance:
    # to_plot = [apply_gamma(img, stretch_hist=True) for img in to_plot]
    to_plot = [img / img.max() for img in to_plot]

In [None]:
fig = plt.figure(figsize=(10, 20), dpi=300)
gs = GridSpec(4, 2, figure=fig)
ax0 = fig.add_subplot(gs[0:2, 0:2])
ax1 = fig.add_subplot(gs[2, 0])
ax2 = fig.add_subplot(gs[2, 1])
ax3 = fig.add_subplot(gs[3, 0])
ax4 = fig.add_subplot(gs[3, 1])
ax1.imshow((to_plot[0]))
ax1.set_title("Image 0")
ax2.imshow(to_plot[1])
ax2.set_title("Image 1")
ax3.imshow(to_plot[2])
ax3.set_title("Image 2")
ax4.imshow(to_plot[3])
ax4.set_title("Image 3")
ax0.imshow(to_plot[4])
ax0.set_title(
    f"Geometric Median of {len(scene_files)} {MISSION} images from {AOI} AOI, {'ID: ' + tile_id if tile_id else ''}, ({file_name_suffix.replace('_', ' ')})"
)
for ax in [ax0, ax1, ax2, ax3, ax4]:
    ax.axis("off")
plt.tight_layout()
plt.savefig(
    f"{output_dir}/geometric_median_{MISSION}_{AOI}{'_' + tile_id if tile_id else ''}_{file_name_suffix}.png",
    dpi=300,
)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=300)
im1 = np.nan_to_num(
    flip_img(
        rasterio.open("temp_dem_imgs/geometric_median_odc_stac_incorrect.tif").read()
    )
)
im2 = np.nan_to_num(
    flip_img(
        rasterio.open("temp_dem_imgs/geometric_median_odc_stac_corrected.tif").read()
    )
)

axes[0].imshow(im1 / im1.max())
axes[0].set_title("Incorrect reflectance")
axes[0].axis("off")
axes[1].imshow(im2 / im2.max())
axes[1].set_title("Corrected reflectance")
axes[1].axis("off")
plt.tight_layout()

In [None]:
shutil.rmtree("temp_dask_dir")

In [None]:
shutil.rmtree("temp_data", ignore_errors=True)
ref = np.zeros((5, 5, 3), dtype=np.float32)
tgts = [np.zeros((5, 5, 3), dtype=np.float32) for _ in range(3)]
ref[1:4, 1:4, :] = 1.0
tgts[2][0:3, 0:3, :] = 1.0
plt.imshow(ref / ref.max())

In [None]:
plt.imshow(tgts[2] / tgts[2].max())

In [None]:
tgts[0] = ref.copy()
tgts[1] = ref.copy()
profile = {
    "driver": "GTiff",
    "dtype": "uint16",
    "nodata": 0,
    "width": 5,
    "height": 5,
    "count": 3,
    "crs": "EPSG:3031",
    "transform": rasterio.Affine(10.0, 0.0, 1000.0, 0.0, -10.0, 1000.0),
    "blockxsize": 5,
    "blockysize": 5,
    "tiled": False,
    "interleave": "band",
}
os.makedirs("temp_data", exist_ok=True)
ref_fp = "temp_data/ref.tif"
with rasterio.open(ref_fp, "w", **profile) as dst:
    for i in range(0, profile["count"]):
        dst.write(ref[:, :, i], i + 1)

for i, tgt in enumerate(tgts):
    tgt_fp = f"temp_data/tgt_{i}.tif"
    with rasterio.open(tgt_fp, "w", **profile) as dst:
        for i in range(0, profile["count"]):
            dst.write(tgt[:, :, i], i + 1)
paths = [ref_fp] + [f"temp_data/tgt_{i}.tif" for i in range(len(tgts))]
ds = create_dataset_from_files(paths, crs=3031, remove_val=None)
gmed = geomedian_with_mads(
    ds,
    reshape_strategy="yxbt",  #'yxbt' if data is larger than RAM
    compute_mads=False,  # True if you want triple MADs
)
gmed = gmed[["band_1", "band_2", "band_3"]]
gmed = gmed.rio.write_crs("epsg:3031")
gmed_img = gmed.to_array().to_numpy()
plt.imshow(flip_img(np.nan_to_num(gmed_img)))

In [None]:
print(gmed_img[0, :, :])
gmed_img[gmed_img > 0.0] = 1.0
plt.imshow(flip_img(gmed_img))