In [2]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
import segmentation_models_pytorch as smp
from tqdm import tqdm

# Check for available device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [4]:
class CityscapesDataset(Dataset):
    def __init__(self, left_img_dir, gt_dir, transform=None):
        self.left_img_dir = left_img_dir
        self.gt_dir = gt_dir
        self.transform = transform
        
        # Collect image paths from the leftImg8bit directory
        self.images = []
        for city in os.listdir(left_img_dir):
            city_path = os.path.join(left_img_dir, city)
            if os.path.isdir(city_path):
                for filename in os.listdir(city_path):
                    if filename.endswith('.png') or filename.endswith('.jpg'):
                        self.images.append(os.path.join(city_path, filename))
        
        print(f"Found {len(self.images)} images in {self.left_img_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        image = Image.open(img_name).convert("RGB")

        city_name = img_name.split(os.sep)[-2]
        base_filename = os.path.basename(img_name).replace('leftImg8bit', 'gtFine_labelIds')
        label_name = os.path.join(self.gt_dir, 'val', city_name, base_filename)

        label = Image.open(label_name)

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label


In [6]:
val_left_img_dir = 'C:\\LeftImg\\val'  # Path to validation images
val_gt_dir = 'C:\\gtFine'  # Path to ground truth masks

transform = transforms.Compose([
    transforms.ToTensor()
])

val_dataset = CityscapesDataset(val_left_img_dir, val_gt_dir, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)


Found 500 images in C:\LeftImg\val


In [12]:
# Initialize the model
model = smp.DeepLabV3Plus(
    encoder_name='resnet34',  # Encoder type
    in_channels=3,            # Input channels (RGB)
    classes=34,               # Number of classes
    activation=None           # No activation (raw logits)
).to(device)

# Load the state dictionary, suppressing unnecessary warnings
state_dict = torch.load("deeplabv3plus_cityscapes.pth", map_location=device)

# Handle missing or unexpected keys during loading
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
    print(f"Missing keys: {missing}\nUnexpected keys: {unexpected}")

# Set the model to evaluation mode
model.eval()


  state_dict = torch.load("deeplabv3plus_cityscapes.pth", map_location=device)


DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

In [24]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score

# Assuming you have the list of unique class labels
# Replace these with your actual class labels
unique_labels = [0, 1, 2]  # Example for a 3-class problem; adjust as necessary

all_preds = []
all_labels = []

for batch_idx, (inputs, labels) in enumerate(val_loader):
    preds = model(inputs).argmax(dim=1)  # Convert logits to class indices

    # Check if labels are one-hot encoded
    if labels.dim() > 1 and labels.shape[1] > 1:
        labels = labels.argmax(dim=1)  # Convert one-hot encoded labels to class indices

    # Collect the predictions and labels for evaluation, flatten to 1D
    all_preds.append(preds.cpu().numpy().flatten())
    all_labels.append(labels.cpu().numpy().flatten())

# Convert lists to numpy arrays and concatenate
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

# Calculate evaluation metrics
conf_matrix = confusion_matrix(all_labels, all_preds, labels=unique_labels)

# Calculate precision, recall, and accuracy
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
accuracy = accuracy_score(all_labels, all_preds)

# Print the evaluation metrics
print(f"Confusion Matrix:\n{conf_matrix}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"Accuracy: {accuracy}")


Confusion Matrix:
[[131072000         0         0]
 [        0         0         0]
 [        0         0         0]]
Precision: 1.0
Recall: 1.0
Accuracy: 1.0
