In [2]:
import numpy as np
import pandas as pd
import torch

In [3]:
import os, struct, io, mmap
from pathlib import Path
import numpy as np
import tensorflow as tf
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

def zero_pad_array(arr, target_hw):
    th, tw = target_hw
    h, w = arr.shape[:2]
    pad_h = max(th - h, 0)
    pad_w = max(tw - w, 0)
    pad = ((pad_h // 2, pad_h - pad_h // 2),
           (pad_w // 2, pad_w - pad_w // 2))
    if arr.ndim == 3: pad += ((0, 0),)
    return np.pad(arr, pad, mode='constant')

class BurnScarTFRecordDataset(Dataset):
    MEAN = np.array([1635.8452, 1584.4594, 1456.8425, 2926.6663, 2135.001, 1352.7313], dtype=np.float32)
    STD  = np.array([884.3994, 815.4016, 839.0293, 1055.6382, 751.4628, 628.5323], dtype=np.float32)

    _feature_desc = {                         
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "mask_raw":  tf.io.FixedLenFeature([], tf.string),
        "height":    tf.io.FixedLenFeature([], tf.int64),
        "width":     tf.io.FixedLenFeature([], tf.int64),
        "channels":  tf.io.FixedLenFeature([], tf.int64),
        "temporal_coords": tf.io.FixedLenFeature([2], tf.float32),
        "location_coords": tf.io.FixedLenFeature([2], tf.float32),
    }

    def __init__(self, tfrecord_file, transform=None, pad_to=(512, 512)):
        super().__init__()
        self.tfrecord_path = os.fspath(tfrecord_file)
        self.transform = transform
        self.pad_to = pad_to

        # ---- build a list of byte offsets ----------------------------------
        self._offsets = self._scan_index()
        self._fh = open(self.tfrecord_path, 'rb')

    def _scan_index(self):
        """Return a list with the starting byte of each record."""
        offsets = []
        with open(self.tfrecord_path, 'rb') as f:
            pos = 0
            while True:
                header = f.read(12)
                if not header: break
                rec_len = struct.unpack('<Q', header[:8])[0]
                offsets.append(pos)
                pos += 12 + rec_len + 4
                f.seek(pos)
        return offsets

    def _read_record(self, offset):
        """Seek & return the serialised Example bytes of one record."""
        self._fh.seek(offset)
        header = self._fh.read(12)
        rec_len = struct.unpack('<Q', header[:8])[0]
        data = self._fh.read(rec_len)
        _ = self._fh.read(4)
        return data

    def __len__(self):
        return len(self._offsets)

    def __getitem__(self, idx):
        serialised = self._read_record(self._offsets[idx])
        ex = tf.io.parse_single_example(serialised, self._feature_desc)

        h = int(ex["height"])
        w = int(ex["width"])
        c = int(ex["channels"])

        img = np.frombuffer(ex["image_raw"].numpy(), dtype=np.float32).reshape((h, w, c))
        msk = np.frombuffer(ex["mask_raw"].numpy(),  dtype=np.uint8).reshape((h, w))

        img = np.nan_to_num(img, nan=0.0)
        msk = np.nan_to_num(msk.astype(np.float32), nan=0.0).astype(np.uint8)

        img = (img - self.MEAN) / self.STD
        img = zero_pad_array(img, self.pad_to)
        msk = zero_pad_array(msk, self.pad_to)

        temporal_coords = ex['temporal_coords'].numpy().astype(np.float32)   # (2,)
        location_coords = ex['location_coords'].numpy().astype(np.float32)   # (2,)

        temporal_coords = np.expand_dims(temporal_coords, axis=0)            # (1, 2)

        if self.transform:
            augmented = self.transform(image=img, mask=msk)
            img, msk   = augmented["image"], augmented["mask"]
        else:
            img = torch.from_numpy(img.transpose(2, 0, 1))  # (C,H,W)
            msk = torch.from_numpy(msk)

        out = {
            "image": img.float(),
            "temporal_coords": torch.from_numpy(temporal_coords),
            "location_coords": torch.from_numpy(location_coords),
            "mask": msk.long()
        }
        return out

    def __del__(self):
        try:
            if hasattr(self, '_fh') and not self._fh.closed:
                self._fh.close()
        except Exception:
            pass


transform = A.Compose([
    A.RandomRotate90(p=0.7),       # p: probability of applying this transform
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    ToTensorV2(),
])

In [4]:
train_dataset = BurnScarTFRecordDataset(r"Z:\SPOT\2023\Asare\ssm_scars_unseen_val.tfrecord", transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=True)

In [5]:
from terratorch.tasks import SemanticSegmentationTask

model = SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args=dict(
        backbone="terratorch_prithvi_eo_v2_300_tl",
        backbone_pretrained=True,
        backbone_img_size=512,
        backbone_bands=["BLUE","GREEN","RED","NIR_NARROW","SWIR_1","SWIR_2"],
        necks=[{"name":"SelectIndices", "indices":[5,11,17,23]},
               {"name":"ReshapeTokensToImage"}],
        decoder="FCNDecoder",              
        decoder_channels=256,
        num_classes=3,
        head_dropout=0.1,
    ),
    loss="dice",
    freeze_backbone=False,
    freeze_decoder=False,
    optimizer="AdamW",
    lr=1e-4,
)

  from .autonotebook import tqdm as notebook_tqdm
INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed


In [1]:
model = SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args=dict(
        backbone="terratorch_prithvi_eo_v2_300_tl",
        backbone_pretrained=True,
        backbone_img_size=512,
        backbone_bands=["BLUE","GREEN","RED","NIR_NARROW","SWIR_1","SWIR_2"],
        necks=[
            {"name": "SelectIndices", "indices": [5, 11, 17, 23]},
            {"name": "ReshapeTokensToImage"}
        ],
        decoder="UNetDecoder",
        decoder_channels=[256, 128, 64, 32, 16],   # <--- THIS IS THE KEY
        num_classes=3,
        head_dropout=0.1,
    ),
    freeze_backbone=False,
    freeze_decoder=False,
    optimizer="AdamW",
    lr=1e-4,
)


NameError: name 'SemanticSegmentationTask' is not defined

In [6]:
model.load_state_dict(torch.load("prithvi_state_dict_300m.pt"))

<All keys matched successfully>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # pad x1 to match x2 size (H, W)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=6, n_classes=1, bilinear=True):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Channels: [64, 128, 256, 512, 512, 512, 512, 512] is memory-friendly
        self.inc    = DoubleConv(n_channels, 64)
        self.down1  = Down(64, 128)
        self.down2  = Down(128, 256)
        self.down3  = Down(256, 512)
        self.down4  = Down(512, 512)
        self.down5  = Down(512, 512)
        self.down6  = Down(512, 512)
        self.down7  = Down(512, 512)

        self.up1 = Up(512+512, 512, bilinear)
        self.up2 = Up(512+512, 512, bilinear)
        self.up3 = Up(512+512, 512, bilinear)
        self.up4 = Up(512+512, 512, bilinear)
        self.up5 = Up(512+256, 256, bilinear)
        self.up6 = Up(256+128, 128, bilinear)
        self.up7 = Up(128+64, 64, bilinear)

        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)       # 64, 512, 512
        x2 = self.down1(x1)    # 128, 256, 256
        x3 = self.down2(x2)    # 256, 128, 128
        x4 = self.down3(x3)    # 512, 64, 64
        x5 = self.down4(x4)    # 512, 32, 32
        x6 = self.down5(x5)    # 512, 16, 16
        x7 = self.down6(x6)    # 512, 8, 8
        x8 = self.down7(x7)    # 512, 4, 4

        u1 = self.up1(x8, x7)  # 512, 8, 8
        u2 = self.up2(u1, x6)  # 512, 16, 16
        u3 = self.up3(u2, x5)  # 512, 32, 32
        u4 = self.up4(u3, x4)  # 512, 64, 64
        u5 = self.up5(u4, x3)  # 256, 128, 128
        u6 = self.up6(u5, x2)  # 128, 256, 256
        u7 = self.up7(u6, x1)  # 64, 512, 512

        logits = self.outc(u7)
        return logits

# Example usage:
# model = UNet(n_channels=6, n_classes=1)
# x = torch.randn(1, 6, 512, 512)
# out = model(x)
# print(out.shape)  # Should print torch.Size([1, 1, 512, 512])


In [5]:
model = UNet(n_channels=6, n_classes=3, bilinear=True)
#

In [5]:
model.load_state_dict(torch.load("prithvi_state_dict_unet.pt"))
model.eval()

SemanticSegmentationTask(
  (model): PixelWiseModel(
    (encoder): PrithviViT(
      (patch_embed): PatchEmbed(
        (proj): Conv3d(6, 1024, kernel_size=(1, 16, 16), stride=(1, 16, 16))
        (norm): Identity()
      )
      (temporal_embed_enc): TemporalEncoder()
      (location_embed_enc): LocationEncoder()
      (blocks): ModuleList(
        (0-23): 24 x Block(
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in

In [7]:
def calculate_precision_recall(ground_truth_mask, predicted_mask, num_classes):
    # print(ground_truth_mask.shape, predicted_mask.shape)
    precision = [None] * num_classes  # np.zeros(num_classes)
    recall = [None] * num_classes  # np.zeros(num_classes)
    precision_dict = dict()
    recall_dict = dict()
    for class_id in range(num_classes):
        true_positives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask == class_id))
        false_positives = np.sum(
            (ground_truth_mask != class_id) & (predicted_mask == class_id))
        false_negatives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask != class_id))
        p = None
        r = None
        if true_positives + false_positives != 0:
            p = true_positives / (true_positives + false_positives)
        if true_positives + false_negatives != 0:
            r = true_positives / (true_positives + false_negatives)
        precision_dict[class_id] = p
        recall_dict[class_id] = r
    return precision_dict, recall_dict


def calc_f1_score(precision, recall):
    if precision is None or recall is None or (precision + recall) == 0:
        return None
    return (2 * precision * recall) / (precision + recall)


def calculate_raw_scores(ground_truth_mask, predicted_mask, num_classes):
    # print(ground_truth_mask.shape, predicted_mask.shape)
    # precision = [None] * num_classes  # np.zeros(num_classes)
    # recall = [None] * num_classes  # np.zeros(num_classes)
    # precision_dict = dict()
    # recall_dict = dict()
    true_positives_dict = dict()
    false_positives_dict = dict()
    false_negatives_dict = dict()
    for class_id in range(num_classes):
        true_positives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask == class_id))
        false_positives = np.sum(
            (ground_truth_mask != class_id) & (predicted_mask == class_id))
        false_negatives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask != class_id))
        true_positives_dict[class_id] = true_positives
        false_positives_dict[class_id] = false_positives
        false_negatives_dict[class_id] = false_negatives
        # p = None
        # r = None
        # if true_positives + false_positives != 0:
        #     p = true_positives / (true_positives + false_positives)
        # if true_positives + false_negatives != 0:
        #     r = true_positives / (true_positives + false_negatives)
        # precision_dict[class_id] = p
        # recall_dict[class_id] = r
    return true_positives_dict, false_positives_dict, false_negatives_dict

def calculate_mean_precision_recall(model, val_datagen, num_classes):
    # Replace with the actual number of classes in your model
    # num_classes = 3
    # precision_total_dict = dict()
    # recall_total_dict = dict()
    total_true_positives_dict = dict()
    total_false_positives_dict = dict()
    total_false_negatives_dict = dict()
    precision_scores = np.zeros(num_classes)
    recall_scores = np.zeros(num_classes)
    f1_scores = np.zeros(num_classes)
    # class_iou_scores = np.zeros(num_classes)
    # class_counts = np.zeros(num_classes)

    # Iterate over the validation data generator
    for images, labels in val_datagen:
        for image, label in zip(images, labels):
            # Predict the labels using the model
            predictions = model.predict(np.array([image]), verbose=False)
            predicted_classes = np.argmax(predictions, axis=3)
            # print(predicted_classes.shape)
            true_positives_dict, false_positives_dict, false_negatives_dict = calculate_raw_scores(
                label, predicted_classes[0], num_classes)
            # print(precision_dict, recall_dict)
            for class_id in range(num_classes):
                # p = precision_dict.get(class_id)
                # r = recall_dict.get(class_id)
                ttp = total_true_positives_dict.get(class_id, 0.0)
                tfp = total_false_positives_dict.get(class_id, 0.0)
                tfn = total_false_negatives_dict.get(class_id, 0.0)
                ttp += true_positives_dict[class_id]
                tfp += false_positives_dict[class_id]
                tfn += false_negatives_dict[class_id]
                total_true_positives_dict[class_id] = ttp
                total_false_positives_dict[class_id] = tfp
                total_false_negatives_dict[class_id] = tfn
                # if p != None:
                #     pt = precision_total_dict.get(class_id, [0, 0])
                #     pt[0] += p
                #     pt[1] += 1
                #     precision_total_dict[class_id] = pt
                # if r != None:
                #     rt = recall_total_dict.get(class_id, [0, 0])
                #     rt[0] += r
                #     rt[1] += 1
                #     recall_total_dict[class_id] = rt
        # Compute IoU scores for each class
        # for class_id in range(num_classes):

    # Calculate average IoU scores for each class
    # class_iou_scores /= class_counts
    for class_id in range(num_classes):
        ttp = total_true_positives_dict.get(class_id)
        tfp = total_false_positives_dict.get(class_id)
        tfn = total_false_negatives_dict.get(class_id)
        p = None
        r = None
        if ttp + tfp != 0:
            p = ttp / (ttp + tfp)
        if ttp + tfn != 0:
            r = ttp / (ttp + tfn)
        # pt = precision_total_dict.get(x)
        # rt = recall_total_dict.get(x)
        # total_score = pt[0]
        precision_scores[class_id] = p
        # rt = recall_total_dict.get(x)
        # total_score = pt[0]
        recall_scores[class_id] = r
        f1_scores[class_id] = calc_f1_score(p, r)

    # Create a DataFrame with class labels and corresponding IoU scores
    iou_df = pd.DataFrame(
        {'class_val': range(num_classes), 'precision': precision_scores, 'recall': recall_scores, 'f1 score': f1_scores})
    # print(precision_total_dict)
    return iou_df


def calculate_mean_precision_recall(model_func, model, val_datagen, num_classes, device):
    total_true_positives_dict = {c: 0 for c in range(num_classes)}
    total_false_positives_dict = {c: 0 for c in range(num_classes)}
    total_false_negatives_dict = {c: 0 for c in range(num_classes)}

    # For batchwise prediction, create predictor
    predict = model_func(model, device)

    for batch in val_datagen:
        # images: batch of shape [B, C, H, W] (torch.Tensor or np.ndarray)
        # labels: batch of shape [B, H, W] (numpy or torch)
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        temporal_coords=batch["temporal_coords"].to(device)
        location_coords=batch["location_coords"].to(device)
        # if torch.is_tensor(images):
        #     images = images.to(device)
        pred_masks = predict(images, temporal_coords=temporal_coords, location_coords=location_coords)  # shape: [B, H, W] (np.ndarray)
        if torch.is_tensor(masks):
            masks = masks.cpu().numpy()
        for pred_mask, label in zip(pred_masks, masks):
            true_positives_dict, false_positives_dict, false_negatives_dict = calculate_raw_scores(
                label, pred_mask, num_classes)
            for class_id in range(num_classes):
                total_true_positives_dict[class_id] += true_positives_dict[class_id]
                total_false_positives_dict[class_id] += false_positives_dict[class_id]
                total_false_negatives_dict[class_id] += false_negatives_dict[class_id]

    # Aggregate results
    precision_scores = np.zeros(num_classes)
    recall_scores = np.zeros(num_classes)
    f1_scores = np.zeros(num_classes)
    iou_scores = np.zeros(num_classes)

    for class_id in range(num_classes):
        ttp = total_true_positives_dict[class_id]
        tfp = total_false_positives_dict[class_id]
        tfn = total_false_negatives_dict[class_id]
        denom = ttp + tfp + tfn
        iou = ttp / denom if denom != 0 else None
        p = ttp / (ttp + tfp) if (ttp + tfp) != 0 else None
        r = ttp / (ttp + tfn) if (ttp + tfn) != 0 else None
        f1 = calc_f1_score(p, r)
        precision_scores[class_id] = p if p is not None else np.nan
        recall_scores[class_id] = r if r is not None else np.nan
        f1_scores[class_id] = f1 if f1 is not None else np.nan
        iou_scores[class_id] = iou if iou is not None else np.nan

    iou_df = pd.DataFrame({
        'class_val': range(num_classes),
        'iou': iou_scores,
        'precision': precision_scores,
        'recall': recall_scores,
        'f1 score': f1_scores
    })
    return iou_df




In [7]:
import torch
import numpy as np
import pandas as pd

def model_predict(model, device):
    """
    Returns a function that predicts class masks from input tensors using a model.
    Handles both single image and batch input.
    """
    def predict(images, temporal_coords, location_coords):
        model.to(device)
        model.eval()
        with torch.no_grad():
            if isinstance(images, np.ndarray):
                images = torch.from_numpy(images)
            images = images.to(device)
            if images.dim() == 3:  # Single image, shape: [C, H, W]
                images = images.unsqueeze(0)
            outputs = model(images).output  # [B, num_classes, H, W]
            pred_masks = torch.argmax(outputs, dim=1)  # [B, H, W]
            pred_masks = pred_masks.cpu().numpy()
        return pred_masks
    return predict


In [8]:
def model_predict(model, device):
    """
    Returns a function that predicts class masks from input tensors using a model.
    Handles both single image and batch input.
    Supports optional temporal_coords and location_coords.
    """
    def predict(images, temporal_coords=None, location_coords=None):
        model.to(device)
        model.eval()
        with torch.no_grad():
            if isinstance(images, np.ndarray):
                images = torch.from_numpy(images)
            images = images.to(device)
            if images.dim() == 3:  # Single image, shape: [C, H, W]
                images = images.unsqueeze(0)

            # Prepare optional kwargs
            kwargs = {}
            if temporal_coords is not None:
                if isinstance(temporal_coords, np.ndarray):
                    temporal_coords = torch.from_numpy(temporal_coords)
                kwargs["temporal_coords"] = temporal_coords.to(device)
            if location_coords is not None:
                if isinstance(location_coords, np.ndarray):
                    location_coords = torch.from_numpy(location_coords)
                kwargs["location_coords"] = location_coords.to(device)

            outputs = model(images, **kwargs).output  # [B, num_classes, H, W]
            pred_masks = torch.argmax(outputs, dim=1)  # [B, H, W]
            pred_masks = pred_masks.cpu().numpy()
        return pred_masks
    return predict


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [12]:
# Suppose you have:
# val_datagen: yields (images, labels) where images.shape = [B, C, H, W], labels.shape = [B, H, W]
# model: your PyTorch segmentation model
# device: torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# num_classes: int

# Pass the predictor factory:
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print(results_df)

   class_val       iou  precision    recall  f1 score
0          0  0.987966   0.991099  0.996810  0.993947
1          1  0.666150   0.880803  0.732152  0.799628
2          2  0.303249   0.511626  0.426792  0.465374


In [10]:
# Pass the predictor factory:
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print(results_df)

   class_val       iou  precision    recall  f1 score
0          0  0.988314   0.991923  0.996332  0.994122
1          1  0.678351   0.870483  0.754503  0.808354
2          2  0.307280   0.472827  0.467415  0.470105


In [10]:
# Suppose you have:
# val_datagen: yields (images, labels) where images.shape = [B, C, H, W], labels.shape = [B, H, W]
# model: your PyTorch segmentation model
# device: torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# num_classes: int

# Pass the predictor factory:
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print(results_df)

   class_val       iou  precision    recall  f1 score
0          0  0.988514   0.992639  0.995814  0.994224
1          1  0.688667   0.858037  0.777224  0.815634
2          2  0.308345   0.489770  0.454268  0.471352


In [None]:
# Pass the predictor factory:
#Val
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print(results_df)

   class_val       iou  precision    recall  f1 score
0          0  0.997108   0.998244  0.998860  0.998552
1          1  0.834184   0.927985  0.891923  0.909597
2          2  0.834855   0.918178  0.901958  0.909996


In [11]:
# Pass the predictor factory:
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print(results_df)

   class_val       iou  precision    recall  f1 score
0          0  0.997204   0.998299  0.998901  0.998600
1          1  0.837536   0.929852  0.894024  0.911586
2          2  0.843570   0.922625  0.907793  0.915149


In [9]:
# Pass the predictor factory:
#Val
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print(results_df)

   class_val       iou  precision    recall  f1 score
0          0  0.997761   0.998847  0.998912  0.998879
1          1  0.871322   0.933205  0.929277  0.931237
2          2  0.869823   0.930450  0.930310  0.930380


In [10]:
# Pass the predictor factory:
#Val
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print('cnn val')
print(results_df)

cnn val
   class_val       iou  precision    recall  f1 score
0          0  0.994738   0.996022  0.998706  0.997362
1          1  0.698693   0.901775  0.756248  0.822625
2          2  0.676976   0.899543  0.732342  0.807377


In [12]:
#Val
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print('cnn train')
print(results_df)

cnn train
   class_val       iou  precision    recall  f1 score
0          0  0.994855   0.996136  0.998709  0.997421
1          1  0.702235   0.901574  0.760541  0.825074
2          2  0.689044   0.901299  0.745280  0.815898


In [12]:
# Pass the predictor factory:
#Val
results_df = calculate_mean_precision_recall(model_predict, model, train_loader, 3, device)
print('cnn val')
print(results_df)

cnn val
   class_val       iou  precision    recall  f1 score
0          0  0.996057   0.998178  0.997872  0.998025
1          1  0.785465   0.871080  0.888785  0.879844
2          2  0.778125   0.877069  0.873379  0.875220


In [7]:
import os
import re
import rasterio
from rasterio.windows import Window
import numpy as np
from PIL import Image
import torch

def extract_year_from_filename(filename):
    match = re.search(r"(20\d{2})", filename)
    return int(match.group(1)) if match else 2023

def get_lat_lon_from_metadata(meta):
    try:
        bounds = meta['bounds']
        lat = (bounds.top + bounds.bottom) / 2
        lon = (bounds.left + bounds.right) / 2
        return lat, lon
    except Exception:
        return 0.0, 0.0

def preprocess(img):
    mean = np.array([1635.8452, 1584.4594, 1456.8425, 2926.6663, 2135.001, 1352.7313], dtype=np.float32)
    std  = np.array([884.3994, 815.4016, 839.0293, 1055.6382, 751.4628, 628.5323], dtype=np.float32)
    img = (img - mean[:, None, None]) / std[:, None, None]
    return torch.from_numpy(img).float()

def process_geotiff_patchwise(
    real_path, mask_path, model, out_fp_path, out_fn_path,
    patch_size=512, preprocess_fn=None, device='cuda'
):
    with rasterio.open(real_path) as src:
        meta = src.meta.copy()
        H, W = src.height, src.width
        C = src.count
        lat, lon = get_lat_lon_from_metadata(src)

        # --- Load mask and resize to match image shape ---
        mask = np.array(Image.open(mask_path))
        if mask.shape != (H, W):
            mask = np.array(Image.fromarray(mask).resize((W, H), Image.NEAREST))

        # --- Prepare output arrays ---
        fp_full = np.zeros((H, W), dtype=np.uint8)
        fn_full = np.zeros((H, W), dtype=np.uint8)

        # --- Prepare temporal/location coords ---
        fname = os.path.splitext(os.path.basename(real_path))[0]
        year = extract_year_from_filename(fname)
        doy = 1

        # --- Sliding window over the image ---
        step = patch_size
        for row in range(0, H, step):
            for col in range(0, W, step):
                win_h = min(step, H - row)
                win_w = min(step, W - col)

                # Read patch
                img_patch = src.read(window=Window(col, row, win_w, win_h)).astype(np.float32)  # (C, h, w)
                # Mask patch
                mask_patch = mask[row:row+win_h, col:col+win_w]

                # Zero pad if needed
                pad_h = patch_size - win_h
                pad_w = patch_size - win_w
                if pad_h > 0 or pad_w > 0:
                    img_patch = np.pad(img_patch, ((0,0), (0,pad_h), (0,pad_w)), mode='constant')
                    mask_patch = np.pad(mask_patch, ((0,pad_h), (0,pad_w)), mode='constant')

                # Preprocess and model inference
                img_input = preprocess_fn(img_patch) if preprocess_fn else torch.from_numpy(img_patch)
                img_input = img_input.unsqueeze(0).to(device)

                temporal_coords = torch.tensor([[year, doy]], dtype=torch.float32).unsqueeze(1).to(device)
                location_coords = torch.tensor([[lat, lon]], dtype=torch.float32).to(device)

                with torch.no_grad():
                    outputs = model(
                        img_input,
                        temporal_coords=temporal_coords,
                        location_coords=location_coords,
                    )
                    logits = outputs.output if hasattr(outputs, "output") else outputs
                    pred_mask_patch = torch.argmax(logits, dim=1).cpu().numpy()[0]  # (h, w)

                # Crop back to original (non-padded) shape
                pred_mask_patch = pred_mask_patch[:win_h, :win_w]
                mask_patch = mask_patch[:win_h, :win_w]

                # Compute FP/FN patch
                fp_patch = ((pred_mask_patch != mask_patch) & (pred_mask_patch != 0)) * pred_mask_patch
                fn_patch = ((pred_mask_patch != mask_patch) & (mask_patch != 0)) * mask_patch

                # Paste patch into full output arrays
                fp_full[row:row+win_h, col:col+win_w] = fp_patch
                fn_full[row:row+win_h, col:col+win_w] = fn_patch

        # Write output GeoTIFFs
        meta.update(dtype=rasterio.uint8, count=1)
        with rasterio.open(out_fp_path, "w", **meta) as dst:
            dst.write(fp_full, 1)
        with rasterio.open(out_fn_path, "w", **meta) as dst:
            dst.write(fn_full, 1)

def batch_generate_fp_fn(
    real_dir, mask_dir, model, out_dir, patch_size=512,
    preprocess_fn=None, device='cuda'
):
    os.makedirs(out_dir, exist_ok=True)
    real_files = sorted([f for f in os.listdir(real_dir) if f.endswith(".tif")])

    for real_fname in real_files:
        real_path = os.path.join(real_dir, real_fname)
        mask_path = os.path.join(mask_dir, os.path.splitext(real_fname)[0] + ".png")
        if not os.path.exists(mask_path):
            print(f"Missing mask for {real_fname}, skipping.")
            continue

        fp_path = os.path.join(out_dir, os.path.splitext(real_fname)[0] + "_fp.tif")
        fn_path = os.path.join(out_dir, os.path.splitext(real_fname)[0] + "_fn.tif")
        process_geotiff_patchwise(
            real_path, mask_path, model, fp_path, fn_path,
            patch_size=patch_size, preprocess_fn=preprocess_fn, device=device
        )
        print(f"Processed {real_fname}")

# Example usage:
# batch_generate_fp_fn(
#     real_dir="real",
#     mask_dir="mask",
#     model=your_model,
#     out_dir="fpfn_output",
#     patch_size=512,
#     preprocess_fn=preprocess,
#     device='cuda'
# )


In [10]:
model = model.to(device)
batch_generate_fp_fn(
    real_dir=r"D:\sentinel_ssm\data\segmentation\original_val\real",
    mask_dir=r"D:\sentinel_ssm\data\segmentation\original_val\mask",
    model=model,
    out_dir=r"D:\sentinel_ssm\data\segmentation\fpfn",
    patch_size=512,
    preprocess_fn=preprocess,
    device='cuda'
)


Processed sentinel_h_ssm_aoi_2019.tif
Processed sentinel_sr_h_ssm_aoi_2019.tif
