In [None]:
import os

# =========================================
# IMPORTS
# =========================================
from pathlib import Path

import numpy as np
import rioxarray
import xarray as xr

from utils.download import process_year

## omc (downloading masked images per year)

In [None]:
# =========================================
# CONFIG
# =========================================
# YEARS = list(range(2017, 2021))
YEARS = list(range(2025, 2026))
MONTH_START_END = ("07-01", "08-31")
GRID = "MGRS-05WMU"
MAX_CLOUD_COVER = 70
BBOX_LL = (-153.5, 70.5, -153, 71)
OUT_DIR = "CDSE_scenes_masked/coverage70"

os.environ["AWS_ACCESS_KEY_ID"] = "C364NPCJK6JQ64OIMZJR"
os.environ["AWS_SECRET_ACCESS_KEY"] = "..."  # fill in securely
# os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_REGION"] = "eu-central-1"
os.environ["AWS_S3_ENDPOINT"] = "eodata.dataspace.copernicus.eu"
os.environ["AWS_VIRTUAL_HOSTING"] = "FALSE"
  # Correct for Sentinel-2
# import boto3
# s3 = boto3.client("s3")

In [None]:
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

# =========================================
# PROCESS
# =========================================
BAND_ORDER = ["B02_10m", "B03_10m", "B04_10m", "B08_10m", "B11_20m", "B12_20m"]
BAND_LABELS = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]

# =========================================
# RUN ALL YEARS
# =========================================
all_counts = {}
for yr in YEARS:
    n = process_year(yr, GRID, MAX_CLOUD_COVER, bbox_ll=BBOX_LL, band_order=BAND_ORDER, out_dir=OUT_DIR, month_start_end=MONTH_START_END)
    all_counts[yr] = n

print("\n‚úÖ Done. Scenes written per year:")
for yr, n in all_counts.items():
    print(f" ‚Ä¢ {yr}: {n} scenes")


## Stack to median

In [None]:
# --------------------------------------------
# CONFIG
# --------------------------------------------
TIF_DIR = Path("CDSE_scenes_masked/coverage70/2025") # Folder with the TIFFs
OUT_PATH = Path("CDSE_2025_median_70.tif")  # Output file
BAND_LABELS = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]  # Optional

# --------------------------------------------
# LOAD TIFFS
# --------------------------------------------
tif_files = sorted(TIF_DIR.glob("*.tif"))
print(f"üóÇ Found {len(tif_files)} TIFFs")

scenes = []

for f in tif_files:
    try:
        ds = rioxarray.open_rasterio(f, masked=True)  # shape: (band, y, x)
        if "band" not in ds.coords:
            ds = ds.assign_coords(band=range(1, ds.sizes["band"] + 1))

        # Optionally set band names
        if len(BAND_LABELS) == ds.sizes["band"]:
            ds = ds.assign_coords(band=BAND_LABELS)

        scenes.append(ds.expand_dims(time=[f.name]))  # add time dimension
        print(f"   ‚úì Loaded {f.name}")
    except Exception as e:
        print(f"   ‚ö†Ô∏è Failed to load {f.name}: {e}")

if not scenes:
    raise RuntimeError("‚ùå No scenes could be loaded.")

# --------------------------------------------
# STACK + MEDIAN
# --------------------------------------------
stack = xr.concat(scenes, dim="time")
print("üìä Stack shape:", stack.shape)

median_img = stack.median(dim="time", skipna=True)

# --------------------------------------------
# SAVE MEDIAN STACK
# --------------------------------------------
median_img_u16 = (
    median_img.clip(0, 10000).fillna(0).astype("uint16").rio.write_nodata(0)
)

print(f"üíæ Saving median image to {OUT_PATH}")
median_img_u16.rio.to_raster(
    OUT_PATH,
    driver="GTiff",
    compress="deflate",
    tiled=True,
    predictor=2,
    BIGTIFF="IF_SAFER",
    blockxsize=512,
    blockysize=512,
)

print("‚úÖ Done.")


## Calculating TC images

In [None]:
# === tasseled_cap_mosaic_generation.py ===
import warnings
from pathlib import Path

import rioxarray
import xarray as xr
from dask.diagnostics import ProgressBar

warnings.filterwarnings(
    "ignore", category=UserWarning, message=".*coordinate precision.*"
)

median_dir = Path("omc_medians_70")
tc_dir = Path("omc_tc_70")
tc_dir.mkdir(exist_ok=True)
years = list(range(2025, 2026))

# Sentinel-2 Tasseled Cap coefficients
coeffs = {
    "tcb": dict(
        Blue=0.3037, Green=0.2793, Red=0.4743, NIR=0.5585, SWIR1=0.5082, SWIR2=0.1863
    ),
    "tcg": dict(
        Blue=-0.2848,
        Green=-0.2435,
        Red=-0.5436,
        NIR=0.7243,
        SWIR1=0.0840,
        SWIR2=-0.1800,
    ),
    "tcw": dict(
        Blue=0.1509, Green=0.1973, Red=0.3279, NIR=0.3406, SWIR1=-0.7112, SWIR2=-0.4572
    ),
}

for year in years:
    in_file = median_dir / f"CDSE_{year}_median_70.tif"
    out_file = tc_dir / f"tc_CDSE_{year}_median_70.tif"

    if not in_file.exists():
        print(f"‚ùå Missing median mosaic for {year}")
        continue
    if out_file.exists():
        print(f"‚è≠Ô∏è Already exists, skipping {out_file}")
        continue

    print(f"‚úÖ Loading: {in_file}")
    # Important: masked=True makes rioxarray treat nodata (0) as NaN
    da = rioxarray.open_rasterio(in_file, chunks={"x": 1024, "y": 1024}, masked=True)

    # assign band names, convert reflectance to 0‚Äì1
    da = (
        da.assign_coords(band=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]).astype(
            "float32"
        )
        / 10000.0
    )  # not needed - 0.1

    # Ensure true zeros are NaN (in case old medians used fillna(0))
    da = da.where(da != 0)

    # Split bands
    blue, green, red, nir, swir1, swir2 = da.sel(
        band=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
    )

    def tc(c):
        return (
            c["Blue"] * blue
            + c["Green"] * green
            + c["Red"] * red
            + c["NIR"] * nir
            + c["SWIR1"] * swir1
            + c["SWIR2"] * swir2
        )

    # Compute tasseled cap
    tcb = tc(coeffs["tcb"])
    tcg = tc(coeffs["tcg"])
    tcw = tc(coeffs["tcw"])

    # Stack tc
    tc_stack = xr.concat([tcb, tcg, tcw], dim="band")
    tc_stack = tc_stack.assign_coords(band=["TCB", "TCG", "TCW"])
    tc_stack = tc_stack.rio.write_crs(da.rio.crs)

    # Ensure NaNs are preserved
    tc_stack = tc_stack.astype("float32").rio.write_nodata(np.nan)

    print(f"üíæ Saving tasseled cap mosaic: {out_file}")
    with ProgressBar():
        (
            tc_stack.compute(scheduler="threads")
            .transpose("band", "y", "x")
            .rio.to_raster(
                out_file,
                driver="GTiff",
                tiled=True,
                compress="deflate",
                BIGTIFF="IF_SAFER",
                predictor=3,
                blockxsize=1024,
                blockysize=1024,
            )
        )

print("‚úÖ All tasseled cap mosaics saved.")


## Trend Calculation (fixed vis)

In [None]:
# =========================================
# TREND CALCULATION FOR TC STACKS
# =========================================

import logging
from pathlib import Path

import dask
import numpy as np
import rioxarray
import xarray as xr
from dask.diagnostics import ProgressBar

# -----------------------------------------
# CONFIG
# -----------------------------------------
tc_dir = Path("omc_tc_70")  # input mosaics
trend_dir = Path("omc_trends")  # output directory
trend_dir.mkdir(exist_ok=True)

years = list(range(2017, 2026))
bands_tc = ["TCB", "TCG", "TCW"]

# -----------------------------------------
# 1. LOAD ALL TASSELED CAP MOSAICS
# -----------------------------------------
arrays = []

for year in years:
    fp = tc_dir / f"tc_CDSE_{year}_median_70.tif"
    if not fp.exists():
        print(f"‚ùå Missing {fp}")
        continue

    print(f"‚úÖ Loading {fp}")
    da = rioxarray.open_rasterio(fp, chunks={"x": 1024, "y": 1024})

    # Assign TC band names
    da = da.assign_coords(band=bands_tc)

    # Add numeric time coordinate
    da = da.expand_dims(time=[np.datetime64(f"{year}-07-15")])

    arrays.append(da)

if not arrays:
    raise RuntimeError("No tasseled cap mosaics found!")

# Concatenate stack
stack = xr.concat(arrays, dim="time").transpose("time", "band", "y", "x")
stack = stack.chunk({"time": -1, "x": 1024, "y": 1024})
stack.name = "tc"

print(f"üß© Stack shape: {stack.shape} (time, band, y, x)")

# -----------------------------------------
# 2. FIX THE TIME AXIS FOR REGRESSION
# -----------------------------------------
# Convert datetime64 ‚Üí integer years
years_numeric = stack["time"].dt.year

# Replace time dim with 'year'
stack = stack.assign_coords(year=("time", years_numeric.data))
stack = stack.swap_dims({"time": "year"})

print(f"üìÖ Using year values for regression: {list(years_numeric.values)}")

# -----------------------------------------
# 3. TREND REGRESSION (PER YEAR)
# -----------------------------------------
results = []

for band in bands_tc:
    print(f"üìà Computing trend for {band}...")

    sub = stack.sel(band=band)

    # Fit a first-degree polynomial across the 'year' axis
    fit = sub.to_dataset(name="tc").polyfit(dim="year", deg=1)

    # Extract slope (degree 1 coefficient)
    slope = fit["tc_polyfit_coefficients"].sel(degree=1)

    # OPTIONAL ‚Äî
    # match GEE visualization intensity (your GEE script did "*10")
    slope = slope * 10

    slope = slope.expand_dims(band=[f"{band}_slope"])
    results.append(slope)

# Combine all slope bands
trend = xr.concat(results, dim="band")
trend.rio.write_crs(stack.rio.crs, inplace=True)

# -----------------------------------------
# 4. COMPUTE THE ARRAY
# -----------------------------------------
out_path = trend_dir / "tc_trend_omc1.tif"
print(f"üíæ Saving trend raster: {out_path}")

# Threaded Dask scheduler
dask.config.set(scheduler="threads")
logging.getLogger("tornado.application").setLevel(logging.ERROR)
logging.getLogger("tornado.general").setLevel(logging.ERROR)

with ProgressBar(dt=30.0):
    trend = trend.compute()

trend_vis = trend.clip(-0.3, 0.3)
trend_vis = ((trend_vis + 0.3) / 0.6 * 255).astype("uint8")
trend_vis.transpose("band", "y", "x").rio.to_raster("trend_visual_70_no2024.tif")

# -----------------------------------------
# 5. SAVE TO GEOTIFF
# -----------------------------------------
trend_vis.transpose("band", "y", "x").rio.to_raster(
    out_path,
    driver="GTiff",
    tiled=True,
    compress="deflate",
    BIGTIFF="IF_SAFER",
    predictor=2,
    blockxsize=1024,
    blockysize=1024,
)

print("‚úÖ Trend image saved successfully.")
