In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import rasterio
from rasterio.windows import Window
from tqdm import tqdm
import torchvision.transforms as T
import segmentation_models_pytorch as smp

In [None]:
# Configuration
MODEL_PATH = 'models/resnext101-32x8d-sentinel2-8band/best_model_ce_tv_d13_m07_y2025_h16_m56_s23.pth'
NUM_CLASSES = 11
NUM_CHANNELS = 8
BATCH_SIZE = 32
ENCODER_NAME = "resnext101_32x8d"
ENCODER_WEIGHTS = "imagenet"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import rasterio
from rasterio.windows import Window

IMG_PATH = "data/Sentinel2_Annual_Composite_2025-0000189440-0000246272_8band.tif"
CROP_PATH = "data/sentinel2_4864_top_left.tif"

with rasterio.open(IMG_PATH) as src:
    H, W = src.height, src.width
    win = Window(0, 0, 4864, 4864)
    kwargs = src.meta.copy()
    kwargs.update({'height': 4864, 'width': 4864, 'transform': rasterio.windows.transform(win, src.transform)})
    with rasterio.open(CROP_PATH, 'w', **kwargs) as dst:
        for i in range(1, src.count + 1):
            dst.write(src.read(i, window=win), i)
print('Cutout saved to:', CROP_PATH)

In [None]:
# Model setup
model = smp.UnetPlusPlus(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=NUM_CHANNELS,
    classes=NUM_CLASSES,
    activation=None
).to(DEVICE)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully")

In [None]:
class FillInvalid:
    def __init__(self, global_means, last_k=2):
        self.global_means = global_means
        self.last_k = last_k

    def __call__(self, image):
        img = torch.as_tensor(image).float()
        invalid = ~torch.isfinite(img)
        C = img.shape[0]

        for c in range(C):
            mask_c = invalid[c]
            if not mask_c.any():
                continue

            if c >= C - self.last_k:
                med = torch.nanmedian(img[c])
                img[c][mask_c] = med
            else:
                img[c][mask_c] = self.global_means[c]

        return img

GLOBAL_MEANS = [0.03402944654226303, 0.04915359988808632, 0.056084536015987396,
                0.1244724690914154, 0.12229487299919128, 0.09260836988687515,
                0.17983973026275635, -0.011018575169146061]
fill_invalid = FillInvalid(GLOBAL_MEANS)

# Image normalization
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
normalize = T.Normalize(mean=imagenet_mean, std=imagenet_std)

In [None]:
class SlidingWindowDataset(Dataset):
    def __init__(self, img_path, fill_invalid, window_size=256, stride=128):
        self.img_path = img_path
        self.window_size = window_size
        self.stride = stride
        self.fill_invalid = fill_invalid
        with rasterio.open(img_path) as src:
            self.H, self.W = src.height, src.width
        self.window_coords = [
            (y, x)
            for y in range(0, self.H - window_size + 1, stride)
            for x in range(0, self.W - window_size + 1, stride)
        ]

    def __len__(self):
        return len(self.window_coords)

    def __getitem__(self, idx):
        y, x = self.window_coords[idx]
        with rasterio.open(self.img_path) as src:
            patch = src.read(window=Window(x, y, self.window_size, self.window_size)).astype(np.float32)
        tensor = torch.from_numpy(patch)   # shape: (8, H, W)
        tensor[:3] = normalize(tensor[:3]) # Normalize first 3 channels
        tensor = self.fill_invalid(tensor) # Fill invalids (on tensor)
        return tensor, y, x

In [None]:
def batched_sliding_window_inference(
    IMG_PATH,
    OUT_PATH,
    model,
    fill_invalid,
    NUM_CLASSES,
    DEVICE,
    window_size=256,
    stride=128,
    batch_size=16,
    verbose=True
):
    """
    Sliding window inference using DataLoader batching (this implementation should match the test pipeline).
    """
    # Prepare dataset and loader
    dataset = SlidingWindowDataset(IMG_PATH, fill_invalid, window_size, stride)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    with rasterio.open(IMG_PATH) as src:
        H, W = src.height, src.width
        meta = src.meta.copy()
        meta.update({"count": 1, "dtype": 'uint8', "nodata": 0})

    prob_accumulator = np.zeros((NUM_CLASSES, H, W), dtype=np.float32)
    count_mask = np.zeros((H, W), dtype=np.float32)

    model.eval()
    with torch.no_grad():
        for batch, ys, xs in tqdm(loader, total=len(loader), disable=not verbose, desc="Sliding Window"):
            batch = batch.to(DEVICE)  # shape: (B, 8, window, window)
            logits = model(batch)  # shape: (B, NUM_CLASSES, H, W)
            probs = torch.softmax(logits, dim=1).cpu().numpy()  # shape: (B, C, H, W)

            for i in range(batch.shape[0]):
                y, x = int(ys[i]), int(xs[i])
                prob_accumulator[:, y:y+window_size, x:x+window_size] += probs[i]
                count_mask[y:y+window_size, x:x+window_size] += 1

    # Normalize accumulated probabilities and compute final mask
    count_mask[count_mask == 0] = 1  # avoid division by zero
    final_probs = prob_accumulator / count_mask[None, :, :]
    final_mask = np.argmax(final_probs, axis=0).astype(np.uint8)

    with rasterio.open(OUT_PATH, 'w', **meta) as dst:
        dst.write(final_mask, 1)

    if verbose:
        print(f"Saved sliding window prediction to {OUT_PATH}")

    return final_mask

In [None]:
import os

# Set parameters
STRIDE=16
IMG_PATH = "data/sentinel2_4864_top_left.tif"
os.makedirs('predictions', exist_ok=True)
OUT_PATH = f"predictions/prediction_mask_{STRIDE}.tif"
NUM_CLASSES = 11
DEVICE = DEVICE     # "cuda" or "cpu"
BATCH_SIZE = 16


mask = batched_sliding_window_inference(
    IMG_PATH,
    OUT_PATH,
    model,
    fill_invalid,
    NUM_CLASSES,
    DEVICE,
    window_size=256,
    stride=STRIDE,
    batch_size=BATCH_SIZE,
)

In [None]:
PALETTE = np.array([
    [0, 0, 0, 0],  # Transparent background
    [1, 0, 0, 0.7],  # Red
    [0, 1, 0, 0.7],  # Green
    [0, 0, 1, 0.7],  # Blue
    [1, 1, 0, 0.7],  # Yellow
    [1, 0, 1, 0.7],  # Magenta
    [0, 1, 1, 0.7],  # Cyan
    [1, 0.5, 0, 0.7],  # Orange
    [0.5, 0, 1, 0.7],  # Purple
    [0, 0.5, 0.5, 0.7],  # Teal
    [0.5, 0.5, 0, 0.7]  # Olive
])

In [None]:
import rasterio
import numpy as np
import matplotlib.pyplot as plt

def save_colored_mask(pred_mask_path, output_path, palette):
    with rasterio.open(pred_mask_path) as src:
        mask = src.read(1)
        profile = src.profile.copy()

    if palette.max() <= 1.01:
        color_mask = (palette[mask] * 255).astype(np.uint8)
    else:
        color_mask = palette[mask].astype(np.uint8)

    profile.update({'count': 3, 'dtype': 'uint8'})

    with rasterio.open(output_path, 'w', **profile) as dst:
        for i in range(3):
            dst.write(color_mask[:, :, i], i + 1)

    print(f"Colored mask saved to: {output_path}")


def save_overlay_png(img_path, mask_path, palette, out_path="overlay.png", alpha=0.9, dpi=300):
    with rasterio.open(img_path) as src:
        rgb = np.stack([src.read(3), src.read(2), src.read(1)], axis=-1).astype(np.float32)
        rgb = (rgb - np.percentile(rgb, 2)) / (np.percentile(rgb, 98) - np.percentile(rgb, 2) + 1e-8)
        rgb = np.clip(rgb, 0, 1)

    with rasterio.open(mask_path) as src:
        mask = src.read(1)

    color_mask = palette[mask]
    overlay = rgb.copy()
    non_bg = mask != 0

    for c in range(3):
        overlay[..., c][non_bg] = (
            (1 - alpha) * rgb[..., c][non_bg] + alpha * color_mask[..., c][non_bg]
        )

    overlay = np.clip(overlay, 0, 1)
    h, w = overlay.shape[:2]
    figsize = (w / dpi, h / dpi)

    plt.figure(figsize=figsize, dpi=dpi)
    plt.axis('off')
    plt.imshow(overlay)
    plt.tight_layout(pad=0)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Overlay saved to: {out_path}")


In [None]:
STRIDE=16
PRED_PATH = f"predictions/prediction_mask_{STRIDE}.tif"
IMG_PATH = "data/sentinel2_4864_top_left.tif"
MASK_RGB_PATH = f"predictions/colored_mask_{STRIDE}.tif"
OVERLAY_PNG_PATH = f"predictions/overlay_visualization_{STRIDE}.png"

save_colored_mask(PRED_PATH, MASK_RGB_PATH, PALETTE)
save_overlay_png(IMG_PATH, PRED_PATH, PALETTE, out_path=OVERLAY_PNG_PATH)