<a href="https://colab.research.google.com/github/easare377/Prithvi-EO-Segmentation/blob/main/model_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install terratorch

Collecting terratorch
  Downloading terratorch-1.0.2-py3-none-any.whl.metadata (10 kB)
Collecting torch==2.7.0 (from terratorch)
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision==0.22.0 (from terratorch)
  Downloading torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting rioxarray==0.19.0 (from terratorch)
  Downloading rioxarray-0.19.0-py3-none-any.whl.metadata (5.5 kB)
Collecting albumentations==1.4.6 (from terratorch)
  Downloading albumentations-1.4.6-py3-none-any.whl.metadata (37 kB)
Collecting albucore==0.0.16 (from terratorch)
  Downloading albucore-0.0.16-py3-none-any.whl.metadata (3.1 kB)
Collecting rasterio==1.4.3 (from terratorch)
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting torchmetrics==1.7.1 (from terratorch)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting geopandas==1.0.1 (from terrato

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import numpy as np
import pandas as pd
import torch

In [2]:
import os, struct, io, mmap
from pathlib import Path
import numpy as np
import tensorflow as tf
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

def zero_pad_array(arr, target_hw):
    th, tw = target_hw
    h, w = arr.shape[:2]
    pad_h = max(th - h, 0)
    pad_w = max(tw - w, 0)
    pad = ((pad_h // 2, pad_h - pad_h // 2),
           (pad_w // 2, pad_w - pad_w // 2))
    if arr.ndim == 3: pad += ((0, 0),)
    return np.pad(arr, pad, mode='constant')

import os
import struct
import numpy as np
import tensorflow as tf
import torch
from torch.utils.data import Dataset

def zero_pad_array(arr, target_hw):
    th, tw = target_hw
    h, w = arr.shape[:2]
    pad_h = max(th - h, 0)
    pad_w = max(tw - w, 0)
    pad = ((pad_h // 2, pad_h - pad_h // 2),
           (pad_w // 2, pad_w - pad_w // 2))
    if arr.ndim == 3: pad += ((0, 0),)
    return np.pad(arr, pad, mode='constant')


class MineFootprintTFRecordDataset(Dataset):
    MEAN = np.array([1635.8452, 1584.4594, 1456.8425, 2926.6663, 2135.001, 1352.7313], dtype=np.float32)
    STD  = np.array([884.3994, 815.4016, 839.0293, 1055.6382, 751.4628, 628.5323], dtype=np.float32)

    _feature_desc = {
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "mask_raw":  tf.io.FixedLenFeature([], tf.string),
        "height":    tf.io.FixedLenFeature([], tf.int64),
        "width":     tf.io.FixedLenFeature([], tf.int64),
        "channels":  tf.io.FixedLenFeature([], tf.int64),
        "temporal_coords": tf.io.FixedLenFeature([2], tf.float32),
        "location_coords": tf.io.FixedLenFeature([2], tf.float32),
    }

    def __init__(self, tfrecord_file, transform=None, pad_to=(224, 224)):
        super().__init__()
        self.tfrecord_path = os.fspath(tfrecord_file)
        self.transform = transform
        self.pad_to = pad_to

        # ---- build a list of byte offsets ----------------------------------
        self._offsets = self._scan_index()
        self._fh = open(self.tfrecord_path, 'rb')

    def _scan_index(self):
        """Return a list with the starting byte of each record."""
        offsets = []
        with open(self.tfrecord_path, 'rb') as f:
            pos = 0
            while True:
                header = f.read(12)
                if not header: break
                rec_len = struct.unpack('<Q', header[:8])[0]
                offsets.append(pos)
                pos += 12 + rec_len + 4
                f.seek(pos)
        return offsets

    def _read_record(self, offset):
        """Seek & return the serialised Example bytes of one record."""
        self._fh.seek(offset)
        header = self._fh.read(12)
        rec_len = struct.unpack('<Q', header[:8])[0]
        data = self._fh.read(rec_len)
        _ = self._fh.read(4)
        return data

    def __len__(self):
        return len(self._offsets)

    def __getitem__(self, idx):
        serialised = self._read_record(self._offsets[idx])
        ex = tf.io.parse_single_example(serialised, self._feature_desc)

        h = int(ex["height"])
        w = int(ex["width"])
        c = int(ex["channels"])

        img = np.frombuffer(ex["image_raw"].numpy(), dtype=np.float32).reshape((h, w, c))
        msk = np.frombuffer(ex["mask_raw"].numpy(),  dtype=np.uint8).reshape((h, w))

        img = np.nan_to_num(img, nan=0.0)
        msk = np.nan_to_num(msk.astype(np.float32), nan=0.0).astype(np.uint8)

        img = (img - self.MEAN) / self.STD
        img = zero_pad_array(img, self.pad_to)
        msk = zero_pad_array(msk, self.pad_to)

        temporal_coords = ex['temporal_coords'].numpy().astype(np.float32)   # (2,)
        location_coords = ex['location_coords'].numpy().astype(np.float32)   # (2,)

        temporal_coords = np.expand_dims(temporal_coords, axis=0)            # (1, 2)

        if self.transform:
            augmented = self.transform(image=img, mask=msk)
            img, msk   = augmented["image"], augmented["mask"]
        else:
            img = torch.from_numpy(img.transpose(2, 0, 1))  # (C,H,W)
            msk = torch.from_numpy(msk)

        out = {
            "image": img.float(),
            "temporal_coords": torch.from_numpy(temporal_coords),
            "location_coords": torch.from_numpy(location_coords),
            "mask": msk.long()
        }
        return out

    def __del__(self):
        try:
            if hasattr(self, '_fh') and not self._fh.closed:
                self._fh.close()
        except Exception:
            pass



transform = A.Compose([
    A.RandomRotate90(p=0.7),       # p: probability of applying this transform
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    ToTensorV2(),
])

In [5]:
val_file_path = '/content/drive/MyDrive/SCO_training/ssm_footprint_val.tfrecord'
val_dataset = MineFootprintTFRecordDataset(val_file_path, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True,drop_last=True)

In [7]:
from terratorch.tasks import SemanticSegmentationTask

model = SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args=dict(
        backbone="terratorch_prithvi_eo_v2_300_tl",
        backbone_pretrained=False,
        backbone_img_size=224,
        backbone_bands=["BLUE","GREEN","RED","NIR_NARROW","SWIR_1","SWIR_2"],
        necks=[{"name":"SelectIndices", "indices":[1, 5,11,17,23]},
               {"name":"ReshapeTokensToImage"}],
        decoder="FCNDecoder",
        decoder_channels=256,
        num_classes=3,
        head_dropout=0.1,
    ),
    freeze_backbone=False,
    freeze_decoder=False,
)



In [8]:
model_save_path  = Path("/content/drive/MyDrive/SCO_training/prithvi_state_dict.pt")
model.load_state_dict(torch.load(model_save_path))
model.eval()

SemanticSegmentationTask(
  (model): PixelWiseModel(
    (encoder): PrithviViT(
      (patch_embed): PatchEmbed(
        (proj): Conv3d(6, 1024, kernel_size=(1, 16, 16), stride=(1, 16, 16))
        (norm): Identity()
      )
      (temporal_embed_enc): TemporalEncoder()
      (location_embed_enc): LocationEncoder()
      (blocks): ModuleList(
        (0-23): 24 x Block(
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Ml

In [10]:
def calculate_precision_recall(ground_truth_mask, predicted_mask, num_classes):
    # print(ground_truth_mask.shape, predicted_mask.shape)
    precision = [None] * num_classes  # np.zeros(num_classes)
    recall = [None] * num_classes  # np.zeros(num_classes)
    precision_dict = dict()
    recall_dict = dict()
    for class_id in range(num_classes):
        true_positives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask == class_id))
        false_positives = np.sum(
            (ground_truth_mask != class_id) & (predicted_mask == class_id))
        false_negatives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask != class_id))
        p = None
        r = None
        if true_positives + false_positives != 0:
            p = true_positives / (true_positives + false_positives)
        if true_positives + false_negatives != 0:
            r = true_positives / (true_positives + false_negatives)
        precision_dict[class_id] = p
        recall_dict[class_id] = r
    return precision_dict, recall_dict


def calc_f1_score(precision, recall):
    if precision is None or recall is None or (precision + recall) == 0:
        return None
    return (2 * precision * recall) / (precision + recall)


def calculate_raw_scores(ground_truth_mask, predicted_mask, num_classes):
    # print(ground_truth_mask.shape, predicted_mask.shape)
    # precision = [None] * num_classes  # np.zeros(num_classes)
    # recall = [None] * num_classes  # np.zeros(num_classes)
    # precision_dict = dict()
    # recall_dict = dict()
    true_positives_dict = dict()
    false_positives_dict = dict()
    false_negatives_dict = dict()
    for class_id in range(num_classes):
        true_positives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask == class_id))
        false_positives = np.sum(
            (ground_truth_mask != class_id) & (predicted_mask == class_id))
        false_negatives = np.sum(
            (ground_truth_mask == class_id) & (predicted_mask != class_id))
        true_positives_dict[class_id] = true_positives
        false_positives_dict[class_id] = false_positives
        false_negatives_dict[class_id] = false_negatives
        # p = None
        # r = None
        # if true_positives + false_positives != 0:
        #     p = true_positives / (true_positives + false_positives)
        # if true_positives + false_negatives != 0:
        #     r = true_positives / (true_positives + false_negatives)
        # precision_dict[class_id] = p
        # recall_dict[class_id] = r
    return true_positives_dict, false_positives_dict, false_negatives_dict

def calculate_mean_precision_recall(model, val_datagen, num_classes):
    # Replace with the actual number of classes in your model
    # num_classes = 3
    # precision_total_dict = dict()
    # recall_total_dict = dict()
    total_true_positives_dict = dict()
    total_false_positives_dict = dict()
    total_false_negatives_dict = dict()
    precision_scores = np.zeros(num_classes)
    recall_scores = np.zeros(num_classes)
    f1_scores = np.zeros(num_classes)
    # class_iou_scores = np.zeros(num_classes)
    # class_counts = np.zeros(num_classes)

    # Iterate over the validation data generator
    for images, labels in val_datagen:
        for image, label in zip(images, labels):
            # Predict the labels using the model
            predictions = model.predict(np.array([image]), verbose=False)
            predicted_classes = np.argmax(predictions, axis=3)
            # print(predicted_classes.shape)
            true_positives_dict, false_positives_dict, false_negatives_dict = calculate_raw_scores(
                label, predicted_classes[0], num_classes)
            # print(precision_dict, recall_dict)
            for class_id in range(num_classes):
                # p = precision_dict.get(class_id)
                # r = recall_dict.get(class_id)
                ttp = total_true_positives_dict.get(class_id, 0.0)
                tfp = total_false_positives_dict.get(class_id, 0.0)
                tfn = total_false_negatives_dict.get(class_id, 0.0)
                ttp += true_positives_dict[class_id]
                tfp += false_positives_dict[class_id]
                tfn += false_negatives_dict[class_id]
                total_true_positives_dict[class_id] = ttp
                total_false_positives_dict[class_id] = tfp
                total_false_negatives_dict[class_id] = tfn
                # if p != None:
                #     pt = precision_total_dict.get(class_id, [0, 0])
                #     pt[0] += p
                #     pt[1] += 1
                #     precision_total_dict[class_id] = pt
                # if r != None:
                #     rt = recall_total_dict.get(class_id, [0, 0])
                #     rt[0] += r
                #     rt[1] += 1
                #     recall_total_dict[class_id] = rt
        # Compute IoU scores for each class
        # for class_id in range(num_classes):

    # Calculate average IoU scores for each class
    # class_iou_scores /= class_counts
    for class_id in range(num_classes):
        ttp = total_true_positives_dict.get(class_id)
        tfp = total_false_positives_dict.get(class_id)
        tfn = total_false_negatives_dict.get(class_id)
        p = None
        r = None
        if ttp + tfp != 0:
            p = ttp / (ttp + tfp)
        if ttp + tfn != 0:
            r = ttp / (ttp + tfn)
        # pt = precision_total_dict.get(x)
        # rt = recall_total_dict.get(x)
        # total_score = pt[0]
        precision_scores[class_id] = p
        # rt = recall_total_dict.get(x)
        # total_score = pt[0]
        recall_scores[class_id] = r
        f1_scores[class_id] = calc_f1_score(p, r)

    # Create a DataFrame with class labels and corresponding IoU scores
    iou_df = pd.DataFrame(
        {'class_val': range(num_classes), 'precision': precision_scores, 'recall': recall_scores, 'f1 score': f1_scores})
    # print(precision_total_dict)
    return iou_df


def calculate_mean_precision_recall(model_func, model, val_datagen, num_classes, device):
    total_true_positives_dict = {c: 0 for c in range(num_classes)}
    total_false_positives_dict = {c: 0 for c in range(num_classes)}
    total_false_negatives_dict = {c: 0 for c in range(num_classes)}

    # For batchwise prediction, create predictor
    predict = model_func(model, device)

    for batch in val_datagen:
        # images: batch of shape [B, C, H, W] (torch.Tensor or np.ndarray)
        # labels: batch of shape [B, H, W] (numpy or torch)
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        temporal_coords=batch["temporal_coords"].to(device)
        location_coords=batch["location_coords"].to(device)
        # if torch.is_tensor(images):
        #     images = images.to(device)
        pred_masks = predict(images, temporal_coords=temporal_coords, location_coords=location_coords)  # shape: [B, H, W] (np.ndarray)
        if torch.is_tensor(masks):
            masks = masks.cpu().numpy()
        for pred_mask, label in zip(pred_masks, masks):
            true_positives_dict, false_positives_dict, false_negatives_dict = calculate_raw_scores(
                label, pred_mask, num_classes)
            for class_id in range(num_classes):
                total_true_positives_dict[class_id] += true_positives_dict[class_id]
                total_false_positives_dict[class_id] += false_positives_dict[class_id]
                total_false_negatives_dict[class_id] += false_negatives_dict[class_id]

    # Aggregate results
    precision_scores = np.zeros(num_classes)
    recall_scores = np.zeros(num_classes)
    f1_scores = np.zeros(num_classes)
    iou_scores = np.zeros(num_classes)

    for class_id in range(num_classes):
        ttp = total_true_positives_dict[class_id]
        tfp = total_false_positives_dict[class_id]
        tfn = total_false_negatives_dict[class_id]
        denom = ttp + tfp + tfn
        iou = ttp / denom if denom != 0 else None
        p = ttp / (ttp + tfp) if (ttp + tfp) != 0 else None
        r = ttp / (ttp + tfn) if (ttp + tfn) != 0 else None
        f1 = calc_f1_score(p, r)
        precision_scores[class_id] = p if p is not None else np.nan
        recall_scores[class_id] = r if r is not None else np.nan
        f1_scores[class_id] = f1 if f1 is not None else np.nan
        iou_scores[class_id] = iou if iou is not None else np.nan

    iou_df = pd.DataFrame({
        'class_val': range(num_classes),
        'iou': iou_scores,
        'precision': precision_scores,
        'recall': recall_scores,
        'f1 score': f1_scores
    })
    return iou_df




In [11]:
import torch
import numpy as np
import pandas as pd

def model_predict(model, device):
    """
    Returns a function that predicts class masks from input tensors using a model.
    Handles both single image and batch input.
    """
    def predict(images, temporal_coords, location_coords):
        model.to(device)
        model.eval()
        with torch.no_grad():
            if isinstance(images, np.ndarray):
                images = torch.from_numpy(images)
            images = images.to(device)
            if images.dim() == 3:  # Single image, shape: [C, H, W]
                images = images.unsqueeze(0)
            outputs = model(images).output  # [B, num_classes, H, W]
            pred_masks = torch.argmax(outputs, dim=1)  # [B, H, W]
            pred_masks = pred_masks.cpu().numpy()
        return pred_masks
    return predict


In [12]:
def model_predict(model, device):
    """
    Returns a function that predicts class masks from input tensors using a model.
    Handles both single image and batch input.
    Supports optional temporal_coords and location_coords.
    """
    def predict(images, temporal_coords=None, location_coords=None):
        model.to(device)
        model.eval()
        with torch.no_grad():
            if isinstance(images, np.ndarray):
                images = torch.from_numpy(images)
            images = images.to(device)
            if images.dim() == 3:  # Single image, shape: [C, H, W]
                images = images.unsqueeze(0)

            # Prepare optional kwargs
            kwargs = {}
            if temporal_coords is not None:
                if isinstance(temporal_coords, np.ndarray):
                    temporal_coords = torch.from_numpy(temporal_coords)
                kwargs["temporal_coords"] = temporal_coords.to(device)
            if location_coords is not None:
                if isinstance(location_coords, np.ndarray):
                    location_coords = torch.from_numpy(location_coords)
                kwargs["location_coords"] = location_coords.to(device)

            outputs = model(images, **kwargs).output  # [B, num_classes, H, W]
            pred_masks = torch.argmax(outputs, dim=1)  # [B, H, W]
            pred_masks = pred_masks.cpu().numpy()
        return pred_masks
    return predict


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [14]:
results_df = calculate_mean_precision_recall(model_predict, model, val_loader, 3, device)
results_df

Unnamed: 0,class_val,iou,precision,recall,f1 score
0,0,0.942844,0.976513,0.96472,0.970581
1,1,0.600818,0.707257,0.799691,0.750639
2,2,0.0,,0.0,
