# Importing the modules

In [None]:
import numpy as np
import torch
import requests

import matplotlib.pyplot as plt

# torchvision related imports.
import torchvision.transforms.functional as F
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid

# models and transforms
from torchvision.transforms.functional import convert_image_dtype
from torchvision.models.segmentation import fcn_resnet50

# Defining the utilities

In [None]:
## utilities for multiple images.
def img_show(images):
    if not isinstance(images, list):
        # generalize cast images to list
        images = [images]
    fig, axis = plt.subplots(ncols=len(images), squeeze=False)
    for i, image in enumerate(images):
        image = image.detach() #detach from current DAG, no gradient.
        image = F.to_pil_image(image)
        axis[0, i].imshow(np.asarray(image))
        axis[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

## Getting the image to be segmented

In [None]:
url = "https://raw.githubusercontent.com/Apress/computer-vision-projects-with-pytorch/main/chapter4/semantic_example_highway.jpg"

# Saving locally for read_image to use
with open("semantic_example_highway.jpg", "wb") as f:
    f.write(requests.get(url).content)

#torchvision.io.read_image (returns a Tensor)
img1 = read_image("semantic_example_highway.jpg")  # [C, H, W]
print(img_tensor.shape)

In [None]:
box_car = torch.tensor([[170, 70, 220, 120]], dtype=torch.float) ## (xmin,ymin,ymax)
colors = ["blue"]
check_box = draw_bounding_boxes(img1, box_car, colors=colors, width=2)
img_show(check_box)

In [None]:
# batch for images.
batch_imgs = torch.stack([img1])
batch_torch = convert_image_dtype(batch_imgs, dtype=torch.float)

# Loading the model for evaluation.

In [None]:
model = fcn_resnet50(pretrained=True, progress=False)

# switching on evaluation mode.
model = model.eval()

# standard normalizing based on train config.
normalized_batch_torch = F.normalize(batch_torch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
result = model(normalized_batch_torch)['out']

## Passing the image through the model.

In [None]:
classes = [
    '__background__', 'aeroplane', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike','person', 'pottedplant', 'sheep', 'sofa', 'train','tvmonitor'
]

class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}

normalized_out_masks = torch.nn.functional.softmax(result, dim=1)

car_mask = [
    normalized_out_masks[img_idx, class_to_idx[cls]]
    for img_idx in range(batch_torch.shape[0])
    for cls in ('car', 'pottedplant', 'bus')
]

img_show(car_mask)