In [None]:
import numpy as np
import pandas as pd
import ee

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, confusion_matrix
import random

from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.cm import ScalarMappable
import matplotlib
matplotlib.rcParams['figure.dpi'] = 250

from patch_classifier_module import PatchClassifier

ee.Authenticate()
ee.Initialize()

## Input-Channel Construction

For each scene (Chesapeake and CGSM), compute (B2, B3, B4, B8) and two robust indices:

- NDTI = (B4 – B3) / (B4 + B3)

- NDWI = (B3 – B8) / (B3 + B8)

### Ciénaga Grande Santa Marta (CGSM)

#### 1. Initial paramaters

- patch_size: the side length (in pixels) of each square patch we’ll extract around a station

- scale: ground resolution of Sentinel-2 in meters per pixel

- half_m / half_deg: how many meters (and degrees) to go out from the station coordinate to build a square of size patch_size

- bands: the Sentinel-2 bands (and two indices) we’ll pull for each patch

- window_days: temporal window around each in-situ sample date

- cloud_thresh: maximum allowed cloud cover when filtering images

- debug_level: controls printing progress messages

In [None]:
patch_size = 128
scale = 10
half_m = (patch_size // 2) * scale
meters_per_deg = 111320.0
half_deg = half_m / meters_per_deg
bands = ['B2','B3','B4','B8','NDTI','NDWI']
window_days = 3
cloud_thresh = 50
debug_level = 1

#### 2. Z-score normalization function

- Takes a 3D array with shape (C, H, W)

- Computes per-band mean and std, then standardizes so each band has mean 0 and unit variance

- Prevents divide-by-zero with a small epsilon

In [None]:
def zscore_normalize(x: np.ndarray) -> np.ndarray:
    """Normalize each band of (C,H,W) array."""
    mean = x.mean(axis=(1,2), keepdims=True)
    std = x.std(axis=(1,2), keepdims=True) + 1e-8
    return (x - mean) / std

#### 3. Load and filter the CSV of in-situ data

1. Read the CSV, parsing the muestreo column as dates

2. Select only the latitude, longitude, station name, sample date, and SST measurement

3. Print a sample of station names if debugging is on.

4. Keep only the four stations located in the Ciénaga Grande de Santa Marta—and exit with an error if none match

In [None]:
df = pd.read_csv('./data/icam.csv',
                parse_dates=['muestreo'],
                low_memory=False)

df = df[['latitud','longitud','estacion','muestreo','sst']].dropna().copy()

if debug_level:
    print("Available stations (sample):", df['estacion'].unique()[:10])

estaciones_cienaga = ['F. Costa Verde',
                      'F. La Barra',
                      'F. Sevilla',
                      'F. Palma Sola',]

df = df[df['estacion'].isin(estaciones_cienaga)].reset_index(drop=True)
if df.empty:
    print("ERROR: No stations match filter. Check station names." )
    sys.exit(1)

#### 4. Initialize the Earth Engine collection

Build a base Sentinel-2 collection:

- Filter out scenes with > cloud_thresh% cloud cover

- Keep only the four visible bands we need

Prepare empty lists for the extracted patches and corresponding labels, plus a counter for skipped rows

In [None]:
collection = (
    ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
      .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', cloud_thresh))
      .select(['B2','B3','B4','B8'])
)
patches, labels = [], []
skipped = 0

#### 5. Iterate over each in-situ sample and extract a patch

1. Loop over each valid row of in-situ measurements

2. Define a ±window_days window around the sample date

3. Define a square region of size patch_size×patch_size around the station coordinates

4. Filter the Sentinel-2 collection by that region and date window; if no images, skip

5. Compute the band-wise median (base) and two indices (NDTI, NDWI)

6. Reproject to ensure consistent pixel scale

7. Sample all bands into a Python dict, fill missing data with zeros

8. Stack into a NumPy array of shape (6, 128, 128)

9. Normalize each band, then append to the list with its corresponding SST label

In [None]:
for i, row in df.iterrows():
    lat, lon    = float(row['latitud']), float(row['longitud'])
    date        = row['muestreo']
    sst_val     = float(row['sst'])
    station     = row['estacion']

    start = (date - timedelta(days=window_days)).strftime('%Y-%m-%d')
    end   = (date + timedelta(days=window_days)).strftime('%Y-%m-%d')

    region = ee.Geometry.Rectangle([
        lon - half_deg, lat - half_deg,
        lon + half_deg, lat + half_deg
    ])

    sentinel = collection.filterBounds(region).filterDate(start, end)
    try:
        count = sentinel.size().getInfo()
    except Exception as e:
        skipped += 1
        continue

    if debug_level:
        print(f"Row {i}: Station={station}, Date={date.date()}, Images={count}")
    if count == 0:
        skipped += 1
        continue

    base = sentinel.median()
    ndti = base.normalizedDifference(['B4','B3']).rename('NDTI')
    ndwi = base.normalizedDifference(['B3','B8']).rename('NDWI')
    comp = base.addBands([ndti, ndwi]).unmask(0)
    reproj = comp.reproject(comp.projection().atScale(scale))

    try:
        data_dict = reproj.sampleRectangle(region=region, defaultValue=0) \
                          .toDictionary().getInfo()
    except Exception as e:
        skipped += 1
        continue

    arrs = []
    for b in bands:
        mat = np.array(data_dict.get(b, np.zeros((patch_size, patch_size))))
        if mat.shape != (patch_size, patch_size):
            mat = np.resize(mat, (patch_size, patch_size))
        arrs.append(mat)
    patch = np.stack(arrs, axis=0)

    patch = zscore_normalize(patch)
    patches.append(patch)
    labels.append(sst_val)

#### 6. Save the extracted dataset

- Verify that at least one patch was extracted, otherwise exit with error

- Stack all patches into a single array of shape (N, C, H, W) and labels into (N,)

- Save both to .npy files for subsequent model training or inference

- Report how many patches were created versus skipped

In [None]:
if not patches:
    print("ERROR: No patches extracted. Adjust filters or check data.")
    sys.exit(1)

X_cgs = np.stack(patches, axis=0)
y_cgs = np.array(labels)

np.save('./data/cgs_patches.npy', X_cgs)
np.save('./data/cgs_labels.npy', y_cgs)
print(f"Extracted {len(patches)} patches, skipped {skipped} entries.")

### Chesapeake Bay (CB)

#### 1. Initial parameters

- patch_size: width/height in pixels of each square patch.

- scale: meters on the ground represented by one pixel.

- half_deg: how many degrees of latitude/longitude correspond to half the patch’s ground footprint.

- THRESHOLD: converts a continuous measurement into a binary label.

In [None]:
patch_size = 128
scale = 10
half_m = (patch_size // 2) * scale
meters_per_deg = 111320.0
half_deg = half_m / meters_per_deg
THRESHOLD = 25.0

#### 2. Z-score normalization function

- Takes a 3D array with shape (C, H, W)

- Computes per-band mean and std, then standardizes so each band has mean 0 and unit variance

- Prevents divide-by-zero with a small epsilon

In [None]:
def zscore_normalize(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    if x.ndim == 3:
        # normaliza un solo parche
        mean = x.mean(axis=(1,2), keepdims=True)
        std  = x.std(axis=(1,2), keepdims=True) + eps
    elif x.ndim == 4:
        # normaliza todo el batch de parches
        mean = x.mean(axis=(2,3), keepdims=True)
        std  = x.std(axis=(2,3), keepdims=True) + eps
    else:
        raise ValueError(f"zscore_normalize espera ndim 3 o 4, got {x.ndim}")
    return (x - mean) / std

#### 3. Load & Filter the Chesapeake In-Situ Data

1. Read the CSV containing many water-quality measurements around Chesapeake Bay

2. Select only rows where the parameter is total suspended solids ('TSS')

3. Extract the relevant columns—timestamp, coordinates, and numeric value—dropping any exact duplicates to get a clean list of sampling events

In [None]:
df = pd.read_csv('./data/cb_in_situ.csv',
                 parse_dates=['SampleDate'],
                 low_memory=False)

df_turb = df[df['Parameter'] == 'TSS'].copy()
samples = (
    df_turb[['SampleDate', 'Latitude', 'Longitude', 'MeasureValue']]
      .drop_duplicates()
      .reset_index(drop=True)
)

#### 4. Prepare Containers & Sentinel-2 Collection

- patches, masks: empty lists that will hold the image data and their binary labels

- collection: a filtered Earth Engine image collection selecting only scenes with ≤ 30 % clouds and the four visible bands

In [None]:
patches, masks = [], []
collection = (
    ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
      .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 30))
      .select(['B2','B3','B4','B8'])
)

#### 5. Iterate Over Each Sample, Extract & Label Patches

1. Define a one-day window around each sample and a square region around its coordinates

2. Filter the Sentinel-2 data both in space and time

3. Build a median composite image and compute two water-quality indices (NDTI, NDWI)

4. Reproject to ensure consistent pixel size

5. Sample the pixel grid into memory

6. Stack the six bands into a 3D array

7. Apply Z-score normalization

8. Create a uniform binary mask (0 or 1) based on the turbidity threshold, same size as the patch

In [None]:
for _, row in samples.iterrows():
    sd    = row['SampleDate']
    label = float(row['MeasureValue'])

    start = sd.strftime('%Y-%m-%d')
    end   = (sd + timedelta(days=1)).strftime('%Y-%m-%d')

    lon, lat = float(row['Longitude']), float(row['Latitude'])
    region = ee.Geometry.Rectangle([
        lon - half_deg, lat - half_deg,
        lon + half_deg, lat + half_deg
    ])

    sentinel = collection.filterBounds(region).filterDate(start, end)
    count = sentinel.size().getInfo()
    if count == 0:
        continue

    base = sentinel.median()
    ndti = base.normalizedDifference(['B4','B3']).rename('NDTI')
    ndwi = base.normalizedDifference(['B3','B8']).rename('NDWI')
    comp = base.addBands([ndti, ndwi]).unmask(0)
    reproj = comp.reproject(comp.projection().atScale(scale))

    data_dict = reproj.sampleRectangle(
        region=region, defaultValue=0
    ).toDictionary().getInfo()

    bands = ['B2','B3','B4','B8','NDTI','NDWI']
    arrays = []
    for b in bands:
        arr = np.array(data_dict.get(b, np.zeros((patch_size, patch_size))))
        if arr.shape != (patch_size, patch_size):
            arr = np.resize(arr, (patch_size, patch_size))
        arrays.append(arr)
        
    patch = np.stack(arrays, axis=0)

    patches.append(zscore_normalize(patch))

    binary_class = 1 if label > THRESHOLD else 0
    mask = np.full((patch_size, patch_size), binary_class, dtype=np.uint8)
    masks.append(mask)

#### 6. Stack & Save the Dataset

- Combine all individual patches into one big array X_data of shape (N, 6, 128, 128)

- Combine all masks into y_mask of shape (N, 128, 128)

- Save both arrays as .npy files for use in model training

In [None]:
X_data = np.stack(patches, axis=0)
y_mask = np.stack(masks, axis=0)

np.save('./data/patches_array.npy', X_data)
np.save('./data/labels_array.npy', y_mask)

## Model

#### 1. Hyperparameters

- BATCH_SIZE: number of samples per training batch

- LR: initial learning rate for the Adam optimizer

- WEIGHT_DECAY: L2 regularization strength

- NUM_EPOCHS: maximum epochs to train

- PATIENCE: how many epochs without validation improvement before early stopping

- AUG_PROB & NOISE_STD: probability and strength of data augmentations (flips, rotations, Gaussian noise)

- FINETUNE_LR & FINETUNE_EPOCHS: settings for a later fine‐tuning stage (not shown here)

- THRESHOLD: cutoff on sigmoid outputs to convert probabilities into binary predictions

In [None]:
BATCH_SIZE = 16
LR = 5e-5
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 100
PATIENCE = 10
AUG_PROB = 0.5
NOISE_STD = 0.02
FINETUNE_LR = 1e-6
FINETUNE_EPOCHS = 20
THRESHOLD = 0.5

#### 2. Z-score Normalization Utility

Standardizes each channel of every patch to zero mean and unit variance, preventing any one band from dominating.

In [None]:
def zscore_normalize(x: np.ndarray) -> np.ndarray:
    mean = x.mean(axis=(2,3), keepdims=True)
    std  = x.std(axis=(2,3), keepdims=True) + 1e-8
    return (x - mean) / std

#### 3. Custom Dataset Class

Wraps NumPy arrays into a PyTorch Dataset

Optionally applies spatial augmentations and small noise on the fly for training

In [None]:
class TurbidityDataset(Dataset):
    def __init__(self, X, y, augment=False):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).float().unsqueeze(1)
        self.augment = augment
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        x, y = self.X[idx], self.y[idx]
        if self.augment:
            # random vertical flip
            if random.random() < AUG_PROB:
                x = torch.flip(x, dims=[1])
            # random horizontal flip
            if random.random() < AUG_PROB:
                x = torch.flip(x, dims=[2])
            # random 90° rotation
            if random.random() < AUG_PROB:
                k = random.choice([1,2,3])
                x = torch.rot90(x, k, dims=[1,2])
            # additive Gaussian noise
            if random.random() < AUG_PROB:
                x = x + torch.randn_like(x) * NOISE_STD
        return x, y

#### 4. Load Data & Create Splits

- Loads the saved patch and mask arrays

- Reduces each mask to a single scalar label by sampling the center pixel

- Normalizes the input patches

- Splits into train/val/test, preserving class balance (stratification)

- Creates DataLoaders for efficient batching and optional shuffling

In [None]:
X = np.load('./data/patches_array.npy')
y_mask = np.load('./data/labels_array.npy')
H, W = y_mask.shape[1], y_mask.shape[2]
y = y_mask[:, H//2, W//2]
X = zscore_normalize(X)

# 10% hold-out test split, stratified
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.1, random_state=42, stratify=y
)

# of the remaining 90%, 80% train and 20% validation
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.2, random_state=42, stratify=y_temp
)

train_loader = DataLoader(TurbidityDataset(X_train, y_train, augment=True),
                          batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
val_loader   = DataLoader(TurbidityDataset(X_val,   y_val,   augment=False),
                          batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader  = DataLoader(TurbidityDataset(X_test,  y_test,  augment=False),
                          batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

#### 5. Model Definition

DoubleConv: two Conv→BN→ReLU blocks in sequence

PatchClassifier: a small CNN encoder that downsamples twice, then global pools to a 256-dim vector per patch and applies a linear layer to predict one logit

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.seq(x)

class PatchClassifier(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        self.encoder = nn.Sequential(
            DoubleConv(in_channels,64), nn.MaxPool2d(2),
            DoubleConv(64,128),      nn.MaxPool2d(2),
            DoubleConv(128,256),     nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(256,1)
    def forward(self, x):
        x = self.encoder(x)                    # (B,256,1,1)
        return self.fc(x.view(x.size(0), -1))  # (B,1)

#### 6. Training & Validation Loop

1. Move model and data to GPU if available

2. Compute a pos_weight to counter class imbalance

3. Train for up to NUM_EPOCHS, tracking training and validation loss

4. Adjust learning rate on plateau of validation loss

5. Save the best checkpoint and stop early if no improvement for PATIENCE epochs

In [None]:
device = torch.device('cuda' if torch.cuda.is_available()
                      else 'mps' if torch.backends.mps.is_available()
                      else 'cpu')
print(f'Using: {device}')
model = PatchClassifier(in_channels=X.shape[1]).to(device)

pos = (y_train==1).sum(); neg = (y_train==0).sum()
pos_weight = torch.tensor(neg/(pos + 1e-8), device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,'min', factor=0.7, patience=5
)

In [None]:
best_val = float('inf')
epochs_no_imp = 10
train_losses, val_losses = [], []

for ep in range(1, NUM_EPOCHS+1):
    model.train()
    train_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        loss = criterion(model(xb), yb)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            val_loss += criterion(model(xb), yb).item() * xb.size(0)
    val_loss /= len(val_loader.dataset)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    if val_loss < best_val:
        best_val, epochs_no_imp = val_loss, 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_no_imp += 1

    print(f"Epoch {ep} | Train {train_loss:.4f} | Val {val_loss:.4f}")
    if epochs_no_imp >= PATIENCE:
        print("Stopping early.")
        break

In [None]:
plt.figure(figsize=(5,4))
plt.plot(np.arange(1, len(train_losses)+1), train_losses, label='Train loss')
plt.plot(np.arange(1, len(val_losses)+1), val_losses, label='Validation loss')
plt.show()

In [None]:
model.load_state_dict(torch.load('best_model.pth'))

#### 7. Final Evaluation on Test Set

- Make predictions on the held-out test set

- Threshold sigmoid outputs to generate binary labels

- Calculate Accuracy, Precision, Recall, AUC, and print a Confusion Matrix to quantify final performance

In [None]:
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        probs = torch.sigmoid(model(xb)).cpu().view(-1)
        preds = (probs > THRESHOLD).int().numpy()
        all_preds.extend(preds)
        all_labels.extend(yb.view(-1).numpy())

test_acc = accuracy_score(all_labels, all_preds)
test_prec = precision_score(all_labels, all_preds)
test_rec = recall_score(all_labels, all_preds)
test_auc = roc_auc_score(all_labels, all_preds)
conf_mat = confusion_matrix(all_labels, all_preds)

print("Test Accuracy: ", test_acc)
print("Precision: ", test_prec)
print("Recall: ", test_rec)
print("AUC: ", test_auc)
print("Confusion Matrix:\n", conf_mat)

## Fine-Tuning on CGSM Data


- Objective: Adapt the pre-trained model (on Chesapeake) to the local Ciénaga domain

- Load the saved NumPy arrays of patches (.npy) and their continuous SST labels

- Apply the same Z-score normalization to match the training pipeline

In [None]:
X_cgsm = np.load('./data/cgs_patches.npy')
y_cgsm = np.load('./data/cgs_labels.npy')

X_cgsm = zscore_normalize(X_cgsm)

#### 1. Create a Train/Hold-out Split with Graceful Fallback

- Stratified split keeps the same proportion of “contaminated” vs. “clean” patches in both fine-tune and hold-out sets

- Fallback: if one class has only a single example (making stratification impossible), it automatically uses a standard random split instead

In [None]:
try:
    X_ft, X_hold, y_ft, y_hold = train_test_split(
        X_cgsm, y_cgsm,
        test_size=0.2, random_state=1,
        stratify=y_cgsm
    )
    print("Using stratified split")
except ValueError:
    # If one class has too few samples, fall back to a simple random split
    print("Stratified split failed, using random split")
    X_ft, X_hold, y_ft, y_hold = train_test_split(
        X_cgsm, y_cgsm,
        test_size=0.2, random_state=1,
        shuffle=True
    )
    print("Using random split")

#### 2. Wrap in Datasets & Loaders

- dataset_ft: the fine-tuning set with augmentations turned on

- dataset_hold: a small hold-out set (no augmentations) to monitor potential overfitting during fine-tuning

- DataLoaders batch and shuffle the data appropriately

In [None]:
dataset_ft = TurbidityDataset(X_ft,  y_ft,  augment=True)
dataset_hold = TurbidityDataset(X_hold, y_hold, augment=False)
loader_ft = DataLoader(dataset_ft,  batch_size=BATCH_SIZE, shuffle=True)
loader_hold = DataLoader(dataset_hold, batch_size=BATCH_SIZE, shuffle=False)

#### 3. Fine-Tune the Full Model

- Lower learning rate (FINETUNE_LR) ensures small weight updates to preserve the previously learned features

- No scheduler or early stopping here—just a fixed number of fine-tuning epochs

- Print the average fine-tuning loss each epoch to verify convergence on the new domain

In [None]:
optimizer_f = torch.optim.Adam(model.parameters(), lr=FINETUNE_LR)
for ep in range(1, FINETUNE_EPOCHS+1):
    model.train()
    loss_ft = 0
    for xb, yb in loader_ft:
        xb, yb = xb.to(device), yb.to(device)
        loss = criterion(model(xb), yb)
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_f.step()
        loss_ft += loss.item() * xb.size(0)
    loss_ft /= len(loader_ft.dataset)
    print(f"Fine-tune Epoch {ep}, Loss: {loss_ft:.4f}")

#### 4. Visual Inspection of Predictions

- Select one batch from your original test set for inspection

- Compute sigmoid probabilities and threshold them into binary predictions

- Plot the first four patches as RGB composites (using bands B4, B3, B2)

- Overlay the ground-truth (center-pixel) label vs. the model’s prediction in the title for quick visual validation

In [None]:
model.eval()
batch = next(iter(test_loader))
xb, yb = batch
xb = xb.to(device)
with torch.no_grad():
    probs = torch.sigmoid(model(xb)).cpu().view(-1)
preds = (probs > THRESHOLD).int()

for i in range(min(4, len(xb))):
    patch = xb.cpu()[i]
    # Create a quick RGB composite from bands [B4,B3,B2]
    rgb = patch[[2,1,0]].numpy().transpose(1,2,0)
    plt.figure(figsize=(3,3))
    plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min()))
    plt.title(f"True: {int(yb[i].item())}, Pred: {int(preds[i].item())}")
    plt.axis('off')
plt.show()

## Plot predictions

#### 1. Parameter Definitions

- Compute how many degrees correspond to half a patch

- Define which bands (and indices) to sample

- Parameters controlling image quality and the visual “fade” effect when plotting

In [None]:
patch_size = 128
scale_patch = 10
half_m = (patch_size // 2) * scale_patch
meters_per_deg = 111320.0
half_deg  = half_m / meters_per_deg

bands = ['B2','B3','B4','B8','NDTI','NDWI']
window_days = 3
cloud_thresh = 50
map_scale = 30
num_rings = 20
max_alpha = 0.7
r_pixels = 20

#### 2. Build a List of Dates

Creates a daily list of strings from date1 to date2, to loop over and fetch composites for each day

In [None]:
start_date = datetime(2025,5,1)
end_date   = datetime(2025,5,20)
dates_list = []
d = start_date
while d <= end_date:
    dates_list.append(d.strftime('%Y-%m-%d'))
    d += timedelta(days=1)

#### 3. Load the Trained Model

Loads your previously‐trained CNN onto CPU/GPU, ready for inference

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PatchClassifier(in_channels=6).to(device)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()

#### 4. Load In-Situ Station Coordinates

Reads the INVEMAR CSV, filters to just the four station names, and extracts unique lat/lon pairs

In [None]:
df_inv = pd.read_csv(
    './data/icam.csv',
    parse_dates=['muestreo'],
    dayfirst=True
)
stations = ['F. Costa Verde','F. La Barra','F. Sevilla','F. Palma Sola']
df_coords = (
    df_inv[['latitud','longitud','estacion']]
      .dropna()
      .query('estacion in @stations')
      .drop_duplicates(['latitud','longitud'])
)
lats = df_coords['latitud'].values
lons = df_coords['longitud'].values

#### 5. Define the Full‐Region Bounding Box

Expands that station envelope by one patch radius to cover the entire area of interest in a single rectangle

In [None]:
min_lon = lons.min() - half_deg
max_lon = lons.max() + half_deg
min_lat = lats.min() - half_deg
max_lat = lats.max() + half_deg
region = ee.Geometry.Rectangle([min_lon, min_lat, max_lon, max_lat])

#### 6. Loop Over Dates & Fetch Composites

1. Filter Sentinel-2 by region, date window, and cloud cover

2. Compute a median composite and two indices (NDTI, NDWI)

3. Reproject to map_scale for a manageable full-region array

4. Sample into a NumPy dict and assemble a normalized RGB image for plotting

In [None]:
for date_str in dates_list[:3]:
    date = pd.to_datetime(date_str, dayfirst=True)
    start = (date - timedelta(days=window_days)).strftime('%Y-%m-%d')
    end   = (date + timedelta(days=window_days)).strftime('%Y-%m-%d')

    coll = (
        ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
          .filterBounds(region)
          .filterDate(start, end)
          .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', cloud_thresh))
          .select(['B2','B3','B4','B8'])
    )
    if coll.size().getInfo() == 0:
        continue

    # build a median composite + compute NDTI & NDWI
    base = coll.median()
    ndti = base.normalizedDifference(['B4','B3']).rename('NDTI')
    ndwi = base.normalizedDifference(['B3','B8']).rename('NDWI')
    comp = base.addBands([ndti, ndwi]).unmask(0)

    # downsample for full-region display
    coarse = comp.reproject(comp.projection().atScale(map_scale))
    info = coarse.sampleRectangle(region=region, defaultValue=0).toDictionary().getInfo()

    # build a normalized RGB background
    R = np.array(info['B4']); G = np.array(info['B3']); B = np.array(info['B2'])
    rgb = np.stack([R, G, B], axis=2)
    rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min())

#### 7. Compute Predictions at Station Locations

- For each station coordinate, extract the same 128×128 patch (with the two indices) and normalize

- Run it through the CNN to get a contamination probability

In [None]:
preds = []
for lat, lon in zip(lats, lons):
    box = ee.Geometry.Rectangle([
        lon-half_deg, lat-half_deg,
        lon+half_deg, lat+half_deg
    ])
    patch_img = (
        coll.median()
            .addBands([ndti, ndwi])
            .unmask(0)
            .reproject(comp.projection().atScale(scale_patch))
    )
    d = patch_img.sampleRectangle(region=box, defaultValue=0).toDictionary().getInfo()
    arr = np.stack([np.array(d.get(b, np.zeros((patch_size, patch_size)))) for b in bands], axis=0)
    arr = (arr - arr.mean(axis=(1,2), keepdims=True)) / (arr.std(axis=(1,2), keepdims=True) + 1e-8)
    tensor = torch.from_numpy(arr).unsqueeze(0).to(device).float()
    with torch.no_grad():
        prob = torch.sigmoid(model(tensor)).cpu().item()
    preds.append(prob)

#### 8. Convert to Pixel Coordinates

Maps longitude/latitude to pixel x/y in the downsampled background image

In [None]:
h, w = rgb_norm.shape[:2]
x_pix = [(lon-min_lon)/(max_lon-min_lon)*w for lon in lons]
y_pix = [(max_lat-lat)/(max_lat-min_lat)*h for lat in lats]

#### 9. Plot the Background + fading Circles

1. Display the true-color Sentinel-2 background

2. Overlay at each station a stack of transparent circles whose opacity fades outward, colored by the predicted probability

3. Add a vertical colorbar to interpret the heatmap of contamination probability

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
ax.imshow(rgb_norm)
ax.set_title(f"Contamination Map & GEE Image for {date.date()}", y=1.03)
ax.axis('off')

cmap = plt.get_cmap('coolwarm')
for x0, y0, p in zip(x_pix, y_pix, preds):
    if np.isnan(p): 
        continue
    color = cmap(p)

    for i in reversed(range(1, num_rings+1)):
        radius = r_pixels * (i / num_rings)
        alpha  = (1 - (i / num_rings)) * max_alpha
        circ = mpatches.Circle(
            (x0, y0), radius=radius,
            facecolor=color, edgecolor=None,
            alpha=alpha, linewidth=0,
            zorder=2
        )
        ax.add_patch(circ)

sm = ScalarMappable(cmap='coolwarm', norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04, label='Probability')
plt.show()

## OTRO