In [None]:
import odc.geo.xr  # noqa
from odc.stac import load, configure_s3_access
from pystac_client import Client

from coastlines.combined import http_to_s3_url

from odc.algo import to_f32, mask_cleanup

In [None]:
catalog = "https://landsatlook.usgs.gov/stac-server/"
client = Client.open(catalog)

_ = configure_s3_access(cloud_defaults=True, requester_pays=True)

In [None]:
bbox = (118.59041, 9.89017, 118.92507, 10.14109)

items = list(
    client.search(bbox=bbox, datetime="2023", collections=["landsat-c2l2-sr"]).items()
)

print(f"Found {len(items)} items")

In [None]:
data = load(
    items,
    bbox=bbox,
    chunks=dict(time=1, x=2048, y=2048),
    patch_url=http_to_s3_url,
    bands=["green", "nir08", "swir16", "qa_pixel"],
)

# Get the nodata mask, just for the two main bands
nodata_mask = (data.green == 0) | (data.swir16 == 0)

# Get cloud and cloud shadow mask
mask_bitfields = [3, 4]  # cloud, cloud shadow
bitmask = 0
for field in mask_bitfields:
    bitmask |= 1 << field

# Get cloud mask
cloud_mask = data["qa_pixel"].astype(int) & bitmask != 0
# Expand and contract the mask to clean it up
# DE Africa uses 10 and 5, which Alex doesn't like!
dilated_cloud_mask = mask_cleanup(cloud_mask, [("opening", 5), ("dilation", 6)])

# Convert to float and scale data to 0-1
for var in ["green", "nir08", "swir16"]:
    data[var] = to_f32(data[var], scale=0.0000275, offset=-0.2)

# Remove values outside the valid range (0-1), but not for nir or awei bands
invalid_ard_values = (
    (data["green"] < 0)
    | (data["green"] > 1)
    | (data["swir16"] < 0)
    | (data["swir16"] > 1)
    | (data["nir08"] < 0)
    | (data["nir08"] > 1)
)

data["scaled_green"] = (data["green"] + (1 - data["nir08"])) / 2
data["scaled_swir"] = (data["swir16"] + data["nir08"]) / 2
data["alex"] = (data.scaled_green - data.scaled_swir) / (
    data.scaled_green + data.scaled_swir
)

# Mask the data, setting the nodata value to `nan`
final_mask = nodata_mask | dilated_cloud_mask | invalid_ard_values
masked = data.where(~final_mask)
masked

In [None]:
median = (
    masked[["green", "nir08", "swir16", "alex", "scaled_green", "scaled_swir"]]
    .median("time")
    .compute()
)
median

In [None]:
median.scaled_swir.plot.hist(bins=100)

In [None]:
median.alex.odc.explore(cmap="RdBu")