## omc (downloading masked images per year)

In [None]:
# =========================================
# CONFIG
# =========================================
YEARS            = list(range(2017, 2021))
MONTH_START_END  = ("07-01", "08-31")
GRID             = "MGRS-05WMU"
MAX_CLOUD_COVER  = 70
BBOX_LL          = (-153.5, 70.5, -153, 71)
OUT_DIR          = "CDSE_scenes_masked/coverage70"

import os
os.environ["AWS_ACCESS_KEY_ID"]     = "C364NPCJK6JQ64OIMZJR"
os.environ["AWS_SECRET_ACCESS_KEY"] = "..."   # fill in securely
os.environ["AWS_REGION"]            = "us-east-1"
os.environ["AWS_S3_ENDPOINT"]       = "eodata.dataspace.copernicus.eu"
os.environ["AWS_VIRTUAL_HOSTING"]   = "FALSE"

# =========================================
# IMPORTS
# =========================================
from pathlib import Path
import math
import numpy as np
import xarray as xr
import rioxarray
from pyproj import Transformer
import rasterio as rio
from rasterio.enums import Resampling
from pystac_client import Client
from pystac import Item
from omnicloudmask import predict_from_array
from shapely.geometry import shape, box, mapping
from shapely.ops import transform as shapely_transform
# =========================================
# HELPERS
# =========================================
def search_s2_stac(start_date: str, end_date: str, grid: str, max_cloud_cover: int = 100) -> list[Item]:
    cat = Client.open("https://stac.dataspace.copernicus.eu/v1/")
    search = cat.search(
        collections=["sentinel-2-l2a"],
        datetime=f"{start_date}/{end_date}",
        query={"eo:cloud_cover": {"lte": max_cloud_cover}, "grid:code": {"eq": grid}},
    )
    items = list(search.items())
    print(f"  üîé Found {len(items)} items")
    return items

def prefer_s3_assets(items):
    out = []
    for it in items:
        it = it.clone()
        for a in it.assets.values():
            s3_href = None
            extra = (getattr(a, "extra_fields", None) or {})
            alt = extra.get("alternate") or extra.get("alternates")
            if isinstance(alt, dict):
                s3_href = (alt.get("s3") or alt.get("S3") or {}).get("href")
            elif isinstance(alt, list):
                for d in alt:
                    href = d.get("href")
                    if href and href.startswith("s3://"):
                        s3_href = href
                        break
            if s3_href:
                a.href = s3_href
        out.append(it)
    return out

def detect_epsg_and_bounds(items, bbox_ll_override=None):
    if not items:
        raise ValueError("No items")

    if bbox_ll_override is None:
        bbs = [it.bbox for it in items]
        minx = min(b[0] for b in bbs)
        miny = min(b[1] for b in bbs)
        maxx = max(b[2] for b in bbs)
        maxy = max(b[3] for b in bbs)
        bbox_ll = (minx, miny, maxx, maxy)
    else:
        bbox_ll = bbox_ll_override

    epsg = None
    for it in items:
        if "proj:epsg" in it.properties:
            epsg = int(it.properties["proj:epsg"])
            break
    if epsg is None:
        lon = (bbox_ll[0] + bbox_ll[2]) / 2.0
        lat = (bbox_ll[1] + bbox_ll[3]) / 2.0
        zone = int(math.floor((lon + 180) / 6) + 1)
        epsg = 32600 + zone if lat >= 0 else 32700 + zone

    tx = Transformer.from_crs("EPSG:4326", f"EPSG:{epsg}", always_xy=True)
    x1, y1 = tx.transform(bbox_ll[0], bbox_ll[1])
    x2, y2 = tx.transform(bbox_ll[2], bbox_ll[3])
    bounds_proj = (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
    return epsg, bbox_ll, bounds_proj

def projected_intersection_ratio(item_geom, aoi_bounds, epsg_out):
    # Transform AOI bbox (in lon/lat) to projected coords
    tx = Transformer.from_crs("EPSG:4326", f"EPSG:{epsg_out}", always_xy=True)
    aoi_proj = shapely_transform(tx.transform, box(*aoi_bounds))

    # Get item's footprint and project it too
    geom = shape(item_geom)
    geom_proj = shapely_transform(tx.transform, geom)

    inter = geom_proj.intersection(aoi_proj)

    if inter.is_empty:
        return 0.0

    return inter.area / aoi_proj.area

def rasterio_env():
    return rio.Env(
        AWS_S3_ENDPOINT=os.environ["AWS_S3_ENDPOINT"],
        AWS_REGION=os.environ["AWS_REGION"],
        AWS_VIRTUAL_HOSTING=os.environ["AWS_VIRTUAL_HOSTING"],
        GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
        CPL_VSIL_CURL_ALLOWED_EXTENSIONS="tif,gtiff,jp2,xml"
    )

# =========================================
# PROCESS
# =========================================
BAND_ORDER  = ["B02_10m","B03_10m","B04_10m","B08_10m","B11_20m","B12_20m"]
BAND_LABELS = ["Blue","Green","Red","NIR","SWIR1","SWIR2"]

def process_year(year: int, grid: str, max_cloud: int, bbox_ll):
    print(f"\n==== Year {year} | grid={grid} | clouds‚â§{max_cloud}% ====")
    start_date = f"{year}-{MONTH_START_END[0]}"
    end_date   = f"{year}-{MONTH_START_END[1]}"

    items = search_s2_stac(start_date, end_date, grid, max_cloud_cover=max_cloud)
    if not items:
        print("  ‚ö†Ô∏è No items for this year.")
        return 0

    items_s3 = prefer_s3_assets(items)
    epsg_out, bbox_ll_used, bounds_out = detect_epsg_and_bounds(items, bbox_ll_override=bbox_ll)
    print(f"  EPSG={epsg_out} | bounds_proj={tuple(round(v,2) for v in bounds_out)}")

    bands_10m = [b for b in BAND_ORDER if b.endswith("10m")]
    bands_20m = [b for b in BAND_ORDER if b.endswith("20m")]

    out_dir = Path(OUT_DIR) / str(year)
    out_dir.mkdir(parents=True, exist_ok=True)

    n_ok = 0

    with rasterio_env():
        for it_s3, it_orig in zip(items_s3, items):
            scene_date = it_orig.properties.get("datetime", "").split("T")[0]
            print(f"\n‚Üí Scene {it_orig.id} ({scene_date})")
            # ------------------------------------------------
            # AOI intersection check (BEFORE loading bands)
            # ------------------------------------------------
            coverage_ratio = projected_intersection_ratio(
                item_geom = it_orig.geometry,
                aoi_bounds = bbox_ll,      # in lon/lat!
                epsg_out = epsg_out        # detected for the tile
            )

            print(f"   ‚ÑπÔ∏è AOI intersection coverage: {coverage_ratio:.2%}")

            if coverage_ratio < 0.4:  # Example: require 5% coverage
                print("   ‚ö†Ô∏è Scene skipped due to low AOI coverage.")
                continue

            try:
                ref = None
                pieces = []

                for bname in bands_10m:
                    if bname not in it_s3.assets:
                        continue
                    href = it_s3.assets[bname].href
                    da = rioxarray.open_rasterio(href, masked=True).squeeze("band", drop=True)
                    if da.rio.crs is None:
                        da = da.rio.write_crs(f"EPSG:{epsg_out}")
                    da = da.rio.clip_box(*bounds_out)
                    if ref is None:
                        ref = da
                    pieces.append(da.expand_dims("band"))

                for bname in bands_20m:
                    if bname not in it_s3.assets:
                        continue
                    href = it_s3.assets[bname].href
                    da20 = rioxarray.open_rasterio(href, masked=True).squeeze("band", drop=True)
                    if da20.rio.crs is None:
                        da20 = da20.rio.write_crs(f"EPSG:{epsg_out}")
                    da20 = da20.rio.clip_box(*bounds_out)
                    da20u = da20.rio.reproject_match(ref, resampling=Resampling.bilinear)
                    pieces.append(da20u.expand_dims("band"))

                if not pieces:
                    print("   ‚ö†Ô∏è No usable bands.")
                    continue

                scene = xr.concat(pieces, dim="band")
                scene = scene.assign_coords(band=BAND_LABELS)
                if scene.rio.crs is None:
                    scene = scene.rio.write_crs(f"EPSG:{epsg_out}")

                # MASKING
                red   = scene.sel(band="Red").values
                green = scene.sel(band="Green").values
                nir   = scene.sel(band="NIR").values
                input_array = np.stack([red, green, nir], axis=0)

                try:
                    pred_mask = predict_from_array(input_array)

                    # Handle shape (1, H, W) or (3, H, W)
                    if pred_mask.ndim == 3:
                        if pred_mask.shape[0] == 1:
                            pred_mask = pred_mask[0]
                        elif pred_mask.shape[0] == 3:
                            pred_mask = pred_mask[1]  # assume class 1 = cloud

                    # Ensure mask shape matches (y, x)
                    if pred_mask.shape != (scene.sizes["y"], scene.sizes["x"]):
                        raise ValueError(f"‚ùå Mask shape {pred_mask.shape} does not match scene shape {(scene.sizes['y'], scene.sizes['x'])}")

                    # Keep only pixels where class == 0
                    mask_keep = pred_mask == 0

                    mask_da = xr.DataArray(
                        mask_keep,
                        dims=("y", "x"),
                        coords={"y": scene.coords["y"], "x": scene.coords["x"]}
                    )

                    scene = scene.where(mask_da)
                    print("   ‚úî Cloud mask applied.")

                except Exception as e:
                    print("   ‚ö†Ô∏è Cloud mask failed:", e)

                scene_u16 = (
                    scene.fillna(0)
                    .clip(0, 10000)
                    .astype("uint16")
                    .rio.write_nodata(0)
                    .rio.write_crs(f"EPSG:{epsg_out}")
                )

                out_path = out_dir / f"{it_orig.id}_{scene_date}_masked.tif"
                print("   üíæ Saving ‚Üí", out_path)
                scene_u16.transpose("band", "y", "x").rio.to_raster(
                    out_path,
                    driver="GTiff",
                    compress="deflate",
                    tiled=True,
                    predictor=2,
                    BIGTIFF="IF_SAFER",
                    blockxsize=512,
                    blockysize=512,
                )
                n_ok += 1

            except Exception as e:
                print("   ‚ùå Scene failed:", e)

    return n_ok

# =========================================
# RUN ALL YEARS
# =========================================
all_counts = {}
for yr in YEARS:
    n = process_year(yr, GRID, MAX_CLOUD_COVER, bbox_ll=BBOX_LL)
    all_counts[yr] = n

print("\n‚úÖ Done. Scenes written per year:")
for yr, n in all_counts.items():
    print(f" ‚Ä¢ {yr}: {n} scenes")


## Stack to median

In [None]:
from pathlib import Path
import xarray as xr
import rioxarray
import numpy as np

# --------------------------------------------
# CONFIG
# --------------------------------------------
TIF_DIR     = Path("CDSE_scenes_masked/coverage70/2025")  # Folder with the TIFFs
OUT_PATH    = Path("CDSE_2025_median_70.tif")   # Output file
BAND_LABELS = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]  # Optional

# --------------------------------------------
# LOAD TIFFS
# --------------------------------------------
tif_files = sorted(TIF_DIR.glob("*.tif"))
print(f"üóÇ Found {len(tif_files)} TIFFs")

scenes = []

for f in tif_files:
    try:
        ds = rioxarray.open_rasterio(f, masked=True)  # shape: (band, y, x)
        if "band" not in ds.coords:
            ds = ds.assign_coords(band=range(1, ds.sizes["band"] + 1))

        # Optionally set band names
        if len(BAND_LABELS) == ds.sizes["band"]:
            ds = ds.assign_coords(band=BAND_LABELS)

        scenes.append(ds.expand_dims(time=[f.name]))  # add time dimension
        print(f"   ‚úì Loaded {f.name}")
    except Exception as e:
        print(f"   ‚ö†Ô∏è Failed to load {f.name}: {e}")

if not scenes:
    raise RuntimeError("‚ùå No scenes could be loaded.")

# --------------------------------------------
# STACK + MEDIAN
# --------------------------------------------
stack = xr.concat(scenes, dim="time")
print("üìä Stack shape:", stack.shape)

median_img = stack.median(dim="time", skipna=True)

# --------------------------------------------
# SAVE MEDIAN STACK
# --------------------------------------------
median_img_u16 = (
    median_img
    .clip(0, 10000)
    .fillna(0)
    .astype("uint16")
    .rio.write_nodata(0)
)

print(f"üíæ Saving median image to {OUT_PATH}")
median_img_u16.rio.to_raster(
    OUT_PATH,
    driver="GTiff",
    compress="deflate",
    tiled=True,
    predictor=2,
    BIGTIFF="IF_SAFER",
    blockxsize=512,
    blockysize=512,
)

print("‚úÖ Done.")


## Calculating TC images

In [None]:
# === tasseled_cap_mosaic_generation.py ===
from pathlib import Path
import numpy as np
import xarray as xr
import rioxarray
from dask.diagnostics import ProgressBar
import warnings

warnings.filterwarnings("ignore", category=UserWarning, message=".*coordinate precision.*")

median_dir = Path("omc_medians_70")       
tc_dir = Path("omc_tc_70")                      
tc_dir.mkdir(exist_ok=True)
years = list(range(2025, 2026))

# Sentinel-2 Tasseled Cap coefficients 
coeffs = {
    "tcb": dict(Blue=0.3037, Green=0.2793, Red=0.4743, NIR=0.5585, SWIR1=0.5082, SWIR2=0.1863),
    "tcg": dict(Blue=-0.2848, Green=-0.2435, Red=-0.5436, NIR=0.7243, SWIR1=0.0840, SWIR2=-0.1800),
    "tcw": dict(Blue=0.1509, Green=0.1973, Red=0.3279, NIR=0.3406, SWIR1=-0.7112, SWIR2=-0.4572),
}

for year in years:
    in_file = median_dir / f"CDSE_{year}_median_70.tif"
    out_file = tc_dir / f"tc_CDSE_{year}_median_70.tif"

    if not in_file.exists():
        print(f"‚ùå Missing median mosaic for {year}")
        continue
    if out_file.exists():
        print(f"‚è≠Ô∏è Already exists, skipping {out_file}")
        continue

    print(f"‚úÖ Loading: {in_file}")
    # Important: masked=True makes rioxarray treat nodata (0) as NaN
    da = rioxarray.open_rasterio(in_file, chunks={"x": 1024, "y": 1024}, masked=True)

    # assign band names, convert reflectance to 0‚Äì1
    da = da.assign_coords(band=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]).astype("float32") / 10000.0          # not needed - 0.1 

    # Ensure true zeros are NaN (in case old medians used fillna(0))
    da = da.where(da != 0)

    # Split bands
    blue, green, red, nir, swir1, swir2 = da.sel(band=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"])

    def tc(c):
        return (c["Blue"]*blue + c["Green"]*green + c["Red"]*red +
                c["NIR"]*nir + c["SWIR1"]*swir1 + c["SWIR2"]*swir2)

    # Compute tasseled cap 
    tcb = tc(coeffs["tcb"])
    tcg = tc(coeffs["tcg"])
    tcw = tc(coeffs["tcw"])

    # Stack tc
    tc_stack = xr.concat([tcb, tcg, tcw], dim="band")
    tc_stack = tc_stack.assign_coords(band=["TCB", "TCG", "TCW"])
    tc_stack = tc_stack.rio.write_crs(da.rio.crs)

    # Ensure NaNs are preserved
    tc_stack = tc_stack.astype("float32").rio.write_nodata(np.nan)

    print(f"üíæ Saving tasseled cap mosaic: {out_file}")
    with ProgressBar():
        (
            tc_stack.compute(scheduler="threads")
            .transpose("band", "y", "x")
            .rio.to_raster(
                out_file,
                driver="GTiff",
                tiled=True,
                compress="deflate",
                BIGTIFF="IF_SAFER",
                predictor=3,           
                blockxsize=1024,
                blockysize=1024,
            )
        )

print("‚úÖ All tasseled cap mosaics saved.")


## Trend Calculation (fixed vis)

In [None]:
# =========================================
# TREND CALCULATION FOR TC STACKS
# =========================================

from pathlib import Path
import numpy as np
import xarray as xr
import rioxarray
from dask.diagnostics import ProgressBar
import dask
import logging

# -----------------------------------------
# CONFIG
# -----------------------------------------
tc_dir     = Path("omc_tc_70")          # input mosaics
trend_dir  = Path("omc_trends")  # output directory
trend_dir.mkdir(exist_ok=True)

years = list(range(2017, 2026))
bands_tc = ["TCB", "TCG", "TCW"]

# -----------------------------------------
# 1. LOAD ALL TASSELED CAP MOSAICS
# -----------------------------------------
arrays = []

for year in years:
    fp = tc_dir / f"tc_CDSE_{year}_median_70.tif"
    if not fp.exists():
        print(f"‚ùå Missing {fp}")
        continue

    print(f"‚úÖ Loading {fp}")
    da = rioxarray.open_rasterio(fp, chunks={"x": 1024, "y": 1024})

    # Assign TC band names
    da = da.assign_coords(band=bands_tc)

    # Add numeric time coordinate
    da = da.expand_dims(time=[np.datetime64(f"{year}-07-15")])

    arrays.append(da)

if not arrays:
    raise RuntimeError("No tasseled cap mosaics found!")

# Concatenate stack
stack = xr.concat(arrays, dim="time").transpose("time", "band", "y", "x")
stack = stack.chunk({"time": -1, "x": 1024, "y": 1024})
stack.name = "tc"

print(f"üß© Stack shape: {stack.shape} (time, band, y, x)")

# -----------------------------------------
# 2. FIX THE TIME AXIS FOR REGRESSION
# -----------------------------------------
# Convert datetime64 ‚Üí integer years
years_numeric = stack["time"].dt.year

# Replace time dim with 'year'
stack = stack.assign_coords(year=("time", years_numeric.data))
stack = stack.swap_dims({"time": "year"})

print(f"üìÖ Using year values for regression: {list(years_numeric.values)}")

# -----------------------------------------
# 3. TREND REGRESSION (PER YEAR)
# -----------------------------------------
results = []

for band in bands_tc:
    print(f"üìà Computing trend for {band}...")

    sub = stack.sel(band=band)

    # Fit a first-degree polynomial across the 'year' axis
    fit = sub.to_dataset(name="tc").polyfit(dim="year", deg=1)

    # Extract slope (degree 1 coefficient)
    slope = fit["tc_polyfit_coefficients"].sel(degree=1)

    # OPTIONAL ‚Äî
    # match GEE visualization intensity (your GEE script did "*10")
    slope = slope * 10

    slope = slope.expand_dims(band=[f"{band}_slope"])
    results.append(slope)

# Combine all slope bands
trend = xr.concat(results, dim="band")
trend.rio.write_crs(stack.rio.crs, inplace=True)

# -----------------------------------------
# 4. COMPUTE THE ARRAY
# -----------------------------------------
out_path = trend_dir / "tc_trend_omc1.tif"
print(f"üíæ Saving trend raster: {out_path}")

# Threaded Dask scheduler
dask.config.set(scheduler="threads")
logging.getLogger("tornado.application").setLevel(logging.ERROR)
logging.getLogger("tornado.general").setLevel(logging.ERROR)

with ProgressBar(dt=30.0):  
    trend = trend.compute()

trend_vis = trend.clip(-0.3, 0.3)
trend_vis = ((trend_vis + 0.3) / 0.6 * 255).astype("uint8")
trend_vis.transpose("band", "y", "x").rio.to_raster("trend_visual_70_no2024.tif")

# -----------------------------------------
# 5. SAVE TO GEOTIFF
# -----------------------------------------
trend_vis.transpose("band", "y", "x").rio.to_raster(
    out_path,
    driver="GTiff",
    tiled=True,
    compress="deflate",
    BIGTIFF="IF_SAFER",
    predictor=2,
    blockxsize=1024,
    blockysize=1024,
)

print("‚úÖ Trend image saved successfully.")

