In [1]:
from model import Model
from torchvision.datasets import Cityscapes
from torch.utils.data import Subset, DataLoader
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
import utils

In [None]:
# transformation for input images and masks
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
target_transforms = transforms.Compose([
    transforms.Resize((32, 32), transforms.InterpolationMode.NEAREST),
    transforms.ToTensor()
])

# load dataset
data_path = "C:/Users/20193625/OneDrive - TU Eindhoven/Documents/5LSM0/5LSM0_Final_Project/CityScapes"
train_dataset = Cityscapes(root=data_path,
                     split='train',
                     mode='fine',
                     target_type='semantic',
                     transform=transform,
                     target_transform=target_transforms)

val_dataset = Cityscapes(root=data_path,
                         split='val',
                         mode='fine',
                         target_type='semantic',
                         transform=transform,
                         target_transform=target_transforms)


# create small subset of dataset
torch.manual_seed(1)
subset_size = int(0.05 * len(train_dataset))
indices = torch.randperm(len(train_dataset))[:subset_size]
train_subset = Subset(train_dataset, indices)
val_subset = Subset(val_dataset, indices)

# create data loaders
train_loader = DataLoader(train_subset, batch_size=8, shuffle=True) # Num workers and Pin Memory??
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)


In [None]:
images, targets = next(iter(train_loader))

In [None]:
def visualize_samples(images, targets, n):
    fig, axs = plt.subplots(n, 2, figsize=(6, n * 2))

    for i in range(n):
        image = images[i].permute(1,2,0).numpy()
        target = targets[i].squeeze().numpy()

        axs[i,0].imshow(image)
        axs[i,0].set_title('Image')
        axs[i,0].axis('off')

        axs[i,1].imshow(target)
        axs[i,1].set_title('Segmentation')
        axs[i,1].axis('off')

    plt.show()

visualize_samples(images, targets, 8)

In [None]:
# define model
model = Model()

In [None]:
# define optimizer and loss function (don't forget to ignore class index 255)
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_function = nn.CrossEntropyLoss(ignore_index=255) # DICE LOSS MAKEN

In [None]:
# training/validation loop
def train_model_segmentation(model, train_loader, num_epochs=5, lr=0.01):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, masks in train_loader:
            masks = (masks*255).long().squeeze()     #*255 because the id are normalized between 0-1
            masks = map_id_to_train_id(masks)
            optimizer.zero_grad()
            outputs = model(inputs)
            masks = (masks * 255)

            loss = loss_function(outputs, masks.long().squeeze())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')


In [None]:
train_model_segmentation(model, train_loader, 20, 0.01)

In [None]:
colors = {}
class_names = []
for cls in train_dataset.classes:
    colors[cls.id]= cls.color
    class_names.append(cls.name)

print(colors)
print(class_names)

In [None]:
def mask_to_rgb(mask, class_to_color):
    """
    Converts a numpy mask with multiple classes indicated by integers to a color RGB mask.

    Parameters:
        mask (numpy.ndarray): The input mask where each integer represents a class.
        class_to_color (dict): A dictionary mapping class integers to RGB color tuples.

    Returns:
        numpy.ndarray: RGB mask where each pixel is represented as an RGB tuple.
    """
    # Get dimensions of the input mask
    height, width = mask.shape

    # Initialize an empty RGB mask
    rgb_mask = np.zeros((height, width, 3), dtype=np.uint8)

    # Iterate over each class and assign corresponding RGB color
    for class_idx, color in class_to_color.items():
        # Mask pixels belonging to the current class
        class_pixels = mask == class_idx
        # Assign RGB color to the corresponding pixels
        rgb_mask[class_pixels] = color

    return rgb_mask

def visualize_segmentation(model, dataloader, num_examples=5):
    """
    Visualizes segmentation results from a given model using a dataloader.

    Args:
        model (torch.nn.Module): The segmentation model to visualize.
        dataloader (torch.utils.data.DataLoader): Dataloader providing image-mask pairs.
        num_examples (int, optional): Number of examples to visualize. Defaults to 5.

    Returns:
        None
    """
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if i >= num_examples:
                break
            
            outputs = model(images)
            outputs = torch.softmax(outputs, dim=1)
            predicted = torch.argmax(outputs, 1)

            images = images.numpy()
            masks = masks.numpy()*255

            predicted = predicted.numpy()

            for j in range(images.shape[0]):
                image = renormalize_image(images[j].transpose(1, 2, 0))

                mask = masks[j].squeeze()
                pred_mask = predicted[j]
                                
                # Convert mask and predicted mask to RGB for visualization
                mask_rgb = mask_to_rgb(mask, colors)
                pred_mask_rgb = mask_to_rgb(pred_mask, colors)
                
                # Get unique classes present in the ground truth and predicted masks
                unique_classes_gt = np.unique(mask)
                unique_classes_pred = np.unique(pred_mask)
                
                unique_classes_gt = np.delete(unique_classes_gt, [0, -1])
                unique_classes_pred= np.delete(unique_classes_pred, 0)
                
                unique_classes_gt[unique_classes_gt == 255] = 0
                unique_classes_pred[unique_classes_pred == 255] = 0
                
                # Map class indices to class names from the VOC2012 dataset
                classes_gt = [class_names[int(idx)] for idx in unique_classes_gt]
                classes_pred = [class_names[int(idx)] for idx in unique_classes_pred]
                
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 3, 1)
                plt.imshow(image)
                plt.title('Image')
                plt.axis('off')

                plt.subplot(1, 3, 2)
                plt.imshow(mask_rgb)
                # plt.title(f'Ground Truth Mask Classes:\n {classes_gt}')
                plt.title(f'Ground Truth Mask Classes:')
                plt.axis('off')

                plt.subplot(1, 3, 3)
                plt.imshow(pred_mask_rgb)
                # plt.title(f'Predicted Mask Predicted Classes:\n {classes_pred}')
                plt.title(f'Predicted Mask Predicted Classes:')
                plt.axis('off')

                plt.show()
                

def renormalize_image(image):
    """
    Renormalizes the image to its original range.
    
    Args:
        image (numpy.ndarray): Image tensor to renormalize.
    
    Returns:
        numpy.ndarray: Renormalized image tensor.
    """
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]  
    renormalized_image = image * std + mean
    return renormalized_image

In [None]:
visualize_segmentation(model, val_loader, num_examples=2)