## 1) Problem, SDG alignment, limitations, scalability

**Problem:** Detect recent tree-cover loss / woodland removal events in the UK using Sentinel‑2 optical imagery.

**SDG alignment:**
- SDG 15 (Life on Land): monitoring forest and woodland loss.
- SDG 13 (Climate Action): quantify loss/impact on carbon sinks.

**Limitations & ethical considerations:**
- GFC labels are produced globally and may mislabel non-forest land uses in temperate regions (label noise).
- Optical sensors are affected by clouds; cloud masking or multi-date imagery recommended.
- False positives/negatives have socio-economic implications (landowners, policy) — include uncertainty reporting and avoid direct enforcement actions based solely on model output.
- Ensure attribution to data providers and follow licenses for GFC and Sentinel data.

**Scalability & sustainability:**
- The pipeline can scale horizontally by processing tiles in parallel (Dask, Spark).
- For continual monitoring, deploy lightweight models with periodic retraining using new labels; consider model quantization for edge or serverless deployments.
- Use cloud-hosted STAC and object storage for large-scale ingestion; keep preprocessing simple (patching) to reduce compute.

## 2) Dataset selection and access

Inputs (image): Sentinel‑2 L2A (bands B02, B03, B04, B08) — 10 m resolution.

Labels (ground truth): Hansen Global Forest Change (GFC) `lossyear` or `loss` raster — 30 m resolution. We'll derive a binary mask (loss occurred within a selected time window).

Data access options:
- Sentinel‑2: AWS public `s3://sentinel-s2-l2a` or Copernicus Open Access Hub / STAC. For this notebook we assume you have relevant Sentinel‑2 tiles locally (download one sample tile and place in `data/sentinel/`).
- GFC: download region subset from Global Forest Watch or Google Cloud public bucket and place in `data/gfc/`.

Rationale: Sentinel‑2 provides the 4 bands used in the original paper 4‑band model; GFC provides consistent loss labels across global regions.

## 3) Preprocessing pipeline — approach summary

Goals:
- Read Sentinel‑2 4‑band images and normalize to 0–1.
- Read GFC loss raster, reproject and resample to Sentinel grid (recommended: resample labels to 10 m using nearest-neighbour to avoid smoothing).
- Create binary mask for loss within a chosen year range (e.g., 2015–2024).
- Cut 512×512 patches with matching image and mask pairs and save as NumPy arrays matching shapes used in repository (512×512×4 and 512×512×1).

Notes: The code below expects local files; replace paths with your downloaded tile/label files.

In [None]:

import os
import numpy as np
import rioxarray as rxr
import rasterio
from rasterio.enums import Resampling
from rasterio.warp import calculate_default_transform, reproject
from affine import Affine
from tqdm import tqdm

def read_sentinel_tile(path_b02, path_b03, path_b04, path_b08):
    """Read 4 bands and stack into (H,W,4) float32 array normalized 0-1."""
    bands = []
    for p in [path_b02, path_b03, path_b04, path_b08]:
        da = rxr.open_rasterio(p, masked=True)
        arr = np.array(da.squeeze())
        bands.append(arr)
    img = np.stack(bands, axis=-1).astype('float32')
    # simple normalization per-tile (scale to 0-1 using tile min/max)
    img = (img - img.min()) / (img.max() - img.min() + 1e-9)
    return img, da.rio.transform(), da.rio.crs
def read_gfc_mask(path_gfc, target_crs, target_transform, target_shape, year_min=None, year_max=None):
    """Read GFC loss raster and resample + reproject to target grid. Returns binary mask aligned to image."""
    src = rasterio.open(path_gfc)
    # Read entire dataset (careful for large files) — user should subset to AOI before using.
    data = src.read(1)
    # If values represent lossyear, create boolean mask for chosen years
    if year_min is not None or year_max is not None:
        mask = np.zeros_like(data, dtype=np.uint8)
        if year_min is None: year_min = 1
        if year_max is None: year_max = 9999
        # Hansen uses year values 1.. (e.g., 2001->1) or 'loss' boolean; adapt accordingly.
        # Here we interpret >0 as loss and ignore year mapping complexity; user can adapt.
        mask[(data >= year_min) & (data <= year_max)] = 1
    else:
        mask = (data > 0).astype(np.uint8)

    # Reproject/resample to target grid
    dst = np.zeros(target_shape, dtype=np.uint8)
    reproject(
        source=mask,
        destination=dst,
        src_transform=src.transform,
        src_crs=src.crs,
        dst_transform=target_transform,
        dst_crs=target_crs,
        resampling=Resampling.nearest
    )
    return dst

### Patch extraction function — creates 512×512 patches and saves numpy arrays

In [None]:
def extract_patches(image, mask, out_dir, prefix='uk', patch_size=512, stride=512, min_mask_coverage=0.001):
    """Slide window patches across image and save pairs where mask coverage passes threshold.
    image: HxWxC, mask: HxW (binary)."""
    os.makedirs(out_dir, exist_ok=True)
    H, W = image.shape[:2]
    idx = 0
    for y in range(0, H - patch_size + 1, stride):
        for x in range(0, W - patch_size + 1, stride):
            img_patch = image[y:y+patch_size, x:x+patch_size, :]
            m_patch = mask[y:y+patch_size, x:x+patch_size]
            coverage = m_patch.mean()
            if coverage >= min_mask_coverage or True:  # saving all patches by default; change as needed
                np.save(os.path.join(out_dir, f"{prefix}_img_{idx}.npy"), img_patch.astype('float32'))
                np.save(os.path.join(out_dir, f"{prefix}_mask_{idx}.npy"), m_patch.reshape(patch_size, patch_size, 1).astype('uint8'))
                idx += 1
    print(f'Wrote {idx} patches to {out_dir}')

## 4) Model adaptation — Attention U‑Net for 4‑band input

We adapt the Attention U‑Net to accept 4 input channels (Sentinel B02,B03,B04,B08). The architecture and attention gates are the same as the original, with input shape (512,512,4).

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, concatenate, Activation, Multiply, Add
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K

def conv_block(x, filters, kernel_size=3, activation='relu'):
    x = Conv2D(filters, kernel_size, padding='same', kernel_initializer='he_normal')(x)
    x = Activation(activation)(x)
    x = Conv2D(filters, kernel_size, padding='same', kernel_initializer='he_normal')(x)
    x = Activation(activation)(x)
    return x

def attention_gate(x, g, inter_channels):
    # x: skip connection, g: gating signal
    theta_x = Conv2D(inter_channels, 1, strides=1, padding='same')(x)
    phi_g = Conv2D(inter_channels, 1, strides=1, padding='same')(g)
    add_xg = Activation('relu')(Add()([theta_x, phi_g]))
    psi = Conv2D(1, 1, padding='same')(add_xg)
    psi = Activation('sigmoid')(psi)
    return Multiply()([x, psi])

def build_attention_unet(input_shape=(512,512,4), base_filters=16, lr=5e-4):
    inputs = Input(shape=input_shape)
    # encoder
    c1 = conv_block(inputs, base_filters)
    p1 = MaxPooling2D()(c1)
    c2 = conv_block(p1, base_filters*2)
    p2 = MaxPooling2D()(c2)
    c3 = conv_block(p2, base_filters*4)
    p3 = MaxPooling2D()(c3)
    c4 = conv_block(p3, base_filters*8)
    p4 = MaxPooling2D()(c4)
    c5 = conv_block(p4, base_filters*16)
    # decoder with attention gates
    u4 = Conv2DTranspose(base_filters*8, 2, strides=2, padding='same')(c5)
    att4 = attention_gate(c4, c5, base_filters*8)
    m4 = concatenate([u4, att4])
    c6 = conv_block(m4, base_filters*8)
    u3 = Conv2DTranspose(base_filters*4, 2, strides=2, padding='same')(c6)
    att3 = attention_gate(c3, c6, base_filters*4)
    m3 = concatenate([u3, att3])
    c7 = conv_block(m3, base_filters*4)
    u2 = Conv2DTranspose(base_filters*2, 2, strides=2, padding='same')(c7)
    att2 = attention_gate(c2, c7, base_filters*2)
    m2 = concatenate([u2, att2])
    c8 = conv_block(m2, base_filters*2)
    u1 = Conv2DTranspose(base_filters, 2, strides=2, padding='same')(c8)
    att1 = attention_gate(c1, c8, base_filters)
    m1 = concatenate([u1, att1])
    c9 = conv_block(m1, base_filters)
    outputs = Conv2D(1, 1, activation='sigmoid')(c9)
    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(lr), loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Example build
model = build_attention_unet(input_shape=(512,512,4))
model.summary()

## 5) Training routine and hyperparameter tuning notes

Notes:
- Use class imbalance strategies: weighted loss (compute class weights from patches) or focal loss.
- Suggested hyperparameter sweep: learning rate (5e-5, 1e-4, 5e-4), batch size (4,8), patch_size (256 or 512), augmentation strength.
- Use early stopping and ModelCheckpoint to keep best validation model.

The training cell below uses Keras `fit` with data loaded from saved `.npy` patch files.

In [None]:
import glob
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

def load_patches(patch_dir, n=None):
    imgs = sorted(glob.glob(os.path.join(patch_dir, '*_img_*.npy')))
    masks = sorted(glob.glob(os.path.join(patch_dir, '*_mask_*.npy')))
    if n is not None:
        imgs = imgs[:n]; masks = masks[:n]
    X = [np.load(p) for p in imgs]
    y = [np.load(p) for p in masks]
    X = np.stack(X).astype('float32')
    y = np.stack(y).astype('uint8')
    return X, y

# Example: load small subset for quick test
PATCH_DIR = 'data/patches'
if os.path.exists(PATCH_DIR):
    X, y = load_patches(PATCH_DIR, n=100)  # load first 100 pairs for quick run
    print('X', X.shape, 'y', y.shape)
    # split
    from sklearn.model_selection import train_test_split
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    model = build_attention_unet(input_shape=(512,512,4), base_filters=16, lr=5e-4)
    cb = [ModelCheckpoint('uk_unet_att.h5', monitor='val_loss', save_best_only=True), EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)]
    # quick training for demonstration; set epochs higher for real runs
    model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=20, batch_size=4, callbacks=cb)
else:
    print('No patches found in', PATCH_DIR)

## 6) Evaluation and statistical testing

We compute IoU (Jaccard), Dice, Precision, Recall and F1 on the validation/test set. To test statistical significance of differences vs. baseline, we use bootstrap resampling of per-patch IoU scores and calculate confidence intervals and p‑value.

In [None]:
from sklearn.metrics import precision_score, recall_score

def iou_score(y_true, y_pred, eps=1e-7):
    y_true = y_true.flatten()
    y_pred = (y_pred.flatten() > 0.5).astype(int)
    inter = (y_true & y_pred).sum()
    union = (y_true | y_pred).sum()
    if union == 0: return 1.0 if inter==0 else 0.0
    return inter / (union + eps)

def dice_score(y_true, y_pred, eps=1e-7):
    y_true = y_true.flatten()
    y_pred = (y_pred.flatten() > 0.5).astype(int)
    inter = (y_true & y_pred).sum()
    return (2 * inter) / (y_true.sum() + y_pred.sum() + eps)

def evaluate_model(model, X, y):
    ious = []; dices = []; precisions = []; recalls = []; f1s = []
    preds = model.predict(X, batch_size=4)
    for i in range(len(X)):
        gt = y[i].reshape(-1)
        pr = preds[i].reshape(-1)
        iou = iou_score(gt, pr)
        dice = dice_score(gt, pr)
        p = precision_score(gt, (pr>0.5).astype(int), zero_division=0)
        r = recall_score(gt, (pr>0.5).astype(int), zero_division=0)
        f1 = 2*p*r / (p+r+1e-7) if (p+r)>0 else 0.0
        ious.append(iou); dices.append(dice); precisions.append(p); recalls.append(r); f1s.append(f1)
    return {'iou': np.array(ious), 'dice': np.array(dices), 'precision': np.array(precisions), 'recall': np.array(recalls), 'f1': np.array(f1s)}

def bootstrap_compare(metric_a, metric_b, n_boot=1000):
    # metric_a and metric_b are arrays of per-sample scores (same length)
    diff = metric_a - metric_b
    n = len(diff)
    boot_diffs = []
    rng = np.random.default_rng(42)
    for _ in range(n_boot):
        idx = rng.integers(0, n, n)
        boot_diffs.append(diff[idx].mean())
    boot_diffs = np.array(boot_diffs)
    ci_low = np.percentile(boot_diffs, 2.5)
    ci_high = np.percentile(boot_diffs, 97.5)
    p_value = (np.sum(boot_diffs <= 0) / n_boot)  # one-sided test for A>B
    return {'mean_diff': diff.mean(), 'ci': (ci_low, ci_high), 'p_value': p_value}


## 7) Failure analysis and visualization helpers

- Visualize false positives and false negatives to understand common failure modes.
- Stratify errors by landcover or season if metadata available.

The cell below provides simple visualization utilities.

In [None]:
import matplotlib.pyplot as plt

def show_case(X, y_true, y_pred, idx=0):
    fig, axs = plt.subplots(1,4, figsize=(16,4))
    axs[0].imshow(X[idx][..., :3])
    axs[0].set_title('RGB preview')
    axs[1].imshow(y_true[idx].squeeze(), cmap='gray')
    axs[1].set_title('Ground truth')
    axs[2].imshow((y_pred[idx].squeeze()>0.5).astype(int), cmap='gray')
    axs[2].set_title('Prediction')
    axs[3].imshow((y_pred[idx].squeeze()>0.5).astype(int) - y_true[idx].squeeze(), cmap='bwr')
    axs[3].set_title('Error (pred - gt)')
    for a in axs: a.axis('off')
    plt.show()