# Semantic segmentation using PyTorch and TorchVision

## Overview
In this lab session, we will explore semantic segmentation and learn how to obtain the segmented masks using the utilities provided by PyTorch and TorchVision. 

## Learning objectives
- Understand how to visualise images using TorchVision utilities.
- Explore semantic segmentation using pre-trained models.
- Use segmentation masks for visualising model predictions.

## Steps
1. Import the required libraries and define helper functions.
2. Visualize a grid of images.
3. Use semantic segmentation models and visualise their outputs.
4. Apply segmentation masks to images.
5. Analyze the model predictions for multiple classes.


In [None]:
#### **Cell 1: Import Libraries**
# Import essential libraries for tensor manipulation, visualization, and image transformations.
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

plt.rcParams["savefig.bbox"] = 'tight'

In [None]:
#### **Cell 2: Define `show` Utility**
# A utility function to display one or more images using Matplotlib.
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])



In [None]:
#### **Cell 3: Visualizing a Grid of Images**
# Load and decode images for visualization.

from torchvision.io import read_file, decode_image
from pathlib import Path

image_folder = Path('./images')
dog1_int = decode_image(read_file(str(image_folder / 'dog1.jpg')))
dog2_int = decode_image(read_file(str(image_folder / 'dog2.jpg')))
dog_list = [dog1_int, dog2_int]


### Semantic segmentation models

We will see how to use it with torchvision's FCN Resnet-50, loaded with `torchvision.models.segmentation.fcn_resnet50`. Let's start by looking at the output of the model.

In [None]:
#### **Cell 4: Semantic Segmentation with FCN ResNet-50**
# Use a pre-trained FCN ResNet-50 model to perform semantic segmentation and output prediction scores.

from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights

weights = FCN_ResNet50_Weights.DEFAULT
transforms = weights.transforms(resize_size=None)

model = fcn_resnet50(weights=weights, progress=False)
model = model.eval()

batch = torch.stack([transforms(d) for d in dog_list])
output = model(batch)['out']
print(output.shape, output.min().item(), output.max().item())

As we can see above, the output of the segmentation model is a tensor of shape ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and we can normalize them into ``[0, 1]`` by using a softmax. After the softmax, we can interpret each value as a probability indicating how likely a given pixel is to belong to a given class.

Let's plot the masks that have been detected for the dog class and for the boat class:

In [None]:
#### **Cell 5: Normalize and Visualize Masks**
# Normalize segmentation outputs to probabilities and visualize masks for specific classes.

sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}

normalized_masks = torch.nn.functional.softmax(output, dim=1)
print(normalized_masks.shape, normalized_masks.min().item(), normalized_masks.max().item())

dog_and_boat_masks = [
    normalized_masks[img_idx, sem_class_to_idx[cls]]
    for img_idx in range(len(dog_list))
    for cls in ('dog', 'boat')
]

show(dog_and_boat_masks)
# from left to right 
# image1 dog_class, image1 boat_class, image1 dog_class, image1 boat_class,

### Visualizing binary segmentation masks for the dog class

In [None]:
#### **Cell 6: Binary Masks and Segmentation Visualization**
# Create binary masks for the "dog" class and visualize the segmented regions.
class_dim = 1
boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog'])
print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}")
show([m.float() for m in boolean_dog_masks])



### Visualizing segmentation masks
The `~torchvision.utils.draw_segmentation_masks` function can be used to draw segmentation masks on images. Semantic segmentation and instance segmentation models have different outputs, so we will treat each independently.

Now that we have boolean masks, we can use them with `~torchvision.utils.draw_segmentation_masks` to plot them on top of the
original images:

In [None]:
#### **Cell 7: Overlay Segmentation Masks**
# Overlay segmentation masks on original images to combine predictions with visual context.

from torchvision.utils import draw_segmentation_masks

dogs_with_masks = [
    draw_segmentation_masks(img, masks=mask, alpha=0.7)
    for img, mask in zip(dog_list, boolean_dog_masks)
]
show(dogs_with_masks)


### More than one mask
We can plot more than one mask per image! Remember that the model returned as many masks as there are classes. Let's ask the same query as above, but this time for *all* classes, not just the dog class: "For each pixel and each class C, is class C the most likely class?"

This one is a bit more involved, so we'll first show how to do it with a single image, and then we'll generalize to the batch

In [None]:
#### **Cell 8: Multi-class Segmentation**
#  Visualize masks for all detected classes 
num_classes = normalized_masks.shape[1]
dog1_masks = normalized_masks[0]
class_dim = 0
dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None]

print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}")
print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}")

dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.7)
show(dog_with_all_masks)

### Masks for the batch
We can see in the image above that only 2 masks were drawn: the mask for the background and the mask for the dog. This is because the model thinks that only these 2 classes are the most likely ones across all the pixels. If the model had detected another class as the most likely among other pixels, we would have seen its mask above.

Removing the background mask is as simple as passing ``masks=dog1_all_classes_masks[1:]``, because the background class is the class with index 0.

Let's now do the same but for an entire batch of images. The code is similar but involves a bit more juggling with the dimensions.

In [None]:
#### **Cell 10: Multi-class Segmentation**
#  Visualize masks for all detected classes across the batch of images.
class_dim = 1
all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None]
print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}")
# The first dimension is the classes now, so we need to swap it
all_classes_masks = all_classes_masks.swapaxes(0, 1)

dogs_with_masks = [
    draw_segmentation_masks(img, masks=mask, alpha=.7)
    for img, mask in zip(dog_list, all_classes_masks)
]
show(dogs_with_masks)


## Instructions
Modify the code to load and visualise the images `Mexico-City.jpg` and `Street.jpg` included in the images folder, and complete the following tasks:

### **Task 1. Visualise more classes (probability maps)**
Extend the code in cell 5 to correctly visualise three different classes from each image. Explore the different classes in `sem_class_to_idx` and select at least one that is not present in the images.

### **Task 2. Visualise binary segmentation masks**
Extend the code in cell 6 to correctly visualise the binary map belonging to the class with the highest probability on each image.


