In [1]:
# Packages
import os
import re
import sys
import rasterio
import numpy as np
import xarray as xr
import pandas as pd
from pathlib import Path
from datetime import timezone
import matplotlib.pyplot as plt
import planetary_computer as pc
from contextlib import nullcontext
from rasterio.plot import reshape_as_image
from typing import Optional, Tuple, Sequence

plt.style.use("~/geoscience/sail_project/MNRAS.mplstyle")
%matplotlib inline

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [5]:
function_path = os.path.expanduser("~/geoscience/sail_project/sentinel_2/s2_functions")
sys.path.append(function_path)
# import all the helper functions.
from tsi_functions import *

In [9]:
sys.path.append(function_path)
from s2_20m_download import * 

Returned 185 items


In [10]:
######################################################################################
######################################################################################
"""
Steps:
-----
- builds a 'time' coordinate if it’s missing (using 'base_time' + 'time_offset', or just 'time_offset' if it’s already datetime64),
- masks the '-100' fill values,
- for each matched pair '(tsi_time, s2_item)' makes a figure with:
- Left: a short time-series (±k min window) of 'percent_thin' and 'percent_opaque' with a vertical line at the chosen time, plus big text labels of the exact values at that time,
- Right: the Sentinel-2 visualization (RGB if available; otherwise grayscale), with cloud % if present.
- We also auto-detects whether the 'percent_*' are fractions [0–1] or percentages [0–100] and formats labels accordingly.
"""

# ---------- helpers ----------
def _date_to_tsi_file(cdf_dir: Path, date_str: str) -> Path | None:
    """Find a TSI .cdf file in cdf_dir whose filename contains YYYYMMDD for date_str."""
    ymd = pd.to_datetime(date_str).strftime("%Y%m%d")
    for fp in sorted(Path(cdf_dir).glob("*.cdf")):
        if ymd in fp.name:
            return fp
    return None

def _s2_cloud_pct(item):
    props = getattr(item, "properties", {}) or {}
    for key in ("eo:cloud_cover", "s2:cloud_percentage", "cloud_cover"):
        if key in props:
            try:
                return float(props[key])
            except Exception:
                pass
    return None

# def _stretch_01(a, p_low=2, p_high=98):
#     """Percentile stretch to [0,1] for display."""
#     a = np.asarray(a, dtype="float32")
#     finite = np.isfinite(a)
#     if not finite.any():
#         return np.zeros_like(a, dtype="float32")
#     lo, hi = np.percentile(a[finite], [p_low, p_high])
#     if hi <= lo:
#         hi = lo + 1e-6
#     return np.clip((a - lo) / (hi - lo), 0, 1)

def _load_s2_visual(item, prefer_rgb=True, use_aws_env=True):
    """
    Return (image, label) for a Sentinel-2 item.
    - Tries to sign (Planetary Computer) if available.
    - Tries RGB (B04,B03,B02) then any single band; else falls back to 'thumbnail'.
    - Applies percentile stretch to uint8 for display.
    """
    # Try to sign for MPC
    try:
        import planetary_computer as pc
        item = pc.sign(item)
    except Exception:
        pass

    import numpy as np
    import rasterio

    def _stretch_01(a, p_low=2, p_high=98):
        a = np.asarray(a, dtype="float32")
        m = np.isfinite(a)
        if not m.any():
            return np.zeros_like(a, dtype="float32")
        lo, hi = np.percentile(a[m], [p_low, p_high])
        if hi <= lo:
            hi = lo + 1e-6
        return np.clip((a - lo) / (hi - lo), 0, 1)

    def _open_band(href):
        with rasterio.open(href) as src:
            return src.read(1)

    # Optionally provide AWS env for s3://
    env = rasterio.Env(AWS_NO_SIGN_REQUEST="YES") if use_aws_env else None
    ctx = env if env is not None else nullcontext()

    from contextlib import nullcontext
    with ctx:
        # 1) Try RGB
        if prefer_rgb:
            for keys in (("B04","B03","B02"), ("red","green","blue")):
                hrefs = [item.assets[k].href for k in keys if k in item.assets]
                if len(hrefs) == 3:
                    try:
                        b = [_open_band(h) for h in hrefs]
                        rgb = np.dstack(b)               # H,W,3
                        rgb = (_stretch_01(rgb) * 255).astype("uint8")
                        return rgb, f"RGB: {keys}"
                    except Exception:
                        pass

        # 2) Any single band
        for k, asset in item.assets.items():
            try:
                g = _open_band(asset.href)
                g = (_stretch_01(g) * 255).astype("uint8")
                return g, f"Band: {k}"
            except Exception:
                continue

        # 3) Fallback: thumbnail (often HTTP)
        for k in ("thumbnail", "quicklook", "overview"):
            if k in item.assets:
                try:
                    import imageio.v3 as iio
                    img = iio.imread(item.assets[k].href)
                    # Ensure uint8
                    if img.dtype != np.uint8:
                        img = (_stretch_01(img) * 255).astype("uint8")
                    label = f"{k}"
                    return img, label
                except Exception:
                    pass

    raise ValueError(f"Could not read any visual asset for item {getattr(item, 'id', '<unknown>')}.")

# ---------- NEW: TSI time-series plotting per matched pair ----------
def _ensure_time_coord(ds: xr.Dataset) -> xr.Dataset:
    """Make sure ds has a proper 'time' coordinate (UTC)."""
    if "time" in ds.coords:
        return ds
    # Try to construct from base_time/time_offset
    if "time_offset" in ds:
        toff = ds["time_offset"].values
        if np.issubdtype(ds["time_offset"].dtype, np.datetime64):
            t = pd.to_datetime(toff)  # already absolute datetimes
        else:
            base = pd.Timestamp(ds["base_time"].values) if "base_time" in ds else pd.Timestamp(0, unit="s")
            # assume seconds for numeric offsets
            t = base + pd.to_timedelta(toff, unit="s")
        return ds.assign_coords(time=("time", pd.DatetimeIndex(t)))
    raise ValueError("No 'time' coordinate and cannot construct it (need 'time_offset' and optional 'base_time').")

def plot_matched_pairs_timeseries(
    cdf_dir,
    tsi_times,
    s2_items,
    thin_var="percent_thin",
    opaque_var="percent_opaque",
    fill_value=-100,
    window="30min",   # half-window on each side
    max_pairs=None,
):
    """
    For each (tsi_time, s2_item):
      - open that date’s TSI file
      - ensure time coord exists
      - read thin/opaque at nearest time, mask fill
      - plot a ±window time-series with a vertical line at the matched time
      - show S2 visual next to it and annotate cloud %
    """
    cdf_dir = Path(cdf_dir)
    pairs = list(zip(tsi_times, s2_items))
    if max_pairs is not None:
        pairs = pairs[:max_pairs]

    half = pd.Timedelta(window)

    for i, (tsi_ts, item) in enumerate(pairs, 1):
        date_str = pd.Timestamp(tsi_ts).tz_convert("UTC").date().isoformat()
        tsi_fp = _date_to_tsi_file(cdf_dir, date_str)
        if tsi_fp is None:
            print(f"[{i}] No TSI file for {date_str}, skipping.")
            continue

        try:
            ds = xr.open_dataset(tsi_fp)
            ds = _ensure_time_coord(ds)
        except Exception as e:
            print(f"[{i}] Failed TSI open/prepare for {tsi_fp.name}: {e}")
            continue

        # get time to sample (tz-naive UTC for xarray .sel)
        tsel = pd.Timestamp(tsi_ts).tz_convert("UTC").tz_localize(None)

        # mask fill values
        if thin_var not in ds or opaque_var not in ds:
            print(f"[{i}] Missing {thin_var}/{opaque_var} in {tsi_fp.name}, skipping.")
            ds.close()
            continue
        thin = ds[thin_var].where(ds[thin_var] != fill_value)
        opaque = ds[opaque_var].where(ds[opaque_var] != fill_value)

        # extract nearest value at tsel
        try:
            thin_val = float(thin.sel(time=tsel, method="nearest").item())
            opaque_val = float(opaque.sel(time=tsel, method="nearest").item())
        except Exception as e:
            print(f"[{i}] Could not sample thin/opaque at {tsel} in {tsi_fp.name}: {e}")
            ds.close()
            continue

        # determine display units (fraction vs percent)
        # heuristic: if max over the day > 1.5, treat as percent
        day_mask = (ds["time"] >= tsel.floor("D")) & (ds["time"] < (tsel.floor("D") + pd.Timedelta("1D")))
        day_thin = thin.where(day_mask, drop=True).values
        day_opaque = opaque.where(day_mask, drop=True).values
        scale_is_percent = False
        for arr in (day_thin, day_opaque):
            a = arr[np.isfinite(arr)]
            if a.size and np.nanmax(a) > 1.5:
                scale_is_percent = True
        fmt = (lambda x: f"{x:.1f}%") if scale_is_percent else (lambda x: f"{x:.3f}")

        # load Sentinel-2 visual
        try:
            s2_img, s2_label = _load_s2_visual(item)
        except Exception as e:
            print(f"[{i}] Could not load Sentinel-2 visual for {getattr(item, 'id','<unknown>')}: {e}")
            ds.close()
            continue
        s2_dt = pd.Timestamp(item.datetime).tz_convert("UTC")

        s2_good = []
        s2_bad = []
        thin_good = []
        thin_bad = []
        opaque_good = []
        opaque_bad = []
        thin = fmt(thin_val)
        opaque = fmt(opaque_val)
        s2_cloud = _s2_cloud_pct(item)
        print(f"cloud: {s2_cloud:.1f}%")
        print(thin)
        print(opaque)
        

In [11]:
for item in items:
    date = item.datetime.date().isoformat()
    # print(f"Processing {item.id} ({date})")

# Now match TSI times to sentinel-2 items:
cdf_dir = "/bsuhome/tnde/scratch/felix/tsi_sky_cover"
tsi_times, s2_items = match_tsi_to_s2(cdf_dir, items, thin_threshold=100, opaque_threshold=100)
print(len(tsi_times), len(s2_items))

# Verify lengths align and preview
print(f"Matched pairs: {len(tsi_times)}")
for t_tsi, s2 in list(zip(tsi_times, s2_items))[:5]:
    print(f"TSI: {t_tsi.isoformat()}  <->  S2: {s2.id} @ {s2.datetime.isoformat()}")

s2_items = [pc.sign(it) for it in s2_items]
plot_matched_pairs_timeseries(
    cdf_dir="/bsuhome/tnde/scratch/felix/tsi_sky_cover",
    tsi_times=tsi_times,
    s2_items=s2_items,
    thin_var="percent_thin",
    opaque_var="percent_opaque",
    fill_value=-100,
    window="30min",
    max_pairs=10,   # optional
)

184 184
Matched pairs: 184
TSI: 2021-09-03T17:49:00+00:00  <->  S2: S2B_MSIL2A_20210903T174909_R141_T13SCD_20210904T042732 @ 2021-09-03T17:49:09.024000+00:00
TSI: 2021-09-08T17:49:00+00:00  <->  S2: S2A_MSIL2A_20210908T174911_R141_T13SCD_20210909T104611 @ 2021-09-08T17:49:11.024000+00:00
TSI: 2021-09-13T17:49:00+00:00  <->  S2: S2B_MSIL2A_20210913T174909_R141_T13SCD_20210914T085302 @ 2021-09-13T17:49:09.024000+00:00
TSI: 2021-09-18T17:50:00+00:00  <->  S2: S2A_MSIL2A_20210918T175011_R141_T13SCD_20210919T051017 @ 2021-09-18T17:50:11.024000+00:00
TSI: 2021-09-23T17:50:00+00:00  <->  S2: S2B_MSIL2A_20210923T174949_R141_T13SCD_20210924T125157 @ 2021-09-23T17:49:49.024000+00:00
cloud: 59.6%
2.2%
92.7%
cloud: 0.8%
20.1%
15.1%
cloud: 38.9%
6.5%
59.0%
cloud: 2.0%
0.4%
1.5%
cloud: 6.0%
2.2%
2.4%
cloud: 64.0%
6.6%
29.7%
cloud: 4.5%
4.0%
39.4%
cloud: 93.9%
0.1%
99.4%
cloud: 85.6%
16.6%
65.6%
cloud: 9.2%
4.4%
30.2%
