In [None]:
import geopandas as gpd
import pandas as pd
import tacoreader
import rasterio as rio
import matplotlib.pyplot as plt

# 1. Load the taco dataset
# HINT: Every TACO dataset is a GeoDataFrame if it fullfill stac requirements
dataset = tacoreader.load("tacofoundation:cloudsen12-l1c")
# dataset = dataset.to_geodataframe() # From TortillaDataFrame to GeoDataFrame [geopandas]

# 2. Spatial Query [Only California]
subset_sp_temporal = dataset[dataset["rai:admin0"] == "Mexico"]

# 4. Filter images that contain cloud shadows
subset_final = subset_sp_temporal[subset_sp_temporal["cloud_shadow_percentage"] > 0]
print(subset_final.plot())

# 5. Create a new TACO file based on the previous filters
tacoreader.compile(dataframe=subset_final, output="mini.taco", nworkers=4)

# 6. Load your new TACO file
dataset = tacoreader.load("mini.taco")

In [None]:
# Full U-Net pipeline for 9-band Sentinel-2 based cloud & shadow segmentation with checkpoints and stability fixes

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, jaccard_score
import tacoreader
import os

# --- Dataset Class ---
class CloudSEN129BandDataset(Dataset):
    def __init__(self, taco_dataset, patch_size=256):
        self.dataset = taco_dataset
        self.patch_size = patch_size
        self.band_indices = [1, 2, 3, 4, 5, 6, 8, 11, 12]  # B2, B3, B4, B5, B6, B7, B8, B11, B12

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset.read(idx)
        eo_path = sample.read(0)
        label_path = sample.read(1)

        with rasterio.open(eo_path) as eo_src:
            eo = eo_src.read(self.band_indices, window=rasterio.windows.Window(0, 0, self.patch_size, self.patch_size))
            eo = np.nan_to_num(eo)
            eo = np.clip(eo, 0, 10000).astype(np.float32) / 10000.0

        with rasterio.open(label_path) as label_src:
            label = label_src.read(1, window=rasterio.windows.Window(0, 0, self.patch_size, self.patch_size)).astype(np.int64)

        label_map = np.full_like(label, 255)
        label_map[label == 0] = 0
        label_map[np.isin(label, [1, 2])] = 1
        label_map[label == 3] = 2

        # Skip entirely invalid labels (all 255)
        if (label_map != 255).sum() == 0:
            return self.__getitem__((idx + 1) % len(self))

        return torch.from_numpy(eo), torch.from_numpy(label_map)

# --- U-Net Model ---
class UNet(nn.Module):
    def __init__(self, in_channels=9, num_classes=3):
        super().__init__()
        def CBR(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )

        self.enc1 = CBR(in_channels, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.enc4 = CBR(256, 512)
        self.pool = nn.MaxPool2d(2, 2)
        self.middle = CBR(512, 1024)

        self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec1 = CBR(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec2 = CBR(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec3 = CBR(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec4 = CBR(128, 64)

        self.out = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        m = self.middle(self.pool(e4))

        d1 = self.dec1(torch.cat([self.up1(m), e4], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d1), e3], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d2), e2], dim=1))
        d4 = self.dec4(torch.cat([self.up4(d3), e1], dim=1))

        return self.out(d4)

# --- Metrics ---
def compute_metrics(preds, labels):
    preds = preds.flatten()
    labels = labels.flatten()
    mask = labels != 255
    preds = preds[mask]
    labels = labels[mask]
    if len(labels) == 0:
        return {"accuracy": 0, "f1": 0, "iou": 0}
    return {
        "accuracy": (preds == labels).sum().item() / len(labels),
        "f1": f1_score(labels.cpu(), preds.cpu(), average='macro', zero_division=0),
        "iou": jaccard_score(labels.cpu(), preds.cpu(), average='macro', zero_division=0)
    }

# --- Training Loop with Checkpoints ---
def train_model(model, train_loader, val_loader, device, epochs=10, checkpoint_dir="checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    model.to(device)
    best_iou = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)

            if torch.isnan(loss):
                print("⚠️ Skipping batch with NaN loss")
                continue

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

        # Validation and checkpoint
        model.eval()
        with torch.no_grad():
            val_metrics = {"accuracy": 0, "f1": 0, "iou": 0}
            count = 0
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                preds = outputs.argmax(1)
                metrics = compute_metrics(preds, masks)
                for k in val_metrics:
                    val_metrics[k] += metrics[k]
                count += 1
            for k in val_metrics:
                val_metrics[k] /= max(count, 1)
            print(f"Val Acc: {val_metrics['accuracy']:.3f}, F1: {val_metrics['f1']:.3f}, IoU: {val_metrics['iou']:.3f}")

            if val_metrics['iou'] > best_iou:
                best_iou = val_metrics['iou']
                torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pth"))
                print("✅ Saved new best model!")

# --- Visualization ---
def visualize_predictions(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = outputs.argmax(1)
            for i in range(min(2, images.size(0))):
                rgb = images[i, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0)  # B4, B3, B2
                gt = masks[i].cpu().numpy()
                pred = preds[i].cpu().numpy()
                fig, ax = plt.subplots(1, 3, figsize=(12, 4))
                ax[0].imshow(rgb)
                ax[0].set_title("Sentinel-2 RGB")
                ax[1].imshow(gt, cmap="gray")
                ax[1].set_title("Ground Truth")
                ax[2].imshow(pred, cmap="gray")
                ax[2].set_title("Prediction")
                plt.show()
            break

# --- Run Everything ---
taco = tacoreader.load("mini.taco")
dataset = CloudSEN129BandDataset(taco)
train_len = int(0.7 * len(dataset))
val_len = int(0.15 * len(dataset))
test_len = len(dataset) - train_len - val_len
train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])
train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4)
test_loader = DataLoader(test_set, batch_size=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=9, num_classes=3)

train_model(model, train_loader, val_loader, device, epochs=40)


In [None]:
import pandas as pd

def save_metrics_csv(model, dataloader, device, csv_path="model_metrics.csv"):
    model.eval()
    all_metrics = {"accuracy": 0, "f1": 0, "iou": 0}
    count = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = outputs.argmax(1)
            metrics = compute_metrics(preds, masks)
            for k in all_metrics:
                all_metrics[k] += metrics[k]
            count += 1

    for k in all_metrics:
        all_metrics[k] /= max(count, 1)

    df = pd.DataFrame([all_metrics])
    df.to_csv(csv_path, index=False)
    print(f"📊 Saved metrics to {csv_path}")


In [None]:
# Save final weights
torch.save(model.state_dict(), "final_model_weights.pth")

# Optionally copy best checkpoint to user-accessible path
import shutil
shutil.copy("checkpoints/best_model.pth", "best_model_weights.pth")

print("💾 Saved model weights: final_model_weights.pth and best_model_weights.pth")


In [None]:
from torchvision.utils import save_image
import os

def save_sample_predictions(model, dataloader, device, save_dir="sample_outputs", num_samples=20):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    count = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            outputs = model(images)
            preds = outputs.argmax(1)

            for i in range(images.size(0)):
                rgb = images[i, [2, 1, 0]].cpu()
                pred = preds[i].cpu().unsqueeze(0)
                gt = masks[i].cpu().unsqueeze(0)

                save_image(rgb, os.path.join(save_dir, f"input_{count}.png"))
                save_image(pred / 2.0, os.path.join(save_dir, f"pred_{count}.png"))
                save_image(gt / 2.0, os.path.join(save_dir, f"gt_{count}.png"))

                count += 1
                if count >= num_samples:
                    print(f"📸 Saved {num_samples} prediction samples to {save_dir}")
                    return


In [None]:
save_metrics_csv(model, test_loader, device)
save_sample_predictions(model, test_loader, device)


In [None]:
# Load best model before evaluating test metrics
model.load_state_dict(torch.load("checkpoints/best_model.pth"))
model.to(device)

# Then run metrics or visualize_predictions
save_metrics_csv(model, test_loader, device)  # e.g., saves model_metrics.csv
