## Evaluate model using these strategies

Each technique adds significant time cost. You may play around which combos give the best balance

#### CRFs and Morphological Operations

Refine boundaries and clean up noises

Use PyDenseCRF, employ dilation and erosion

#### Thresholding

Fine tune uncertain regions around boundaries

Apply threshold on logits after softmax

#### Test-Time Augmentation and Multi-Scale

Improve robustness and accuracy, reduce reliance on specific features

Apply augmentations like flipping, rotation, scaling, etc ... Then aggregate results by max/most/mean/weighted on logits/probs/classes

In [1]:
import sys
from pathlib import Path
from pprint import pprint

import numpy as np
import torch
from PIL.Image import Image
from pydensecrf import densecrf as dcrf  # type: ignore
from pydensecrf.utils import unary_from_softmax
from scipy.ndimage import binary_dilation, binary_erosion
from sklearn.metrics import confusion_matrix
from torch import Tensor
from torchvision.datasets import VOCSegmentation
from torchvision.transforms.v2 import functional as TF

sys.path.append(str(Path("..").resolve()))
from src.datasets import resolve_metadata
from src.models import FCN_ResNet34_Weights, fcn_resnet34
from src.pipeline import forward_batch
from src.utils.metrics import metrics_from_confusion
from src.utils.transform import DataTransform
from src.utils.visual import combine_images, draw_mask_on_image

In [2]:
def refine_segmentation_with_crf_and_morphology(
    logits, image, apply_crf=True, apply_morphology=True
) -> np.ndarray:
    """
    Refine semantic segmentation logits using CRF and morphological operations.

    Args:
        logits (torch.Tensor): Logits tensor of shape (num_class, H, W).
        image (np.ndarray): Original image as a NumPy array of shape (H, W, 3).
        apply_crf (bool): Whether to apply CRF.
        apply_morphology (bool): Whether to apply morphological operations.

    Returns:
        np.ndarray: Refined segmentation map of shape (H, W).
    """
    H, W = logits.shape[1], logits.shape[2]
    num_classes = logits.shape[0]

    # Convert logits to probabilities
    probs = torch.softmax(logits, dim=0).cpu().numpy()  # Shape: (num_class, H, W)
    predicted_mask = np.argmax(probs, axis=0)  # Initial prediction, shape: (H, W)

    # Apply CRF
    if apply_crf:
        d = dcrf.DenseCRF2D(W, H, num_classes)  # Initialize CRF
        unary = unary_from_softmax(probs)  # Convert probabilities to unary potentials
        d.setUnaryEnergy(unary)

        # Add pairwise terms using the image
        d.addPairwiseGaussian(sxy=3, compat=3)  # Spatial smoothness
        d.addPairwiseBilateral(
            sxy=50, srgb=13, rgbim=image, compat=10
        )  # Appearance-based smoothness

        # Perform CRF inference
        refined_probs = np.array(d.inference(10)).reshape(
            num_classes, H, W
        )  # Refined probabilities
        predicted_mask = np.argmax(refined_probs, axis=0)  # Update prediction

    # Apply morphological operations
    if apply_morphology:
        binary_mask = (predicted_mask > 0).astype(
            np.uint8
        )  # Binary mask for non-background
        binary_mask = binary_dilation(binary_mask, iterations=1)  # Dilation
        binary_mask = binary_erosion(binary_mask, iterations=1)  # Erosion

        # Apply the refined binary mask to the predicted mask
        predicted_mask[binary_mask == 0] = 0

    return predicted_mask

In [3]:
def apply_thresholding(logits, threshold=0.5, uncertain=255):
    """
    Apply thresholding to refine class boundaries in semantic segmentation.

    Args:
        logits (torch.Tensor): Logits tensor of shape (num_class, H, W).
        threshold (float): Confidence threshold (between 0 and 1).
        uncertain: index for uncertain regions

    Returns:
        torch.Tensor: Refined segmentation map of shape (H, W).
    """
    # Convert logits to probabilities
    probs = torch.softmax(logits, dim=0)  # Shape: (num_class, H, W)

    # Get the predicted class for each pixel
    predicted_classes = torch.argmax(probs, dim=0)  # Shape: (H, W)

    # Apply thresholding
    max_probs, _ = torch.max(probs, dim=0)  # Maximum probability for each pixel
    refined_mask = torch.where(max_probs >= threshold, predicted_classes, uncertain)

    return refined_mask  # Shape: (H, W), where 255 indicates uncertain pixels

In [4]:
weights = FCN_ResNet34_Weights.VOC2012
model = fcn_resnet34(weights=weights)
transforms = DataTransform()
augment = weights.value.transforms()
dataset = VOCSegmentation(r"D:\_Dataset", image_set="val", transforms=transforms)
metadata = resolve_metadata("VOC")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
data: tuple[Tensor, Tensor] = dataset[1]
image, mask = data
model.eval().to(device)
with torch.no_grad():
    images, masks = image.unsqueeze(0), mask.unsqueeze(0)
    logits, _ = forward_batch(model, images, masks, augment, None, device)
logit = logits["out"].squeeze(0)
pred = logit.argmax(0)

In [6]:
colors = metadata.colors
mask_overlay = draw_mask_on_image(image, mask, colors)
pred_overlay = draw_mask_on_image(image, pred, colors)

crf_image = (
    TF.to_dtype(image, torch.uint8, scale=True)
    .permute(1, 2, 0)
    .contiguous()
    .numpy(force=True)
)
refined_arr = refine_segmentation_with_crf_and_morphology(logit, crf_image)
refined = torch.tensor(refined_arr, dtype=torch.uint8)
refined_overlay = draw_mask_on_image(image, refined, colors)

# combined = combine_images([image, mask_overlay, pred_overlay, refined_overlay])
# combined_pil: Image = TF.to_pil_image(combined)
# display(combined_pil.reduce(3))

In [7]:
ignore_index = metadata.ignore_index
mask_np = mask.numpy(force=True).flatten()
pred_np = pred.numpy(force=True).flatten()
not_ignored = (mask_np != ignore_index) & (pred_np != ignore_index)

cm = confusion_matrix(mask_np[not_ignored], pred_np[not_ignored])
print(cm)
pprint(metrics_from_confusion(cm))

refined_np = refined.numpy(force=True).flatten()
not_ignored = (mask_np != ignore_index) & (refined_np != ignore_index)

cm = confusion_matrix(mask_np[not_ignored], refined_np[not_ignored])
print(cm)
pprint(metrics_from_confusion(cm))

[[55254    27  8570]
 [    0     0     0]
 [ 8262  7731 77660]]
{'acc': 0.8438769809020723,
 'dice': 0.5770282016901661,
 'fwiou': 0.7623476883155276,
 'macc': 0.5648632412873146,
 'miou': 0.5086419156730985}
[[52617     0 11234]
 [    0     0     0]
 [10547  1364 81742]]
{'acc': 0.8530513510767981,
 'dice': 0.5681667402847043,
 'fwiou': 0.7501062599352178,
 'macc': 0.5656256265827229,
 'miou': 0.4955235678430326}
