# DeepLabV3

In [4]:
import os
import cv2
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from glob import glob
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch import optim
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

# ResNet-based classes
import torchvision.models as tvmodels
from torchvision.models.resnet import ResNet, Bottleneck
from torch.optim.lr_scheduler import CosineAnnealingLR

# ------------------- CONFIG -------------------
BASE_PATH = "C:/Users/User/Desktop/ai4mars/msl"
TRAIN_SPLIT = os.path.join(BASE_PATH, 'train_split.json')
VAL_SPLIT   = os.path.join(BASE_PATH, 'val_split.json')
CLASS_WEIGHTS_PATH = os.path.join(BASE_PATH, 'class_weights.json')
SAVED_MODEL_DIR    = "C:/Users/User/Desktop/saved_model"
os.makedirs(SAVED_MODEL_DIR, exist_ok=True)

NUM_CLASSES        = 4    # Soil, Bedrock, Sand, Big Rock
BATCH_SIZE         = 16
EPOCHS             = 15
LEARNING_RATE      = 1e-4   # Cosine schedule base
L2_REGULARIZATION  = 1e-5
PATIENCE           = 5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------------------------
# 1) DATA LOADING
def load_splits(split_path):
    with open(split_path, 'r') as f:
        return json.load(f)

# Load the splitted data
train_split = load_splits(TRAIN_SPLIT)
val_split   = load_splits(VAL_SPLIT)

# Load class weights
with open(CLASS_WEIGHTS_PATH, 'r') as f:
    class_weights_json = json.load(f)
class_weights = torch.tensor(class_weights_json["class_weights"], dtype=torch.float32)

class RealMarsDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data[idx]['image']
        mask_path  = self.data[idx]['mask']

        # Read & preprocess image
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if image is None:
            image = np.zeros((256,256,3), dtype=np.float32)
        else:
            image = image.astype(np.float32)/255.
            image = cv2.resize(image, (256,256))
        image_tensor = torch.from_numpy(image.transpose(2,0,1)).float()

        # Read & preprocess mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            mask = np.full((256,256), fill_value=-1, dtype=np.int64)
        else:
            mask = mask.astype(np.int64)
            mask = cv2.resize(mask, (256,256), interpolation=cv2.INTER_NEAREST)
        mask[mask==255] = -1

        # One-hot encoding
        mask_one_hot = torch.zeros((NUM_CLASSES, *mask.shape), dtype=torch.float32)
        for c in range(NUM_CLASSES):
            mask_one_hot[c] = torch.from_numpy((mask==c).astype(np.float32))

        return image_tensor, torch.tensor(mask, dtype=torch.long), mask_one_hot

train_dataset = RealMarsDataset(train_split)
val_dataset   = RealMarsDataset(val_split)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# ---------------------------------------------
# 2) HYBRID LOSS
def hybrid_loss(pred, target, target_one_hot, class_weights, ignore_index=-1):
    import torch.nn as nn
    # Weighted CE
    valid_target = target.clone()
    valid_target[valid_target==ignore_index] = 0
    wce = nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_index)(pred, valid_target)

    pred_softmax = F.softmax(pred, dim=1)

    # Dice
    dice_num = 2.0*(pred_softmax*target_one_hot).sum(dim=(2,3))
    dice_den = (pred_softmax+target_one_hot).sum(dim=(2,3)) + 1e-6
    dice = 1.0 - (dice_num/dice_den).mean()

    # Tversky
    alpha, beta = 0.7, 0.3
    tp = (pred_softmax*target_one_hot).sum(dim=(2,3))
    fn = (target_one_hot*(1-pred_softmax)).sum(dim=(2,3))
    fp = ((1-target_one_hot)*pred_softmax).sum(dim=(2,3))
    tversky_idx = tp/(tp+alpha*fn+beta*fp+1e-6)
    tversky_loss = 1.0 - tversky_idx.mean()

    # Focal
    focal = -(target_one_hot*((1-pred_softmax)**2)*torch.log(pred_softmax+1e-6)).mean()

    return wce + dice + tversky_loss + focal

# ---------------------------------------------
# 3) ENCODER with DROPOUT BOTTLEKNECK
class BottleneckWithDropout(Bottleneck):
    """
    Insert Dropout2d(p=0.2) after second 3x3 conv.
    """
    def __init__(self, inplanes, planes, *args, **kwargs):
        self.dropout_p = kwargs.pop('dropout', 0.2)
        super().__init__(inplanes, planes, *args, **kwargs)
        self.dropout = nn.Dropout2d(p=self.dropout_p)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        # Insert dropout
        out = self.dropout(out)

        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

class ResNet50Encoder(nn.Module):
    """
    Use a ResNet-50 backbone with atrous/dilated conv for the deeper layers,
    but we keep it simpler by referencing torchvision's ResNet, injecting
    BottleneckWithDropout. We'll partially load pretrained weights as well.
    """
    def __init__(self, pretrained=True, dropout=0.2):
        super().__init__()
        # 1) Build a standard resnet50
        resnet_official = tvmodels.resnet50(pretrained=pretrained)

        # 2) Replace the Bottleneck with our version that has dropout
        # We'll replicate layer definitions from torchvision with 'block=BottleneckWithDropout'.
        from torchvision.models.resnet import ResNet
        from copy import deepcopy

        class ResNetDrop(ResNet):
            def __init__(self):
                super().__init__(
                    block=BottleneckWithDropout,
                    layers=[3,4,6,3],
                    zero_init_residual=False
                )
                # remove final FC, avgpool
                del self.fc
                del self.avgpool

        self.base = ResNetDrop()
        # We attempt to partially load from the official resnet50
        official_dict = resnet_official.state_dict()
        base_dict = self.base.state_dict()

        filtered = {}
        for k,v in official_dict.items():
            # if shape matches, we copy
            if k in base_dict and v.shape == base_dict[k].shape:
                filtered[k] = v
        base_dict.update(filtered)
        self.base.load_state_dict(base_dict)

        # Now we convert some layers for atrous conv in layer3, layer4 to enlarge FOV
        # e.g. layer3, layer4 => dilation=2,4 ...
        # We'll pick a moderate approach for DeepLabv3.
        # Typically: layer3 is dilated=2, layer4 is dilated=4 => no further downsample in layer3.
        self._convert_to_dilated(self.base.layer3, dilation=2)
        self._convert_to_dilated(self.base.layer4, dilation=4)

    def _convert_to_dilated(self, layer, dilation):
        """ Convert stride in the first block to 1, and set both conv2's dilation & padding. """
        # For each Bottleneck in the layer
        for i, block in enumerate(layer):
            if i == 0:  
                # remove stride in conv2
                if block.downsample is not None:
                    # remove stride in the downsample
                    block.downsample[0].stride = (1,1)
                block.conv2.dilation = (dilation,dilation)
                block.conv2.padding  = (dilation,dilation)
                block.conv2.stride   = (1,1)
            else:
                block.conv2.dilation = (dilation,dilation)
                block.conv2.padding  = (dilation,dilation)

    def forward(self, x):
        # The base is a partial resnet (no final pool/FC). 
        # We'll manually replicate forward.
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        x = self.base.maxpool(x)

        f1 = self.base.layer1(x)      # => e.g. [B,256,H/4,W/4]
        f2 = self.base.layer2(f1)     # => [B,512,H/8,W/8]
        f3 = self.base.layer3(f2)     # => [B,1024,H/8,W/8] if we removed stride in layer3
        f4 = self.base.layer4(f3)     # => [B,2048,H/8,W/8] or H/4 if further no stride

        return f1, f2, f3, f4

# ---------------------------------------------
# 4) ASPP MODULE
class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling to capture multi-scale context.
    We'll use parallel dilated convs with different rates, plus a global pool path.
    To keep it lightweight, we reduce #channels in each branch.
    """
    def __init__(self, in_channels=2048, out_channels=128, dropout=0.2):
        super().__init__()
        self.dropout_p = dropout
        # Branch1: 1x1 conv
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout)
        )
        # Branch2: 3x3 atrous conv (rate=6)
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1,
                      padding=6, dilation=6, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout)
        )
        # Branch3: 3x3 atrous conv (rate=12)
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1,
                      padding=12, dilation=12, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout)
        )
        # Branch4: 3x3 atrous conv (rate=18)
        self.branch4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1,
                      padding=18, dilation=18, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout)
        )
        # Global avg pool branch
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.global_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout)
        )

        # final conv for after concatenation
        self.project = nn.Sequential(
            nn.Conv2d(out_channels*5, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout)
        )

    def forward(self, x):
        # x => [B,2048,H,W] typically H,W = ~1/8 input size w/dilation
        b, c, h, w = x.shape

        feat1 = self.branch1(x)
        feat2 = self.branch2(x)
        feat3 = self.branch3(x)
        feat4 = self.branch4(x)

        # global pooling
        gp = self.global_pool(x)                 # => [B,2048,1,1]
        gp = self.global_conv(gp)                # => [B,256,1,1]
        gp = F.interpolate(gp, size=(h,w), mode='bilinear', align_corners=False)

        cat = torch.cat([feat1, feat2, feat3, feat4, gp], dim=1)
        out = self.project(cat)
        return out

# ---------------------------------------------
# 5) DeeplabV3 HEAD
class DeepLabV3Head(nn.Module):
    """
    After ASPP, we do a small decode step (conv, upsample if needed).
    """
    def __init__(self, aspp_out=256, num_classes=4, dropout=0.2):
        super().__init__()
        self.dropout = nn.Dropout2d(p=dropout)
        self.conv1 = nn.Conv2d(aspp_out, aspp_out//2, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(aspp_out//2)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(aspp_out//2, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

# ---------------------------------------------
# 6) Full DeepLabV3
class DeepLabV3(nn.Module):
    """
    Proposed DeepLabV3 with:
      - ResNet50Encoder( with BottleneckWithDropout + atrous convs )
      - ASPP with multiple atrous rates
      - Lightweight decode head => produce final segmentation logits
    """
    def __init__(self, num_classes=4, pretrained=True, dropout=0.2):
        super().__init__()
        print(f"Initializing DeepLabV3 with dropout={dropout} and pretrained={pretrained}")
        self.encoder = ResNet50Encoder(pretrained=pretrained, dropout=dropout)
        self.aspp    = ASPP(in_channels=2048, out_channels=256, dropout=dropout)
        self.head    = DeepLabV3Head(aspp_out=256, num_classes=num_classes, dropout=dropout)

    def forward(self, x):
        # get feature maps from resnet
        f1, f2, f3, f4 = self.encoder(x)  # f4 => [B,2048,H/8, W/8] with atrous
        # pass f4 to ASPP
        x = self.aspp(f4)                 # => [B,256,H/8,W/8]
        # decode
        logits = self.head(x)             # => [B,num_classes,H/8,W/8]
        # upsample final => 256×256
        logits = F.interpolate(logits, scale_factor=8, mode='bilinear', align_corners=False)
        return logits

# ---------------------------------------------
#         TRAIN FUNCTION (Cosine LR)
# ---------------------------------------------
def train_model(
    model, train_loader, val_loader,
    num_epochs, save_path,
    base_lr=1e-4,  # initial LR
    l2_reg=1e-7,
    early_stopping_patience=5
):
    import torch.optim as optim
    from torch.optim.lr_scheduler import CosineAnnealingLR

    model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=base_lr, weight_decay=l2_reg)

    # Setup Cosine Annealing LR for 'num_epochs' cycles
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    best_val_loss = float('inf')
    early_stopping_counter = 0

    history = {
        'loss': [],
        'val_loss': [],
        'accuracy': [],
        'val_accuracy': []
    }

    def mean_iou(pred, target, num_classes):
        pred_labels = torch.argmax(pred, dim=1)
        ious = []
        for cls in range(num_classes):
            intersection = ((pred_labels == cls) & (target == cls)).sum().item()
            union = ((pred_labels == cls) | (target == cls)).sum().item()
            if union > 0:
                ious.append(intersection / union)
        return np.mean(ious) if ious else 0.0

    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for images, masks, masks_one_hot in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            images, masks, masks_one_hot = images.to(DEVICE), masks.to(DEVICE), masks_one_hot.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = hybrid_loss(outputs, masks, masks_one_hot, class_weights)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            preds = outputs.argmax(dim=1)
            valid_indices = (masks != -1)
            train_correct += (preds[valid_indices] == masks[valid_indices]).sum().item()
            train_total   += valid_indices.sum().item()

        # Metrics for training
        train_loss /= len(train_loader)
        train_accuracy = train_correct / train_total if train_total>0 else 0.0

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for images, masks, masks_one_hot in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                images, masks, masks_one_hot = images.to(DEVICE), masks.to(DEVICE), masks_one_hot.to(DEVICE)
                outputs = model(images)
                loss = hybrid_loss(outputs, masks, masks_one_hot, class_weights)
                val_loss += loss.item()

                preds = outputs.argmax(dim=1)
                valid_indices = (masks != -1)
                val_correct += (preds[valid_indices] == masks[valid_indices]).sum().item()
                val_total   += valid_indices.sum().item()

        val_loss     /= len(val_loader)
        val_accuracy = val_correct / val_total if val_total>0 else 0.0

        # Cosine LR step after each epoch
        scheduler.step()

        print(f"[Epoch {epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, "
              f"LR: {scheduler.get_last_lr()[0]:.6f}")

        history['loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['accuracy'].append(train_accuracy)
        history['val_accuracy'].append(val_accuracy)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_counter = 0
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping_patience:
                print("Early stopping triggered. No improvement in validation loss.")
                break

    return history


# ---------------- PLOT LEARNING CURVES --------------
def plot_learning_curves(history):
    """Plot training and validation loss, accuracy using the stored `history` dict."""
    plt.figure(figsize=(12, 5))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(history['loss'], label="Train Loss", marker='o')
    plt.plot(history['val_loss'], label="Validation Loss", marker='o')
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()

    # Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history['accuracy'], label="Train Accuracy", marker='o')
    plt.plot(history['val_accuracy'], label="Validation Accuracy", marker='o')
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid()

    plt.tight_layout()
    plt.show()

# ---------------- MAIN SCRIPT ----------------
if __name__ == "__main__":
    model = DeepLabV3(num_classes=NUM_CLASSES)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=L2_REGULARIZATION)

    # Setup ReduceLROnPlateau
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    reduce_on_plateau = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-10, verbose=1)

    # Train
    history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=EPOCHS,
        save_path=os.path.join(SAVED_MODEL_DIR, 'DeepLabV3.pth'),
        base_lr=LEARNING_RATE,
        l2_reg=L2_REGULARIZATION,
        early_stopping_patience=PATIENCE
    )

    # Plot the final learning curves
    plot_learning_curves(history)

    print("Training complete. Results have been saved.")

Initializing DeepLabV3 with dropout=0.2 and pretrained=True


Epoch 1/15 [Train]:   0%|          | 2/703 [00:53<5:09:56, 26.53s/it]


KeyboardInterrupt: 

Model evaluation

In [None]:
import torch
import numpy as np
import cv2
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Load your test split and define a DataLoader similar to how you do it for train/val
#    (Assuming you have test_split.json or something similar).
test_split_path = os.path.join(BASE_PATH, "test_split.json")

# Example RealMarsDataset usage for the test set (reusing the same class from your code)
test_split_data = load_splits(test_split_path)
test_dataset = RealMarsDataset(test_split_data)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# 2. Load the trained model and set it to eval mode
model = DeepLabV3(num_classes=NUM_CLASSES)
model_path = os.path.join(SAVED_MODEL_DIR, 'ResNet50_3.pth')
model.load_state_dict(torch.load(model_path))
model.to(DEVICE)
model.eval()

# 3. Inference on the test set and gather predictions and targets
all_preds = []
all_targets = []

with torch.no_grad():
    for images, masks, _ in test_loader:  # We don't necessarily need masks_one_hot for evaluation
        images = images.to(DEVICE)
        masks  = masks.to(DEVICE)

        outputs = model(images)
        preds = outputs.argmax(dim=1)  # shape: (batch_size, H, W)

        # Move to CPU
        preds_np  = preds.cpu().numpy()
        masks_np  = masks.cpu().numpy()

        # Flatten them, but exclude the -1 ignore pixels
        for i in range(preds_np.shape[0]):
            valid_indices = (masks_np[i] != -1)  # mask is -1 => ignore
            valid_preds   = preds_np[i][valid_indices]
            valid_targets = masks_np[i][valid_indices]
            all_preds.extend(valid_preds.tolist())
            all_targets.extend(valid_targets.tolist())

# 4. Classification report (multi-class)
#    We specify the class indices and their names for interpretability:
class_names = ["Soil", "Bedrock", "Sand", "Big Rock"]
report = classification_report(
    all_targets,
    all_preds,
    labels=[0, 1, 2, 3],
    target_names=class_names,
    digits=4
)
print("Classification Report:")
print(report)

# 5. Confusion matrix
cm = confusion_matrix(all_targets, all_preds, labels=[0, 1, 2, 3])
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names,
            yticklabels=class_names)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# 6. Compute class-wise and mean IoU
#    We can adapt the mean_iou approach to aggregated data here
def compute_iou(preds, targets, num_classes):
    """
    preds: list or 1D array of predicted class indices
    targets: list or 1D array of ground truth class indices
    Returns: array of IoU values for each class, plus mean IoU
    """
    ious = []
    for cls in range(num_classes):
        intersection = 0
        union = 0
        for p, t in zip(preds, targets):
            intersection += int(p == cls and t == cls)
            union        += int((p == cls) or (t == cls))
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return ious

ious = compute_iou(all_preds, all_targets, NUM_CLASSES)
print("Per-Class IoU:")
for idx, iou_val in enumerate(ious):
    print(f"  {class_names[idx]}: {iou_val:.4f}")
mean_iou = np.nanmean(ious)
print(f"Mean IoU: {mean_iou:.4f}")

Segmentation map prediction

In [None]:
import os
import json
import time
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


BASE_PATH       = "C:/Users/User/Desktop/ai4mars/msl"
TEST_SPLIT      = os.path.join(BASE_PATH, "test_split.json")
SAVED_MODEL_DIR = "C:/Users/User/Desktop/saved_model"
MODEL_PATH      = os.path.join(SAVED_MODEL_DIR, "ResNet50_3.pth")

NUM_CLASSES     = 4
DEVICE          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class_names = ["Soil", "Bedrock", "Sand", "Big Rock"]

# Consistent BGR color map for display
color_map = {
    0: (0, 0, 128),    # Soil
    1: (0, 128, 0),    # Bedrock
    2: (0, 128, 128),  # Sand
    3: (128, 0, 128),  # Big Rock
}

def load_test_split(test_split_path):
    with open(test_split_path, 'r') as f:
        return json.load(f)

def load_preprocess_single_image(img_path):
    """Load an image, resize to 256x256, normalize [0..1], and convert to Tensor(C,H,W)."""
    image = cv2.imread(img_path, cv2.IMREAD_COLOR)
    if image is None:
        image = np.zeros((256, 256, 3), dtype=np.float32)
    else:
        image = image.astype(np.float32) / 255.0
        image = cv2.resize(image, (256, 256))
    image_tensor = torch.from_numpy(image.transpose(2, 0, 1)).float()
    return image_tensor

def load_ground_truth_mask(mask_path):
    """Load ground-truth mask, resize, and return as (H,W) integer array."""
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        mask = np.full((256, 256), fill_value=-1, dtype=np.int64)
    else:
        mask = mask.astype(np.int64)
        mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
    return mask

def colorize_mask(mask):
    """
    Convert class indices in 'mask' to a BGR color image using 'color_map'.
    mask shape: (H, W), integer class IDs.
    Returns: (H,W,3) colored image.
    """
    h, w = mask.shape
    seg_vis = np.zeros((h, w, 3), dtype=np.uint8)
    for cls_idx, bgr in color_map.items():
        seg_vis[mask == cls_idx] = bgr
    return seg_vis

def main():
    # 1. Load test split
    test_data = load_test_split(TEST_SPLIT)
    if not test_data:
        raise ValueError("Test split is empty or missing.")

    # 2. Select an image from the test split
    image_id = 1239
    if image_id < 0 or image_id >= len(test_data):
        raise ValueError(f"Invalid image_id: {image_id}, must be between 0 and {len(test_data)-1}.")

    sample = test_data[image_id]
    image_path = sample["image"]
    mask_path  = sample["mask"]  # Ground-truth label path

    print(f"Selected Image: {image_path}")
    print(f"Ground Truth Label: {mask_path}")

    # 3. Load the trained ERFNet model
    model = DeepLabV3(num_classes=NUM_CLASSES)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()

    # 4. Preprocess the single image
    image_tensor = load_preprocess_single_image(image_path).unsqueeze(0).to(DEVICE)

    # Load original BGR image for display
    original_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
    if original_bgr is None:
        original_bgr = np.zeros((256, 256, 3), dtype=np.uint8)
    else:
        original_bgr = cv2.resize(original_bgr, (256, 256))

    # Load ground truth mask
    gt_mask = load_ground_truth_mask(mask_path)  # shape [256,256], class IDs
    gt_color = colorize_mask(gt_mask)

    # 5. Inference and measure time
    start_time = time.time()
    with torch.no_grad():
        output = model(image_tensor)  # shape [1, num_classes, 256,256]
    end_time = time.time()
    inference_time_ms = (end_time - start_time) * 1000

    # 6. Convert to predicted mask
    pred_mask = output.argmax(dim=1).squeeze(0).cpu().numpy()
    pred_color = colorize_mask(pred_mask)

    # 7. Visualization
    # For clarity, we display 4 subplots: (1) original, (2) ground-truth, (3) predicted, (4) overlay
    original_rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)
    gt_rgb       = cv2.cvtColor(gt_color, cv2.COLOR_BGR2RGB)
    pred_rgb     = cv2.cvtColor(pred_color, cv2.COLOR_BGR2RGB)

    # Overlay: blend original and predicted
    overlay_bgr = cv2.addWeighted(original_bgr, 0.5, pred_color, 0.5, 0)
    overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)

    # Build legend patches
    patches = []
    for i, name in enumerate(class_names):
        b, g, r = color_map[i]
        color_rgb = (r/255, g/255, b/255)
        patches.append(mpatches.Patch(color=color_rgb, label=name))

    plt.figure(figsize=(16, 6))

    # Subplot 1: Original
    plt.subplot(1, 4, 1)
    plt.imshow(original_rgb)
    plt.title("Original")
    plt.axis('off')

    # Subplot 2: Ground Truth
    plt.subplot(1, 4, 2)
    plt.imshow(gt_rgb)
    plt.title("Ground Truth")
    plt.axis('off')

    # Subplot 3: Predicted
    plt.subplot(1, 4, 3)
    plt.imshow(pred_rgb)
    plt.title("Predicted")
    plt.axis('off')
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Subplot 4: Overlay
    plt.subplot(1, 4, 4)
    plt.imshow(overlay_rgb)
    plt.title("Overlay")
    plt.axis('off')

    plt.suptitle(f"Inference Time: {inference_time_ms:.2f} ms", fontsize=14)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()