In [None]:
from pystac_client import Client
import geopandas as gpd
import rasterio
from pathlib import Path
import pandas as pd
import requests
import planetary_computer as pc
import rioxarray as rxr
import numpy as np
import matplotlib.pyplot as plt

In [None]:
stac_url = "https://planetarycomputer.microsoft.com/api/stac/v1"

In [None]:
in_dir = Path('out')

files = sorted(in_dir.rglob("*_LSTC.tif"))

records = []

for f in files:
    parts = f.name.split("_")
    pathrow = parts[2]
    acq_date = parts[3]

    datetime = pd.to_datetime(acq_date, format = "%Y%m%d")

    records.append({
        "file": f,
        "pathrow": pathrow,
        "acq_date": acq_date,
    })

df = pd.DataFrame(records)
df

In [None]:
bbox_gpkg = gpd.read_file("bbox_sm.gpkg")
bbox_4326 = bbox_gpkg.to_crs("EPSG:4326")
minx, miny, maxx, maxy = bbox_4326.total_bounds
bbox = [minx, miny, maxx, maxy]

In [None]:
client = Client.open(stac_url)

start_date = "2019-11-01"
end_date = "2020-02-01"

query = client.search(
    bbox = bbox,
    collections = "landsat-c2-l2",
    datetime=f"{start_date}/{end_date}",
    query={"eo:cloud_cover": {"lt": 10}},
)

items = list(query.items())
print(f"Found: {len(items)} datasets")

In [None]:
first = items[0]
print(first.id)
print(list(first.assets.keys()))

In [None]:
stac_records = []

for item in items:
    scene_id = item.id
    
    dt = pd.to_datetime(item.datetime)
    acq_date = dt.normalize()
    date = acq_date.strftime("%Y%m%d")

    wrs_path = item.properties.get("landsat:wrs_path")
    wrs_row = item.properties.get("landsat:wrs_row")
    pathrow = f"{int(wrs_path):03d}{int(wrs_row):03d}"

    stac_records.append({
        "scene_id": scene_id,
        "pathrow": pathrow,
        "acq_date": date,
    })

stac_df = pd.DataFrame(stac_records)
stac_df.head()

In [None]:
matches = df.merge(
    stac_df,
    on=["pathrow", "acq_date"],
    how="inner"
)

In [None]:
out_dir = Path('data/validate')
out_dir.mkdir(exist_ok=True, parents=True)

scene_ids = matches["scene_id"].dropna().unique().tolist()
print(f"Matching scenes: {len(scene_ids)}")

client = Client.open(stac_url)

start_date = "2019-11-01"
end_date = "2020-02-01"

query2 = client.search(
    bbox = bbox,
    collections = "landsat-c2-l2",
    datetime=f"{start_date}/{end_date}",
    ids = scene_ids
)

items = list(query2.items())
print(len(items))

In [None]:
# first = items[0]
# print(first.id)
# print(list(first.assets.keys()))

In [None]:
# asset_keys = ["atran"]

# for item in items:
#     sid = item.id
#     for asset_key in asset_keys:
#         signed = pc.sign(item.assets[asset_key])
#         url = signed.href
#         out_path = out_dir / f"{sid}_{asset_key}.tif"

#         with requests.get(url, stream=True) as r:
#             r.raise_for_status()
#             with open(out_path, "wb") as f:
#                 for chunk in r.iter_content(8192):
#                     f.write(chunk)

In [None]:
matches = matches.set_index("scene_id")

In [None]:
from rasterio.plot import show

in_dir = Path('out/')
v_dir = Path('data/validate')

rmse_list = []

sse_total = 0.0
mae_total = 0.0
n_total = 0

for sid in matches.index:
    out_path = matches.loc[sid, "file"]
    v_path = v_dir / f"{sid}_lwir11.tif"

    model = rxr.open_rasterio(out_path, masked=True).squeeze(drop=True)
    v = rxr.open_rasterio(v_path, masked=True).squeeze(drop=True)

    v_match = v.rio.reproject_match(model)

    model_data = model.values.astype("float64")
    v_data = v_match.values.astype("float64")

    model_data[model_data == -9999] = np.nan
    v_data[v_data == 0] = np.nan
    mask = np.isfinite(model_data) & np.isfinite(v_data)

    v_data = np.where(mask, v_data, np.nan)

    model_data = model_data + 273.15
    v_data = (v_data * 0.00341802) + 149.0 

    diff = model_data[mask] - v_data[mask]
    rmse = np.sqrt(np.mean(diff ** 2))
    abs_diff = np.abs(diff)
    mae = np.nanmean(abs_diff)

    sse_total += np.nansum(diff**2)
    n_total += np.isfinite(diff).sum()
    mae_total += np.nansum(abs_diff)

    rmse_list.append({
        "scene_id": sid,
        "file": out_path,
        "rmse": rmse,
        "mae": mae
    })

    print(f"{sid}: RMSE = {rmse:.3f}, MAE = {mae:.3f}")

    print(
    sid,
    "model C:", np.nanmin(model_data), np.nanmax(model_data),
    "vali C:", np.nanmin(v_data), np.nanmax(v_data)
)

    common_limits = np.concatenate([
        model_data[np.isfinite(model_data)],
        v_data[np.isfinite(v_data)]
    ])

    vmin = np.nanpercentile(common_limits, 2)
    vmax = np.nanpercentile(common_limits, 98)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    # Left: Modelled LST
    ax1.set_title("Modelled LST (K)")
    im1 = show(
        model_data,
        ax = ax1,
        cmap = "inferno",
        vmin = vmin,
        vmax = vmax,
        origin = "upper"
    )

    ax1.annotate(
        "GDA94, EPSG:3577",
        xy=(0.01, 0.01),
        xycoords='axes fraction',
        fontsize=6,
        ha='left',
        va='bottom',
        bbox=dict(facecolor='white', alpha=0.0, edgecolor='none'),
        zorder=1000
    )
    cbar = fig.colorbar(im1.get_images()[0], ax=ax1)

    # Right: USGS LST
    ax2.set_title("USGS LST (K)")
    im2 = show(
        v_data,
        ax = ax2,
        cmap = "inferno",
        vmin = vmin,
        vmax = vmax,
        origin = "upper"
    )

    ax2.annotate(
        "GDA94, EPSG:3577",
        xy=(0.01, 0.01),
        xycoords='axes fraction',
        fontsize=6,
        ha='left',
        va='bottom',
        bbox=dict(facecolor='white', alpha=0.0, edgecolor='none'),
        zorder=1000
    )
    cbar = fig.colorbar(im2.get_images()[0], ax=ax2)

    # Make sure both subplots share the same x/y limits exactly
    ax2.set_xlim(ax1.get_xlim())
    ax2.set_ylim(ax1.get_ylim())

    plt.tight_layout()
    plt.show()

rmse_df = pd.DataFrame(rmse_list)


In [None]:
rmse_df

global_rmse = np.sqrt(sse_total / n_total)
global_mae = mae_total / n_total

print(f"Global RMSE across all scenes: {global_rmse:.3f} K")
print(f"Global MAE across all scenes: {global_mae:.3f} K")

In [None]:
# good_threshold = 8
# bad_threshold = 8

# good_df = rmse_df[rmse_df["rmse"] <= good_threshold]
# bad_df = rmse_df[rmse_df["rmse"] >= bad_threshold]

# good_scenes = good_df["file"].tolist()
# bad_scenes  = bad_df["file"].tolist()