In [1]:
import gdown
import requests
import zipfile
from functools import partial
from pathlib import Path
import rasterio as rio
from matplotlib import pyplot as plt
from omnicloudmask import (
    predict_from_load_func,
    predict_from_array,
    load_ls8,
    load_multiband,
    load_s2,
)

In [2]:
test_data_dir = Path("test data")
test_data_dir.mkdir(exist_ok=True)

In [3]:
test_data_liks = {
    "LC81960302014022LGN00": "1ewmbD2YzxUS2IibMW5GTbcQyZIoz0TNf",
    "S2B_MSIL1C_20180302T150259_N0206_R125_T22WES_20180302T183800.SAFE": "1pGu_RdboqYcK4Q6_kjpnynCSzmNdUgcW",
    "S2A_MSIL2A_20170725T142751_N9999_R053_T19GBQ_20240410T040247.SAFE": "1ZEfXnNpWi75OV6fVhNvzbe6MhxsvXSI3",
}

In [4]:
def download_file_from_google_drive(file_id: str, destination: Path) -> None:
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, str(destination), quiet=False)

In [5]:
for file_name, file_id in test_data_liks.items():
    zip_file = test_data_dir / f"{file_name}.zip"
    if not zip_file.exists():
        download_file_from_google_drive(file_id, zip_file)

    with zipfile.ZipFile(zip_file, "r") as zip_ref:
        zip_ref.extractall(test_data_dir / file_name)

In [None]:
maxar_url = "https://maxar-opendata.s3.us-west-2.amazonaws.com/events/Emilia-Romagna-Italy-flooding-may23/ard/32/120000303231/2023-05-23/1050010033C95B00-ms.tif"
maxar_path = test_data_dir / "maxar.tif"
if not maxar_path.exists():
    with requests.get(maxar_url, stream=True) as response:
        with open(maxar_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

In [None]:
s2_l2a_Path = (
    test_data_dir / "S2A_MSIL2A_20170725T142751_N9999_R053_T19GBQ_20240410T040247.SAFE"
)
s2_l1c_Path = (
    test_data_dir / "S2B_MSIL1C_20180302T150259_N0206_R125_T22WES_20180302T183800.SAFE"
)
ls_path = test_data_dir / "LC81960302014022LGN00"
s2_l1c_Path.exists(), s2_l2a_Path.exists(), ls_path.exists()

In [None]:
s2_items = [s2_l2a_Path, s2_l1c_Path]
ls_items = [ls_path]

In [None]:
load_multiband_maxar = partial(load_multiband, resample_res=10, band_order=[1, 2, 4])

In [None]:
pred_paths = predict_from_load_func(
    load_func=load_s2, scene_paths=s2_items, inference_dtype="bf16", batch_size=2
)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
for ax, path in zip(axs, pred_paths):
    with rio.open(path) as src:
        ax.imshow(src.read(1))
plt.show()

In [None]:
load_s2_21m = partial(load_s2, resolution=21.0)
pred_paths = predict_from_load_func(
    load_func=load_s2_21m, scene_paths=s2_items, inference_dtype="bf16", batch_size=2
)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
for ax, path in zip(axs, pred_paths):
    with rio.open(path) as src:
        ax.imshow(src.read(1))
plt.show()

In [None]:
pred_paths = predict_from_load_func(
    load_func=load_ls8, scene_paths=ls_items, inference_dtype="bf16", batch_size=2
)

pred_array = rio.open(pred_paths[0]).read(1)
plt.imshow(pred_array, vmin=0, vmax=3)

In [None]:
pred_paths = predict_from_load_func(
    load_func=load_multiband_maxar,
    scene_paths=[maxar_path],
    inference_dtype="bf16",
    batch_size=2,
)

pred_array = rio.open(pred_paths[0]).read(1)
plt.imshow(pred_array, vmin=0, vmax=3)

In [None]:
for dtype in ["float32", "float16", "bfloat16"]:
    pred_paths = predict_from_load_func(
        load_func=load_multiband_maxar,
        scene_paths=[maxar_path],
        inference_dtype=dtype,
        batch_size=2,
    )

pred_array = rio.open(pred_paths[0]).read(1)
plt.imshow(pred_array, vmin=0, vmax=3)

In [None]:
RGNIR_array = load_multiband_maxar(maxar_path)[0]
pred_array = predict_from_array(RGNIR_array)
plt.imshow(pred_array[0], vmin=0, vmax=3)

In [None]:
pred_paths = predict_from_load_func(
    load_func=load_multiband_maxar,
    scene_paths=[maxar_path],
    inference_device="cpu",
)

pred_array = rio.open(pred_paths[0]).read(1)
plt.imshow(pred_array, vmin=0, vmax=3)