In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

import rasterio
from rasterio import features
import geopandas as gpd
import shapely.geometry as geom

In [2]:
# ---------------------------------------------
# PART A: DATA PREPARATION (ROI + GeoJSON)
# ---------------------------------------------
def create_mask_from_geojson(roi_tif_path, geojson_path):
    """
    1) Loads the ROI tif image using rasterio
    2) Clips the polygons in the GeoJSON to ROI bounds
    3) Rasterizes them to produce a mask (0=background, 1=object)
    Returns:
        roi_img (np.ndarray), mask (np.ndarray), metadata (dict)
    """
    # -- Load ROI TIF --
    with rasterio.open(roi_tif_path) as roi_src:
        roi_img = roi_src.read(1)  # read first band; adapt if multi-band
        roi_transform = roi_src.transform
        roi_crs = roi_src.crs
        roi_width = roi_src.width
        roi_height = roi_src.height
        roi_bounds = roi_src.bounds

    # -- Read GeoJSON as GeoDataFrame --
    geo_df = gpd.read_file(geojson_path)
    
    # -- Clip polygons to ROI bounding box --
    roi_polygon = geom.box(*roi_bounds)
    geo_df_clipped = gpd.clip(geo_df, roi_polygon)

    # -- Rasterize polygons to a mask --
    shapes_to_rasterize = [(geom, 1) for geom in geo_df_clipped.geometry if geom is not None]
    mask = features.rasterize(
        shapes=shapes_to_rasterize,
        out_shape=(roi_height, roi_width),
        transform=roi_transform,
        fill=0,
        all_touched=True
    )

    # -- Prepare metadata to help with debugging or saving --
    metadata = {
        'transform': roi_transform,
        'crs': roi_crs,
        'width': roi_width,
        'height': roi_height
    }
    return roi_img, mask, metadata


class ROISegDataset(Dataset):
    """
    Minimal PyTorch Dataset for ROI segmentation.
    Expects pre-loaded images and masks (or can load on-the-fly if needed).
    """
    def __init__(self, roi_img, mask, transform=None):
        """
        Args:
            roi_img (np.ndarray): The ROI image, shape (H, W)
            mask (np.ndarray): The segmentation mask, shape (H, W)
            transform (callable, optional): Optional transform for data augmentation
        """
        super().__init__()
        self.roi_img = roi_img
        self.mask = mask
        self.transform = transform

        # For demonstration, we'll store patches or just the entire image
        # as a single sample. Typically you'd want to tile large images.
        # Here, let's just keep it as 1 sample for simplicity.

    def __len__(self):
        return 1  # only one image/mask pair in this minimal example

    def __getitem__(self, idx):
        image = self.roi_img
        label = self.mask

        # Convert to float32
        image = image.astype(np.float32)
        label = label.astype(np.float32)

        # Expand dimension to create 1 channel: (1, H, W)
        image = np.expand_dims(image, axis=0)
        label = np.expand_dims(label, axis=0)  # still binary but keep shape consistent

        if self.transform:
            # If we have torchvision-based transforms that work on PIL images,
            # we might need to convert from numpy to PIL, apply transform,
            # then convert back to tensor. Or use custom transforms on numpy arrays.
            pass

        return torch.from_numpy(image), torch.from_numpy(label)

In [3]:
# ---------------------------------------------
# PART B: DEFINE A SIMPLE U-NET
# ---------------------------------------------
# Option 1: Implement a minimal U-Net from scratch (short version)
# Option 2: Use segmentation_models_pytorch (commented below)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(SimpleUNet, self).__init__()
        # Down
        self.dc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.dc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        # Bottleneck
        self.dc3 = DoubleConv(128, 256)

        # Up
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dc4 = DoubleConv(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dc5 = DoubleConv(128, 64)

        # Output
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Down
        x1 = self.dc1(x)
        x2 = self.pool1(x1)
        x3 = self.dc2(x2)
        x4 = self.pool2(x3)

        # Bottleneck
        x5 = self.dc3(x4)

        # Up
        x6 = self.up1(x5)
        # Concat skip
        x6 = torch.cat([x6, x3], dim=1)
        x7 = self.dc4(x6)
        x8 = self.up2(x7)
        x8 = torch.cat([x8, x1], dim=1)
        x9 = self.dc5(x8)

        out = self.out_conv(x9)
        return out

"""
# ALTERNATIVE: Using segmentation_models_pytorch
# import segmentation_models_pytorch as smp
# model = smp.Unet(
#     encoder_name="resnet34",  # or any other encoder
#     encoder_weights="imagenet",
#     in_channels=1,
#     classes=1,
# )
"""

'\n# ALTERNATIVE: Using segmentation_models_pytorch\n# import segmentation_models_pytorch as smp\n# model = smp.Unet(\n#     encoder_name="resnet34",  # or any other encoder\n#     encoder_weights="imagenet",\n#     in_channels=1,\n#     classes=1,\n# )\n'

In [4]:
# ---------------------------------------------
# PART C: TRAINING AND VALIDATION LOOP
# ---------------------------------------------
def dice_coefficient(pred, target, smooth=1.0):
    """
    Calculates the Dice coefficient for binary segmentation.
    pred, target: expected shape (N, 1, H, W)
    """
    pred = torch.sigmoid(pred)  # convert logits to [0..1]
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)

    intersection = (pred_flat * target_flat).sum()
    return (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)


def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0.0
    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        # Forward
        outputs = model(images)
        loss = criterion(outputs, masks)
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)


def validate_one_epoch(model, loader, device):
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            dice = dice_coefficient(outputs, masks)
            dice_scores.append(dice.item())
    return np.mean(dice_scores)


In [5]:
# ---------------------------------------------
# PART D: MAIN SCRIPT TO RUN EVERYTHING
# ---------------------------------------------
def main():
    # 1. File paths
    roi_tif_path = "data/roi.tif"
    geojson_path = "data/labels.geojson"

    # 2. Create ROI image + mask
    roi_img, roi_mask, metadata = create_mask_from_geojson(roi_tif_path, geojson_path)

    # 3. Create Dataset
    dataset = ROISegDataset(roi_img=roi_img, mask=roi_mask)
    # For demonstration, we’ll use the same dataset for train & val
    # But in real practice, you should have a separate hold-out set.
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    # 4. Initialize Model, Criterion, Optimizer
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SimpleUNet(in_channels=1, out_channels=1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # 5. Train
    num_epochs = 10
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, loader, optimizer, criterion, device)
        val_dice = validate_one_epoch(model, loader, device)

        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {train_loss:.4f}, Val Dice: {val_dice:.4f}")

    # 6. Save the trained model
    torch.save(model.state_dict(), "simple_unet_roi_seg.pth")

    # Optionally, reload and do inference later
    # model.load_state_dict(torch.load("simple_unet_roi_seg.pth"))

if __name__ == "__main__":
    main()

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


Epoch [1/10] - Loss: 0.4834, Val Dice: 0.9212
Epoch [2/10] - Loss: 0.1630, Val Dice: 0.9632
Epoch [3/10] - Loss: 0.0768, Val Dice: 0.9784
Epoch [4/10] - Loss: 0.0451, Val Dice: 0.9860
Epoch [5/10] - Loss: 0.0292, Val Dice: 0.9906
Epoch [6/10] - Loss: 0.0195, Val Dice: 0.9936
Epoch [7/10] - Loss: 0.0133, Val Dice: 0.9956
Epoch [8/10] - Loss: 0.0091, Val Dice: 0.9970
Epoch [9/10] - Loss: 0.0062, Val Dice: 0.9979
Epoch [10/10] - Loss: 0.0042, Val Dice: 0.9986
