# Project: Dense Prediction: Monocular Depth Estimation and Semantic Segmentation

### Total points in the project = 120 = 65 (Part 1) + 35 (Part 2) + 20 (Competition)

<img src='https://i.imgur.com/I2rSgxd.png' width=200> <img src='https://i.imgur.com/1oP2EIg.png' width=200>

## Read Carefully!

<hr/>

**Note**: If you are using Google Colab, Make sure that your Colab notebook is a GPU instance. Also, the first time you run the training, the instance might crash for exceeding the allocated memory. This is expected behaviour, especially with large batch sizes. Collab will suggest restarting the session and providing instances with larger memory sizes.

**Note**: This project is more open-ended. Multiple solutions can be considered _correct_. As there already exist implementations of various deep networks for this task on the interweb, **plagiarism will NOT be tolerated**. Your code will be judged for similarity against code available online and other students' code. You are expected to justify every design decision when your project is being evaluated. Any plagiarism detected will lead to a Grade of ZERO WITHOUT ANY WARNINGS. You will be explicitly guided in the notebook to look-up information on internet when needed. 

**Note**: The networks you will design/implement may be much larger than what you have previously designed. Please bring hardware concerns to the attention to the TAs on slack as required. You will need to begin early to test out new ideas/hyperparameters and training will take much longer. Best of luck!

<hr/>

# Competition [20 points]
This project contains two sub-parts (Depth estimation and semantic segmentation) for which you will develop your models. Performance of your models on both the sub-tasks (given validation sets) will be compared against the other students in the class and ranked in the leaderboard. 20 points for grading are reserved for the competition and will be given depending on your position on the leaderboard. Your notebook will be re-run and your pretrained models may also be chosen for further testing for verification. Any form of cheating WILL be caught and will result in TOTAL grade of ZERO without any warnings.

#### **Top 3 winners** will be given food coupons by Prof. Peter Wonka :\)

Good luck!

# Part 1 : Monocular Depth Estimation

## Introduction

- In this part of the project, you are tasked to create a model that **estimates depth from a single input image**. The input is an RGB image and the output is a single channel dense depth map where each pixel is the estimated distance from the 'camera sensor' to an object in the scene in real world units (e.g. in meters). Depth from a single image is a fundemental vision task with many useful applications including scene understanding and reconstruction.

- You are to develop a convolutional neural network (CNN) that formulates the problem as a regression of the depth map from a single RGB image. 

- In this section, we provide all the source code needed for loading and evaluating your model.  You will reuse the model in the next section

- Your task in this section is to modify the code in order to:
    - Define a [UNet](https://arxiv.org/abs/1505.04597) model that takes an RGB image and outputs a single channel depth map. **[25 points]**
    - Define an approprate loss function. **[15 points]**
    - Tune the model to achieve an RMSE of **0.035** or less on the given validation set. **[25 points]**


## Setup

In [1]:
!pip install tqdm datasets timm

Collecting tqdm
  Obtaining dependency information for tqdm from https://files.pythonhosted.org/packages/00/e5/f12a80907d0884e6dff9c16d0c0114d81b8cd07dc3ae54c5e962cc83037e/tqdm-4.66.1-py3-none-any.whl.metadata
  Using cached tqdm-4.66.1-py3-none-any.whl.metadata (57 kB)
Collecting datasets
  Obtaining dependency information for datasets from https://files.pythonhosted.org/packages/09/7e/fd4d6441a541dba61d0acb3c1fd5df53214c2e9033854e837a99dd9e0793/datasets-2.14.5-py3-none-any.whl.metadata
  Downloading datasets-2.14.5-py3-none-any.whl.metadata (19 kB)
Collecting timm
  Obtaining dependency information for timm from https://files.pythonhosted.org/packages/7a/bd/2c56be7a3b5bc71cf85a405246b89d5359f942c9f7fb6db6306d9d056092/timm-0.9.7-py3-none-any.whl.metadata
  Downloading timm-0.9.7-py3-none-any.whl.metadata (58 kB)
     ---------------------------------------- 0.0/58.8 kB ? eta -:--:--
     -------------------- ------------------- 30.7/58.8 kB 1.3 MB/s eta 0:00:01
     ------------------

## Downloading data

In [3]:
import datasets
ds = datasets.load_dataset("shariqfarooq/cs323_densepred_depth")  # DO NOT change this

In [4]:
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'depth'],
        num_rows: 25356
    })
    test: Dataset({
        features: ['image', 'depth'],
        num_rows: 518
    })
})

In [5]:
from torchvision import transforms as T
from torch.utils.data import DataLoader


IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

image_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

depth_transforms = T.Compose([
    T.ToTensor()
])

def transform(batch):
  batch['image'] = ([image_transforms(im) for im in batch['image']])
  batch['depth'] = [depth_transforms(d)[:1] for d in batch['depth']]
  return batch

ds.set_transform(transform)


In [6]:
ds['train'][0]['image'].shape, ds['train'][0]['depth'].shape

(torch.Size([3, 256, 256]), torch.Size([1, 128, 128]))

## Sanity check 

In [None]:
# Examine training data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch
from torchvision.utils import make_grid

def denorm_imagenet(x):
    mean, std = torch.tensor(IMAGENET_MEAN).to(x.device), torch.tensor(IMAGENET_STD).to(x.device)
    if x.ndim == 3:
        mean, std = mean[:, None, None], std[:, None, None]
    elif x.ndim == 4:
        mean, std = mean[None, :, None, None], std[None, :,  None, None]

    return x * std + mean


def show_example_data(dataset, split='train', num=5):
    im_stacked = []
    depth_stacked = []
    for i in range(num):
        sample = dataset[split][i]
        im_stacked.append(denorm_imagenet(sample['image']))
        depth_stacked.append(sample['depth'])
    
    im_stacked = make_grid(torch.stack(im_stacked), nrow=num)
    depth_stacked = make_grid(torch.stack(depth_stacked), nrow=num, normalize=True, scale_each=True)

    fig, ax = plt.subplots(2, 1)
    ax[0].imshow(im_stacked.permute(1, 2, 0))
    ax[1].imshow(depth_stacked[0])
    for a in ax:
        a.axis('off')
    
    plt.tight_layout()
    plt.show()

show_example_data(ds)

## Getting started with timm

[`timm`](https://timm.fast.ai/) is a popular deep-learning library created by [Ross Wightman](https://twitter.com/wightmanr) and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results. 

We will be using `timm` to load encoder backbones pretrained on ImageNet so we don't have to define encoder architectures (that you did in previous projects) or start training from scratch



In [None]:
# timm example for loading any encoder architecture
import timm
import torch

model = timm.create_model('resnet34') # load resnet34 architecture
x     = torch.randn(1, 3, 224, 224)  # create random input tensor
model(x).shape

In [None]:
# You can also load a pretrained model
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)

In [None]:
# List al available models
avail_pretrained_models = timm.list_models(pretrained=True)

print(f"Number of available pretrained models: {len(avail_pretrained_models)}")
print(f"First 5 models: {avail_pretrained_models[:5]}")

As you can see, we have vast variety of available pretrained models (750+)! In this project, you are free to choose any one of them

In [None]:
# Searching through 750+ models can be a pain. timm provides glob filtering so we can filter the models by name
# Lets list all the models that have 'efficientnet' in their name
all_efficientnet_models = timm.list_models('*efficientnet*', pretrained=True)
all_efficientnet_models

### Feature extraction with timm

In [None]:
BACKBONE = "resnet34"  # TODO: Change this to your desired backbone

In [None]:
# Let's tell timm that we only need features from various from various layers (and not final ImageNet class logits)
encoder = timm.create_model(BACKBONE, pretrained=True, features_only=True)

In [None]:
print("Channel info:", encoder.feature_info.channels())
x = torch.rand(1, 3, 224, 224)
out = encoder(x)
assert type(out) is list
[o.shape for o in out]

#### Note that now we are able to get multi-scale features from our encoder. You will use this to build your Unet

## Model [25 points]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model


class Unet(nn.Module):
    """
    TODO: Complete the docstring of the Unet class (description, parameters, returns, etc.)
    """

    def __init__(
            self,
            backbone='resnet50',
            num_classes=1,
            final_activation=nn.Identity(), # TODO: Change this to your desired final activation
            # TODO: Add any other relevant args you need
    ):
        super().__init__()
        # TODO: Complete the definition of the Unet class
        # Use timm to load the encoder backbone with imagenet pretrained weights
        # encoder **must** return a list of feature maps that will be used by the decoder
        self.encoder = None
        self.decoder = None
        self.final_activation = final_activation

    def forward(self, x: torch.Tensor):
        # TODO: Complete the forward function
        x = None
        return x

In [None]:
model = Unet(BACKBONE, num_classes=1)

In [None]:
# Always perform a sanity check on the models you define
x = torch.randn(1, 3, 256, 256)
out = model(x)
print("Input shape", x.shape)
print("Output shape", out.shape)

In [None]:
# Move to GPUs. Using all GPUs available by default. You can change this.
model = nn.DataParallel(model)
model = model.cuda()

## Loss Function [15 points]

Define a loss function that is suitable for the depth estimation. Look up the latest papers, for example on [PapersWithCode leaderboards](https://paperswithcode.com/sota/monocular-depth-estimation-on-nyu-depth-v2).
Why will the current loss not work? Submit the answer in the notebook.

In [None]:
import torch
import torch.nn.functional as F


def loss_fn(pred_y, y):
    return torch.mean(y.sub(pred_y))

## Training + Evaluation [25 points]

Tune the hyperparameters and the architecture to achieve the target RMSE

In [None]:
### Hyperparameters
# TODO: Change these to your desired hyperparameters

epochs = 1
batch_size = 32
learning_rate = 10

workers = 4 # The number of parallel processes used to read data. Increase this if you have more cores.
train_loader = DataLoader(ds['train'], batch_size=batch_size, shuffle=True, num_workers=workers)
test_loader = DataLoader(ds['test'], batch_size=batch_size, shuffle=False, num_workers=workers)

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import numpy as np
from tqdm.auto import tqdm
import gc
import os

run_id = f'model_n{epochs}_bs{batch_size}_lr{learning_rate}'; print('\n\nTraining', run_id)
save_path = run_id + '.pkl'

# TODO: Experiment with different optimizers and learning rate schedulers
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(train_loader), epochs=epochs)


class RMSE(object):
    def __init__(self):
        self.sq_errors = []
        self.num_pix = 0
        
    def get(self):
        return np.sqrt(
                    np.sum(np.array(self.sq_errors))/self.num_pix
                )
    
    def add_batch(self, pred, target):
        sqe = (pred-target)**2
        self.sq_errors.append(np.sum(sqe))
        self.num_pix += target.size
        
    def reset(self):
        self.sq_errors = []
        self.num_pix = 0



ITER_PER_EPOCH = len(train_loader)
TOTAL_STEPS = ITER_PER_EPOCH * epochs


metrics = RMSE()

@torch.no_grad()
def validate(model, valid_loader):
    model.eval()
    metrics.reset()
    for i, (sample) in tqdm(enumerate(valid_loader), total=len(valid_loader), desc='Validating'):
        x, y = sample['image'].float().cuda(), sample['depth'].numpy()
        y_pred = model(x).detach().cpu().numpy()
        metrics.add_batch(y_pred, y)
    print('\nValidation RMSE {avg_rmse}'.format(avg_rmse=metrics.get()))


# One validation before we start training (good practice to catch errors early)
validate(model, test_loader)
pbar = tqdm(total=TOTAL_STEPS, desc='Training')
for epoch in range(epochs):
    model.train()
    N = len(train_loader)

    for i, (sample) in enumerate(train_loader):

        # Load a batch and send it to GPU
        x = sample['image'].float().cuda()
        y = sample['depth'].float().cuda()

        # Forward pass: compute predicted y by passing x to the model.
        y_pred = model(x)

        # Compute and print loss.
        loss = loss_fn(y_pred, y)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable
        # weights of the model).
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its parameters
        optimizer.step()
        # learning rate scheduler step.
        # TODO: location of this call might change depending on your choice of scheduler. 
        scheduler.step()

        pbar.update(1)
        # Report progress. Add any extra logging info here
        pbar.set_postfix({'epoch': f'{epoch+1}/{N}', 'loss': loss.item(), 'epoch%': "{0:.1f}%".format(100*(i+1)/N)})
            
        #break # useful for quick debugging        
    torch.cuda.empty_cache(); del x, y; gc.collect()
    
    # Validation after each epoch
    validate(model, test_loader)
    

# Save model
torch.save(model.state_dict(), save_path)
print('\nTraining done. Model saved ({}).'.format(save_path))

## Visual Test of the Trained Model (no tasks required)

In [None]:
# Load model from disk
#model = Unet(BACKBONE, num_classes=1)
#model.load_state_dict(torch.load('trained_model.pkl'))
#model.eval() # set to evaluation mode


from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# Visualize validation sample
sample = iter(test_loader).next()
x = sample['image'].float().cuda()
y_pred, y = model(x), sample['depth']

# plt.figure(figsize=(20,20))
plt.imshow(make_grid(sample['image'], padding=0, normalize=True).permute((1, 2, 0)))
plt.axis('off')
plt.title("Input")
plt.show()

# plt.figure(figsize=(20,20))
plt.imshow(make_grid(sample['depth'], padding=0, normalize=True, scale_each=True).permute((1, 2, 0))[:,:,0])
plt.axis('off')
plt.title("Ground Truth")
plt.show()

# plt.figure(figsize=(20,20))
plt.imshow(make_grid(y_pred.detach().cpu(), padding=0, normalize=True, scale_each=True).permute((1, 2, 0))[:,:,0])
plt.axis('off')
plt.title("Predicted")
plt.show()

####  At this point, you can restart your notebook for part 2

# Part 2 : Semantic Segmentation

In this part of the project, you will reuse the model you created in the previous part to perform Semantic Segmentation - instead of assigning a real number to each
pixel, you will assign it a class.

The tasks are as following:
- Implement data prepareation: encoding and decoding of segmentation maps **[5 points]**
- Modify the UNet model that takes an RGB image and now outputs a _label map_ of _N_ classes **[15 points]**
- Define an approprate loss function. **[3 points]**
- Tune the model to achieve an mIOU of **0.45** or higher on the given validation set. **[10 points]**
- Visualization **[2 points]**

## Data Preparation [5 points]

We are going to use the [PASCAL VOC dataset](http://host.robots.ox.ac.uk/pascal/VOC/), which is a commonly used benchmark. In order to reduce the
computational requirements, we will be using [a variant](https://huggingface.co/datasets/shariqfarooq/cs323_densepred_seg256) that has a uniform and slightly lower resolution (256x256) than official. 

In [None]:
# Lets load the dataset first
from datasets import load_dataset

ds_voc = load_dataset("shariqfarooq/cs323_densepred_seg256")  # DO NOT change this

In [None]:
ds_voc

In [None]:
# Examine training data
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def make_pil_grid(images, nrow=8):
    grid = Image.new('RGB', (images[0].width*nrow, images[0].height*((len(images)-1)//nrow+1)))
    for i, im in enumerate(images):
        grid.paste(im, (i%nrow*images[0].width, i//nrow*images[0].height))
    return grid
    


def show_example_data(dataset, split='train', num=7, nrow=7):
    ims = [dataset[split][i]['image'] for i in range(num)]
    masks = [dataset[split][i]['mask'] for i in range(num)]
    im_grid = make_pil_grid(ims, nrow)
    mask_grid = make_pil_grid(masks, nrow)

    grid = make_pil_grid([im_grid, mask_grid], 1)
    plt.figure(figsize=(20,20))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

show_example_data(ds_voc)

Now you have to implement the encoding and decoding of the data.

Make sure that the labels are in the range `0..N-1`, where
N is the number of classes - 21 in our case. You can have one special label for unknown regions.

We provide the map of RGB to label for convenience in `get_pascal_color_palette()`. The map should be read as this - if a pixel has color `[0, 0, 0]`, it has label 0. If the color is `[128, 0, 0]`, the label is 1 and so on.

You need to use the palette information to implement `encode_segmap` and `decode_segmap` functions

In [1]:

def get_pascal_color_palette():
    """Load the mapping that associates pascal classes with label colors
    Returns:
        np.ndarray with dimensions (21, 3)
    """
    return np.asarray([[0, 0, 0],
                       [128, 0, 0],
                       [0, 128, 0],
                       [128, 128, 0],
                       [0, 0, 128],
                       [128, 0, 128],
                       [0, 128, 128],
                       [128, 128, 128],
                       [64, 0, 0],
                       [192, 0, 0],
                       [64, 128, 0],
                       [192, 128, 0],
                       [64, 0, 128],
                       [192, 0, 128],
                       [64, 128, 128],
                       [192, 128, 128],
                       [0, 64, 0],
                       [128, 64, 0],
                       [0, 192, 0],
                       [128, 192, 0],
                       [0, 64, 128]])

def get_pascal_class_names():
    return ['Background',
            'Aeroplane',
            'Bicycle',
            'Bird',
            'Boat',
            'Bottle',
            'Bus',
            'Car',
            'Cat',
            'Chair',
            'Cow',
            'Diningtable',
            'Dog',
            'Horse',
            'Motorbike',
            'Person',
            'Pottedplant',
            'Sheep',
            'Sofa',
            'Train',
            'Tvmonitor']


def encode_segmap(mask, unk_label=255):
    """Encode segmentation label images as pascal classes
    Args:
        mask (np.ndarray or PIL.Image.Image): raw segmentation label image of dimension
          (M, N, 3), in which the Pascal classes are encoded as colours.
    Returns:
        (np.ndarray): class map with dimensions (M,N), where the value at
        a given location is the integer denoting the class index.
    """
    # TODO: Complete this function
    raise NotImplementedError("TODO: Implement this function")

def decode_segmap(mask, unk_label=255):
    """Decode segmentation label prediction as RGB images
    Args:
        mask (torch.tensor): class map with dimensions (B, M,N), where the value at
        a given location is the integer denoting the class index.
    Returns:
        (np.ndarray): colored image of shape (BM, BN, 3)
    """
    mask = mask.astype(int)
    mask[mask == unk_label] = 0
    # TODO: Complete this function
    raise NotImplementedError("TODO: Implement this function")

In [None]:
# TODO: Optionally add data augmentations to increase performance. 
# Note that any augmentation should act jointly on image and its mask label.  Look up `albumentations`.

im_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250])
])


def transform_voc(batch):
    batch['image'] = [im_transforms(i) for i in batch['image']]
    batch['mask'] = [torch.tensor(encode_segmap(m)) for m in batch['mask']]
    return batch

ds_voc.set_transform(transform_voc)

You should implement a few more sanity checks - the range of data in the RGB part, the range of data in the label part, whether the dataset returns tensors,
whether the labels have the datatype `torch.long` etc.

In [None]:
# Perform sanity tests as required
# TODO: Add any sanity tests

## Modifying Architecture and the loss [18 points]
You will have to some form of surgery on the network you constructed in Part 1.
You have to make sure you are initializing the weights with the depth model you trained above. And then doctor the model such that you attach a new _Segmentation Head_ as your final block.

1. The number of channels the last layer predicts must change to the number of classes in the dataset. The last activation may also change. 
2. The loss function must change to reflect the fact that we are now performing per-pixel classification. (What loss did you use for classification in Project 1?)
3. You might get a CUDA assert error. This means that you have a label higher than the number of channels in the _logits_. This is very common with semantic segmentation, where you might want to label some region unkown as it's label might be under doubt - for example near the edges of objects. Look up how to ignore a certain label with a classification loss.
4. Take care of input, label and logit sizes. We want predictions to be 256x256 as well, so you may need an upsampling layer in the _Segmentation Head_

Good luck!

In [None]:
### Hyperparameters
# TODO: Change these to your desired hyperparameters
epochs = 100
batch_size = 32
learning_rate = 1

workers = 4 # The number of parallel processes used to read data. Increase this if you have more cores.
train_loader = DataLoader(ds_voc['train'], batch_size=batch_size, shuffle=True, num_workers=workers)
test_loader = DataLoader(ds_voc['val'], batch_size=batch_size, num_workers=workers)

In [None]:
# IMPORTANT
# You should use the exact same Unet class definition as in the previous section so
# TODO: Run the cell containing the Unet class definition in the previous section before proceeding

In [None]:
BACKBONE = "resnet34"  # TODO: Change this to your desired backbone

In [None]:
# Load the pretrained depth model
depth_pretrained_path = "my_trained_model.pkl"  # TODO: Specify the path to your trained depth model
model = Unet(BACKBONE, num_classes=1)
model = nn.DataParallel(model)
model = model.cuda()
model.load_state_dict(torch.load(depth_pretrained_path))

### Segmentation Head

In [None]:
class SegHead(nn.Module):
    pass # TODO: Implement a segmentation head here. Remember to complete the docstrings as well


def change_classes(unet, target_classes=21):
    seg_head = SegHead()
    # TODO: Perform 'surgery' on the unet model to attach a segmentation head
    return unet

model = change_classes(model, target_classes=21)


In [None]:
import torch

def loss_fn(pred_y, y):
    #TODO: Change this to your desired loss function. 
    return torch.mean(y.sub(y_pred))  # Why wouldn't this work?

## Training and Evaluation [10 points]
Tune the hyperparameters to get the maximum possible score on the PASCAL VOC challenge. 
And answer the following questions:
1. What is the relationship between the _size_ of the class and the IOU How would you quantify this relationship?
2. What is the relationship between the number of instances and the IOU? how many times a class exists in an image vs the IOU?
3. Which weights can you not transfer from the depth model?

In [None]:
from utils import Metrics

In [None]:
model = model.cuda()

In [None]:
run_id = f'seg_model_n{epochs}_bs{batch_size}_lr{learning_rate}'; print('\n\nTraining', run_id)
save_path = run_id + '.pkl'

# TODO: Experiment with different optimizers and learning rate schedulers
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(train_loader), epochs=epochs)

metrics = Metrics(len(get_pascal_class_names()), get_pascal_class_names())


ITER_PER_EPOCH = len(train_loader)
TOTAL_STEPS = ITER_PER_EPOCH * epochs

@torch.no_grad()
def validate(model, valid_loader):
    model.eval()
    metrics.reset()
    for i, (sample) in tqdm(enumerate(valid_loader), total=len(valid_loader), desc='Validating'):
        x, y = sample['image'].float().cuda(), sample['mask'].numpy()
        y_pred = model(x)
        y_pred = torch.argmax(y_pred, dim=1) # get the most likely prediction
        metrics.add_batch(y, y_pred.detach().cpu().numpy())
    print('\nValidation stats ', metrics.get_table())


validate(model, test_loader)
pbar = tqdm(total=TOTAL_STEPS, desc='Training')
for epoch in range(epochs):
    model.train()

    N = len(train_loader)

    for i, (sample) in enumerate(train_loader):

        # Load a batch and send it to GPU
        x = sample['image'].float().cuda()
        y = sample['mask'].long().cuda()

        # Forward pass: compute predicted y by passing x to the model.
        y_pred = model(x)

        # Compute and print loss.
        loss = loss_fn(y_pred, y)


        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable
        # weights of the model).
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its parameters
        optimizer.step()
        scheduler.step() # TODO: location of this call might change depending on your choice of scheduler.
        pbar.update(1)
        pbar.set_postfix({'epoch': f"{epoch+1}/{epochs}", 'loss': loss.item(), 'epoch%': "{0:.1f}%".format(100*(i+1)/N)})




        #break # useful for quick debugging
    torch.cuda.empty_cache(); del x, y; gc.collect()

    # Validation after each epoch
    validate(model, test_loader)


# Save model
torch.save(model.state_dict(), save_path)
print('\nTraining done. Model saved ({}).'.format(save_path))

## Visualization  [2 points]
Use the `decode_segmap` function to visualize images and their Ground Truth and Predicted segmentation maps. The images must be from the validation set.


In [2]:
# TODO: Implement visualization