In [1]:
from utils import UNetResNet50, check_band_coverage, pansharpen_to_10m_and_save, read_pansharpened_tiff
import torch
import numpy as np
from tqdm import tqdm

In [19]:
check_band_coverage("data/S2A_MSIL1C_20210508T092031_N0500_R093_T34SEH_20230303T013318.SAFE/GRANULE/L1C_T34SEH_A030692_20210508T092338/IMG_DATA")


 L1C_T34SEH_A030692_20210508T092338

Band   Resolution Shape           Valid Pixels (%)  
------------------------------------------------------------
B01    60.0       (1830, 1830)    100.00            
B02    10.0       (10980, 10980)  100.00            
B03    10.0       (10980, 10980)  100.00            
B04    10.0       (10980, 10980)  100.00            
B05    20.0       (5490, 5490)    100.00            
B06    20.0       (5490, 5490)    100.00            
B07    20.0       (5490, 5490)    100.00            
B08    10.0       (10980, 10980)  100.00            
B09    60.0       (1830, 1830)    100.00            
B10    60.0       (1830, 1830)    100.00            
B11    20.0       (5490, 5490)    100.00            
B12    20.0       (5490, 5490)    100.00            
B8A    20.0       (5490, 5490)    100.00            


In [20]:
pansharpen_to_10m_and_save("data/S2A_MSIL1C_20210508T092031_N0500_R093_T34SEH_20230303T013318.SAFE/GRANULE/L1C_T34SEH_A030692_20210508T092338/IMG_DATA", output_tiff="data/T34SEH_pansharpened.tif")

['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
1 B01
2 B02
3 B03
4 B04
5 B05
6 B06
7 B07
8 B08
9 B8A
10 B09
11 B10
12 B11
13 B12

Saved pansharpened image to: data/T34SEH_pansharpened.tif


In [None]:
def predict_full_tile(model, tile_array, device, patch_size=128, stride=64, num_classes=10):
    """
    Run sliding window inference on a large tile.

    Args:
        model: Trained UNetResNet50 model.
        tile_array: Numpy array of shape (13, H, W).
        device: 'cuda' or 'cpu'.
        patch_size: Size of each square patch (default 128).
        stride: Step size between patches (default 64).
        num_classes: Number of classes in output mask.

    Returns:
        final_prediction: 2D numpy array (H, W) of predicted class IDs.
    """
    model.eval()
    _, H, W = tile_array.shape
    output_probs = np.zeros((num_classes, H, W), dtype=np.float32)
    count = np.zeros((H, W), dtype=np.float32)

    with torch.no_grad():
        for row in tqdm(range(0, H - patch_size + 1, stride)):
            for col in range(0, W - patch_size + 1, stride):
                patch = tile_array[:, row:row + patch_size, col:col + patch_size]
                patch_tensor = torch.from_numpy(patch).unsqueeze(0).to(device).float()  # (1, 13, 128, 128)

                logits = model(patch_tensor)  # (1, num_classes, 128, 128)
                probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()  # (num_classes, 128, 128)

                output_probs[:, row:row + patch_size, col:col + patch_size] += probs
                count[row:row + patch_size, col:col + patch_size] += 1

    avg_probs = output_probs / count.clip(min=1e-8)  # Avoid divide by zero
    final_prediction = np.argmax(avg_probs, axis=0).astype(np.uint8)
    np.save("predicted_mask.npy", final_prediction)
    print("Prediction Saved")
    return final_prediction


In [21]:
tile = read_pansharpened_tiff("data/T34SEH_pansharpened_aligned.tif")
tile.shape

Stack shape: (13, 10979, 10979)


(13, 10979, 10979)

In [22]:
tile = read_pansharpened_tiff("data/T34SEH_pansharpened.tif")
tile.shape

Stack shape: (13, 10980, 10980)


(13, 10980, 10980)

In [2]:
import rasterio
from tqdm import tqdm
import numpy as np
import torch

def predict_full_tile_streaming(model, tiff_path, device, patch_size=128, stride=64, num_classes=8, output_path="predicted_mask.npy"):
    """
    Perform sliding window inference on a large TIFF tile using disk streaming.

    Args:
        model: Trained model.
        tiff_path: Path to the input GeoTIFF.
        device: 'cuda' or 'cpu'.
        patch_size: Size of square patches.
        stride: Step size between patches.
        num_classes: Number of classes.
        output_path: Where to save the prediction.
    """
    model.eval()

    with rasterio.open(tiff_path) as src:
        H, W = src.height, src.width
        output_probs = np.zeros((num_classes, H, W), dtype=np.float32)
        count = np.zeros((H, W), dtype=np.float32)

        with torch.no_grad():
            for row in tqdm(range(0, H - patch_size + 1, stride)):
                for col in range(0, W - patch_size + 1, stride):
                    # Read patch from file: (bands, patch_size, patch_size)
                    patch = src.read(window=rasterio.windows.Window(col, row, patch_size, patch_size))
                    if patch.shape[1] != patch_size or patch.shape[2] != patch_size:
                        continue  # Skip edge cases if you don't pad

                    # Normalize and convert to torch tensor
                    patch = patch.astype(np.float32) / 10000.0
                    patch = np.clip(patch, 0.0, 1.0)
                    patch_tensor = torch.from_numpy(patch).unsqueeze(0).to(device)  # (1, 13, H, W)

                    logits = model(patch_tensor)
                    probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()

                    output_probs[:, row:row + patch_size, col:col + patch_size] += probs
                    count[row:row + patch_size, col:col + patch_size] += 1

        avg_probs = output_probs / np.clip(count, 1e-8, None)
        final_prediction = np.argmax(avg_probs, axis=0).astype(np.uint8)
        np.save(output_path, final_prediction)
        print(f"Prediction saved to {output_path}")
        return final_prediction


In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UNetResNet50(num_classes=8).to(device)
model.load_state_dict(torch.load("best_model_lr_0.0001.pt", map_location=device, weights_only=True))
pred_mask = predict_full_tile_streaming(model, "data/T34SEH_pansharpened.tif", device, patch_size=128, stride=64, num_classes=8, output_path="predicted_mask.npy")

100%|██████████| 170/170 [04:33<00:00,  1.61s/it]


Prediction saved to predicted_mask.npy


In [4]:
import rasterio
from rasterio.transform import from_origin

def save_prediction_geotiff(prediction, reference_tiff_path, output_path):
    """
    Save the predicted mask as a GeoTIFF using the georeferencing info from the original tile.

    Args:
        prediction (np.ndarray): 2D array of predicted class IDs (H, W).
        reference_tiff_path (str): Path to the original input GeoTIFF (e.g., Sentinel-2 tile).
        output_path (str): Path to save the output GeoTIFF.
    """
    with rasterio.open(reference_tiff_path) as src:
        profile = src.profile
        transform = src.transform
        crs = src.crs

    # Update profile for single-band uint8 mask
    profile.update({
        'driver': 'GTiff',
        'height': prediction.shape[0],
        'width': prediction.shape[1],
        'count': 1,
        'dtype': 'uint8',
        'transform': transform,
        'crs': crs
    })

    with rasterio.open(output_path, 'w', **profile) as dst:
        dst.write(prediction, 1)

    print(f"GeoTIFF saved to: {output_path}")


save_prediction_geotiff(pred_mask, "data/T34SEH_pansharpened.tif", "pred_mask.tif")

GeoTIFF saved to: pred_mask.tif


In [5]:
id2idx = {10: 0, 20: 1, 30: 2, 40: 3, 50: 4, 60: 5, 80: 6, 90: 7}
idx2id = {v: k for k, v in id2idx.items()}

# remapping
remapped_mask = np.vectorize(idx2id.get)(pred_mask).astype(np.uint8)
np.unique(pred_mask), np.unique(remapped_mask)

(array([0, 1, 2, 3, 4, 5, 6], dtype=uint8),
 array([10, 20, 30, 40, 50, 60, 80], dtype=uint8))

In [8]:
from sklearn.metrics import confusion_matrix
import numpy as np
import rasterio

def load_reference_mask(path):
    with rasterio.open(path) as src:
        mask = src.read(1)  # Read first band
    return mask


def compute_metrics(pred_mask, true_mask, class_ids):
    """
    Evaluate segmentation results.

    Args:
        pred_mask: 2D numpy array of predicted class IDs.
        true_mask: 2D numpy array of true class IDs.
        class_ids: List of valid class IDs (e.g., [10, 20, 30, ..., 90]).

    Returns:
        Dictionary with pixel accuracy, per-class IoU, and mean IoU.
    """
    # Flatten for easier computation
    print("Prediction shape:", pred_mask.shape)
    print("Ground truth shape:", true_mask.shape)

    pred = pred_mask.flatten()
    true = true_mask.flatten()

    # Mask out invalid labels (e.g., background or nodata in ground truth)
    valid = np.isin(true, class_ids)
    pred = pred[valid]
    true = true[valid]

    # Confusion matrix using class IDs directly
    cm = confusion_matrix(true, pred, labels=class_ids)

    intersection = np.diag(cm)
    ground_truth_set = cm.sum(axis=1)
    predicted_set = cm.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection

    iou_per_class = intersection / np.clip(union, 1e-8, None)
    mean_iou = np.mean(iou_per_class)
    pixel_accuracy = np.sum(intersection) / np.clip(np.sum(cm), 1e-8, None)

    return {
        "pixel_accuracy": pixel_accuracy,
        "iou_per_class": dict(zip(class_ids, iou_per_class)),
        "mean_iou": mean_iou
    }


In [9]:
true_mask = load_reference_mask("data/GBDA24_ex2_34SEH_ref_data.tif")

# List of all valid class IDs
class_ids = [10, 20, 30, 40, 50, 60, 80, 90]

# Evaluate
metrics = compute_metrics(remapped_mask, true_mask, class_ids)
print("Pixel Accuracy:", metrics["pixel_accuracy"])
print("Per-Class IoU:", metrics["iou_per_class"])
print("Mean IoU:", metrics["mean_iou"])


Prediction shape: (10980, 10980)
Ground truth shape: (10980, 10980)
Pixel Accuracy: 0.7508253886707643
Per-Class IoU: {10: np.float64(0.7258283394437622), 20: np.float64(0.04968237770667938), 30: np.float64(0.18715757604524064), 40: np.float64(0.2584323683147882), 50: np.float64(0.1721664531071201), 60: np.float64(0.04611749934373815), 80: np.float64(0.956728629252614), 90: np.float64(0.0)}
Mean IoU: 0.29951415540174287
