# Week 1: Introduction to Computer Vision

## Notebook 5: Instance Segmentation with a Pre-Trained Model using Torchvision

Welcome to the fifth notebook of this week's Applied AI Study Group! We will study instance segmentation problem in this notebook. The aim of our task will be to identify each object in the given images and differentiate the instances even the ones that belong to the same category.

### 1. Instance Segmentation

Instance segmentation treats each objects of the same class as they are different objects, hence, label them differently such as object 1, object 2, etc. Its difference from [Semantic Segmentation](https://github.com/inzva/Applied-AI-Study-Group/blob/add-cv-week1/Applied%20AI%20Study%20Group%20%236%20-%20January%202022/Week%201/4-segmentation_uNet_PyTorch_final.ipynb) is its aim towards differentiating every single object in the given images. It is a more challenging problem since we want to identify each object separately comparing to having one entity corresponding multiple object.

### 2. `Some model...` Convolutional Neural Network

TODO: will add model explanation.

### 3. Torchvision

Torchvision package is built on PyTorch. It aims to provide popular datasets, model (including pretrained models) architectures, and image transformation techniques applied in the field of computer vision.

To install Torchvision:

    !pip install torchvision

### 4. Imports and Checks

You should have installed Numpy and Matplotlib using `pip` and, PyTorch using [Week 0 - Notebook 2](https://github.com/inzva/Applied-AI-Study-Group/blob/add-frameworks-week/Applied%20AI%20Study%20Group%20%236%20-%20January%202022/Week%200/2-mnist_classification_convnet_pytorch.ipynb).


The following two cells will import required libraries and packages to run this notebook and will download the model we will use. 

TODO: will add model name and 1 sentence explanation.

In [None]:
# TODO: will update to instance segmentation
# In this example we use a pretrained model to perform segmentation
# here is the original code: https://www.learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
from torchvision import models

fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T

import numpy as np

In [None]:
# we define the label colors for each type of object.
# lets say, for aeroplanes, we color the pixels red.
# for bicycles, we paint the pixels green.
def decode_segmap(image, nc=21):
    label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])

    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
  
    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]
    
    rgb = np.stack([r, g, b], axis=2)
    return rgb

In [None]:
def segment(net, path):
    img = Image.open(path)
    plt.imshow(img); plt.axis('off'); plt.show()
    # pytorch transforms perform operations to data step by step. it is a way to do preprocessing.
    trf = T.Compose([T.Resize(256), 
                   T.CenterCrop(224), 
                   T.ToTensor(), 
                   T.Normalize(mean = [0.485, 0.456, 0.406], 
                               std = [0.229, 0.224, 0.225])])
  
    # cannot input one image. instead we input a dataset composed of one image. this operation is for that, we add another dimension.
    inp = trf(img).unsqueeze(0)
    #make forward prop, take the output tensor
    out = net(inp)['out']
    # take the class prediction with the maximum probability
    om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
    #pass the class to the coloring function
    rgb = decode_segmap(om)
    #draw the colored image
    plt.imshow(rgb); plt.axis('off'); plt.show()

In [None]:
#!curl https://images.pexels.com/photos/2385051/pexels-photo-2385051.jpeg
#!ls
# point out the resize!
segment(fcn, './datasets/segmentation.jpeg')