In [None]:
# --- 0) Install (if needed) ---
# !pip install torchgeo segmentation-models-pytorch pystac-client planetary-computer rioxarray rasterio ipywidgets

# --- 1) Imports & Config ---
import os, json, tempfile, uuid, math
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

import pystac_client
import planetary_computer
import rioxarray as rxr
import rasterio
from rasterio.windows import from_bounds
from torchgeo.datasets import RasterDataset, stack_samples
from torchgeo.samplers import GridGeoSampler, RandomGeoSampler
from segmentation_models_pytorch.encoders import get_preprocessing_fn
from segmentation_models_pytorch.unet import Unet
import kornia.augmentation as K


In [None]:

# Config
AOI_GEOJSON = "aoi.geojson"            # polygon in EPSG:4326
DATE_RANGE = "2024-06-01/2024-06-30"
BANDS = ["B02", "B03", "B04", "B08"]
OUT_TIF = "data/s2_aoi.tif"
LABELS_PARQUET = "data/labels.parquet"  # patch-level labels
PATCH_SIZE = 256                       # pixels
STRIDE = 256
CLASSES = ["water", "vegetation", "urban", "bare_soil"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs("data", exist_ok=True)

# --- 2) Download Sentinel-2 L2A (least-cloudy) ---
def download_s2(aoi_path, out_path):
    with open(aoi_path) as f:
        geom = json.load(f)["features"][0]["geometry"]
    client = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")
    search = client.search(
        collections=["sentinel-2-l2a"],
        intersects=geom,
        datetime=DATE_RANGE,
        query={"eo:cloud_cover": {"lt": 20}},
        max_items=1,
    )
    item = next(search.get_items())
    signed = planetary_computer.sign(item)
    assets = [signed.assets[b] for b in BANDS]
    xr = rxr.open_rasterio([a.href for a in assets]).rio.clip([geom], from_disk=True)
    xr.rio.to_raster(out_path)
    print(f"Saved {out_path}")

if not os.path.exists(OUT_TIF):
    download_s2(AOI_GEOJSON, OUT_TIF)

# --- 3) TorchGeo datasets ---
class S2Dataset(RasterDataset):
    filename_glob = "*.tif"
    is_image = True
    separate_files = False
    all_bands = BANDS
    rgb_bands = ["B04", "B03", "B02"]

img_ds = S2Dataset(root="data")

# --- 4) Patch sampler (for unlabeled sampling & later prediction) ---
sampler = GridGeoSampler(img_ds, size=PATCH_SIZE, stride=STRIDE)

# --- 5) Simple labeling UI (patch-level class) ---
def load_chip(sample):
    arr = sample["image"]  # (C, H, W)
    # simple stretch for RGB preview
    rgb_idx = [BANDS.index(b) for b in ["B04", "B03", "B02"]]
    rgb = np.stack([arr[i] for i in rgb_idx], axis=-1)
    rgb = np.clip((rgb - np.percentile(rgb, 2)) / (np.percentile(rgb, 98) - np.percentile(rgb, 2)), 0, 1)
    return rgb

def label_session(num_samples=20):
    loader = DataLoader(img_ds, batch_sampler=sampler, collate_fn=stack_samples)
    it = iter(loader)
    records = []
    class_dd = widgets.Dropdown(options=CLASSES, description="Class:")
    save_btn = widgets.Button(description="Save label", button_style="success")
    status = widgets.Label()

    def on_save(_):
        nonlocal current_sample
        rec = {
            "uuid": str(uuid.uuid4()),
            "class": class_dd.value,
            "bounds": current_sample["bbox"],
            "path": current_sample["crs"],
            # store chip raw array path? we keep index only; model training uses bbox via sampler again
        }
        records.append(rec)
        status.value = f"Saved {len(records)} / {num_samples}"
        show_next()

    def show_next():
        nonlocal current_sample
        try:
            current_sample = next(it)
        except StopIteration:
            status.value = "No more samples"
            return
        rgb = load_chip(current_sample)
        plt.figure(figsize=(4,4))
        plt.imshow(rgb)
        plt.axis("off")
        clear_output(wait=True)
        display(class_dd, save_btn, status)
        plt.show()

    save_btn.on_click(on_save)
    current_sample = None
    show_next()
    return records

# Run labeling UI in notebook to collect labels
# labeled = label_session(num_samples=30)

# After labeling, save labels (bbox + class)
def save_labels(records, path=LABELS_PARQUET):
    df = pd.DataFrame(records)
    df.to_parquet(path)
    print(f"Saved labels to {path}")

# --- 6) Training dataset using saved labels ---
class PatchLabelDataset(torch.utils.data.Dataset):
    def __init__(self, raster_ds, labels_df, patch_size=PATCH_SIZE):
        self.raster_ds = raster_ds
        self.labels = labels_df.reset_index(drop=True)
        self.patch_size = patch_size
        self.band_idx = list(range(len(BANDS)))
        self.preproc = None  # could add normalization
        self.aug = K.AugmentationSequential(
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),
            data_keys=["input"]
        )
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        row = self.labels.iloc[idx]
        bbox = row["bounds"]
        # bbox is [minx, miny, maxx, maxy]; use window read
        with rasterio.open(self.raster_ds.filepaths[0]) as src:
            window = from_bounds(*bbox, transform=src.transform)
            chip = src.read(window=window, boundless=True, fill_value=0)
        x = torch.from_numpy(chip.astype(np.float32))
        sample = {"input": x}
        sample = self.aug(sample)
        x = sample["input"]
        y = torch.tensor(CLASSES.index(row["class"]), dtype=torch.long)
        return x, y

# --- 7) Train classifier (patch-level) ---
def train_model(labels_path=LABELS_PARQUET, epochs=5, batch_size=8):
    labels_df = pd.read_parquet(labels_path)
    ds = PatchLabelDataset(img_ds, labels_df)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)
    model = Unet(encoder_name="resnet34", in_channels=len(BANDS), classes=len(CLASSES), encoder_weights=None)
    # use global pooling to make classifier: replace final conv to logits, pool spatially
    model.segmentation_head = nn.Identity()
    clf_head = nn.Linear(64, len(CLASSES))  # 64 is decoder last channels for resnet34 UNet; adjust if needed

    optimizer = torch.optim.Adam(list(model.parameters()) + list(clf_head.parameters()), lr=1e-3)
    model.to(DEVICE); clf_head.to(DEVICE)
    for epoch in range(epochs):
        model.train(); clf_head.train()
        losses = []
        for x, y in dl:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            feats = model(x)              # (B, C, H, W) features
            pooled = F.adaptive_avg_pool2d(feats, 1).squeeze(-1).squeeze(-1)  # (B, C)
            logits = clf_head(pooled)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        print(f"Epoch {epoch+1}: loss={np.mean(losses):.4f}")
    return model, clf_head

# --- 8) Sliding-window prediction ---
def predict_map(model, clf_head, raster_path=OUT_TIF, out_csv="data/predictions.csv"):
    model.eval(); clf_head.eval()
    preds = []
    loader = DataLoader(img_ds, batch_sampler=sampler, collate_fn=stack_samples)
    with torch.no_grad():
        for batch in loader:
            x = batch["image"].to(DEVICE)
            feats = model(x)
            pooled = F.adaptive_avg_pool2d(feats, 1).squeeze(-1).squeeze(-1)
            logits = clf_head(pooled)
            prob = F.softmax(logits, dim=1).cpu().numpy()
            for p, bbox in zip(prob, batch["bbox"]):
                preds.append({"bbox": bbox, **{f"prob_{c}": p[i] for i,c in enumerate(CLASSES)}})
    pd.DataFrame(preds).to_csv(out_csv, index=False)
    print(f"Saved predictions (per-patch) to {out_csv}")

# --- 9) Usage flow (inside notebook) ---
# 1. labeled = label_session(num_samples=30)
# 2. save_labels(labeled)
# 3. model, head = train_model(epochs=5, batch_size=8)
# 4. predict_map(model, head)
