<a href="https://colab.research.google.com/github/JalalSayed1/DL_CW/blob/master/DL_CW_damages.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## GUID: 2571964S

## to describe and motivate your design choices: architecture, pre-processing, training regime

## to analyse, describe and comment on your results

## to provide some discussion on what you think are the limitations of your solution and what could be future work

Why UNet:
1. image generation and segmentation.

## Can discuss:
Choice of Architecture: Explain why you chose a specific architecture over others. For instance, U-Net might be chosen for its efficiency and small dataset effectiveness, while DeepLab could be chosen for its state-of-the-art performance on segmentation tasks.

Preprocessing: Describe any preprocessing steps you take, such as resizing images, normalising pixel values, data augmentation, etc.

Postprocessing: If you apply any postprocessing to the segmentation maps, such as CRFs to sharpen the boundaries, explain why and how this improves the results.

Loss Functions: Discuss the choice of loss function, which in the case of segmentation could be cross-entropy, dice coefficient, or a combination of several loss functions.

Metrics: Describe the metrics you'll use to evaluate the performance of your model, such as pixel accuracy, mean Intersection over Union (IoU), etc.

Training Strategy: Detail the training process, including the choice of optimiser, learning rate, batch size, and any other hyperparameters.

Results and Analysis: Present the results and provide an analysis of what worked and what didn’t. Discuss any challenges you faced and how you addressed them.

Visualisation: Explain the importance of visualisation in understanding the performance of your model. For example, overlay images help to see where the model is performing well and where it is making mistakes.

---

# Damages - Deep Learning Coursework 2024

The aim of this coursework will be for you to design, implement and test a deep learning architecture to detect and identify damage in images. Digitization allows to make historical pictures and art much more widely available to the public. Many such pictures have suffered some form of damage due to time, storage conditions and the fragility of the original medium. For example, the image below (A) shows an example of a digitized parchment that has suffered significant damage over time.

**The aim of this project is for you to design, implement and evaluate a deep learning model to detect and identify damage present in images.**

<table>
<tr>
<td>
<div>
<img src="damage_data/image_path/cljmrkz5n341f07clcujw105j.png" width="500"/>
</div>
</td>
<td>
<div>
<img src="damage_data/annotation_rgb_path/cljmrkz5n341f07clcujw105j.png" width="500"/>
</div>
</td>
</tr>
<td><center>(A) Image</center></td><td><center>(B) damage labels</center></td>
</table>
*(Note that the images will only show once you have downloaded the dataset)*


The image labels in this figure (B) identifies a smatter of peeling paint, a large stained area in the bottom left and a missing part on the top left. Each colour in those images corresponds to a different category of damage, including `fold`, `writing` or `burn marks`. You are provided with a dataset of a variety of damaged images, from Parchment to ceramic or wood painting, and detailed annotations of a range of damages.

You are free to use any architecture you prefer, from what we have seen in class. You can decide to use unsupervised pre-training of only supervised end-to-end training - the approach you choose is your choice.

### Hand-in date: Friday 15th of March before 4:30pm (on Moodle)

### Steps & Hints
* First, look at the data. What are the different type of images (content), what type of material, what type of damage? How different are they? What type of transformations for your data augmentation do you think would be acceptable here?.
* Second, check the provided helper functions for loading the data and separate into training and test set and cross-validation.
* Design a network for the task. What output? What layers? How many? Do you want to use an Autoencoder for unsupervised pre-training?
* Choose a loss function for your network
* Select optimiser and training parameters (batch size, learning rate)
* Optimise your model, and tune hyperparameters (especially learning rate, momentum etc)
* Analyse the results on the test data. How to measure success? Which classes are recognised well, which are not? Is there confusion between some classes? Look at failure cases.
* If time allows, go back to drawing board and try a more complex, or better, model.
* Explain your thought process, justify your choices and discuss the results!

### Submission
* submit ONE zip file on Moodle containing:
  * **your notebook**: use `File -> download .ipynb` to download the notebook file locally from colab.
  * **a PDF file** of your notebook's output as you see it: use `File -> print` to generate a PDF.
* your notebook must clearly contains separate cells for:
  * setting up your model and data loader
  * training your model from data
  * loading your pretrained model from github/gitlab/any other online storage you like!
  * testing your model on test data.
* The training cells must be disabled by a flag, such that when running *run all* on your notebook it does
  * load the data
  * load your model
  * apply the model to the test data
  * analyse and display the results and accuracy
* In addition provide markup cell:
  * containing your student number at the top
  * to describe and motivate your design choices: architecture, pre-processing, training regime
  * to analyse, describe and comment on your results
  * to provide some discussion on what you think are the limitations of your solution and what could be future work

* **Note that you must put your trained model online so that your code can download it.**


### Assessment criteria
* In order to get a pass mark, you will need to demonstrate that you have designed and trained a deep NN to solve the problem, using sensible approach and reasonable efforts to tune hyper-parameters. You have analysed the results. It is NOT necessary to have any level of accuracy (a network that predicts poorly will always yield a pass mark if it is designed, tuned and analysed sensibly).
* In order to get a good mark, you will show good understanding of the approach and provide a working solution.
* in order to get a high mark, you will demonstrate a working approach of gradual improvement between different versions of your solution.
* bonus marks for attempting something original if well motivated - even if it does not yield increased performance.
* bonus marks for getting high performance, and some more points are to grab for getting the best performance in the class.

### Notes
* You are provided code to isolate the test set and cross validation, make sure to keep the separation clean to ensure proper setting of all hyperparameters.
* I recommend to start with small models that can be easier to train to set a baseline performance before attempting more complex one.
* Be mindful of the time!

In [None]:
using_colab = False
if using_colab:
    from google.colab import drive
    drive.mount('/content/drive')

## Housekeeping

In [None]:
!pip install gdown pytorch_lightning

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

import os
import pandas as pd
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 243748701
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import gdown
import shutil

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
DEVICE

# Load dataset

We then load the metadata in a dataframe for convenience

In [None]:
if using_colab:
    !pwd

In [None]:
download_dataset = False
if download_dataset:
  if not os.path.exists("damage_data"):
      # !gdown 1v8aUId0-tTW3ln3O2BE4XajQeCToOEiS -O damages.zip

      # FILE_ID = "1771AIemjPvrIGjf_87tGMqwRhxAUuuw-"
      # !wget 'https://drive.google.com/uc?id=1771AIemjPvrIGjf_87tGMqwRhxAUuuw-&export=download' -O damages.zip

      !cp -r "/content/drive/MyDrive/Colab Notebooks/CW/damages.zip" "."

In [None]:
dataset_exist = os.path.exists("damage_data")
if dataset_exist:
  # set  that to wherever you want to store the data (eg, your Google Drive), choose a persistent location!
  root_dir = '.'
  # root_dir = '/content/drive/MyDrive/Colab Notebooks/CW/'

  data_dir = os.path.join(root_dir, "damage_data")
  csv_path = os.path.join(data_dir, 'metadata.csv')

  zip_path = os.path.join(root_dir, 'damages.zip')

  try:
      if not dataset_exist:
        # if .zip file is not unpacked, unpack it
        shutil.unpack_archive(zip_path, root_dir)
      df = pd.read_csv(csv_path)

  except Exception as e:  # if the dataset has not been downloaded yet, do it.
    if download_dataset:
        zip_path = os.path.join(root_dir, 'damages.zip')
        gdown.download(id='1v8aUId0-tTW3ln3O2BE4XajQeCToOEiS', output=zip_path)
        shutil.unpack_archive(zip_path, root_dir)
        df = pd.read_csv(csv_path)
    print(e)

This dataframe has the paths of where the dataset images and annotation labels are stored, plus classification labels.

In [None]:
if dataset_exist:
  df

The images in the dataset are categorised in terms of the type of `material`, meaning what was the original picture on, eg, Parchment, Glass or Textile.

In [None]:
if dataset_exist:
  df['material'].unique()

Moreover, images are also categorised in terms on the `content` of the image, meaning what is depicted: eg, Line art, geometric patterns, etc.

In [None]:
if dataset_exist:
  df['content'].unique()

## Labels
Segmentation labels are saved as a PNG image, where each number from 1 to 15 corresponds to a damage class like Peel, Scratch etc; the Background class is set to 255, and the Clean class (no damage) is set to 0. We also provide code to convert these annotation values to RGB colours for nicer visualisation, but for training you should use the original annotations.

In [None]:
name_color_mapping = {
    "Material loss": "#1CE6FF",
    "Peel": "#FF34FF",
    "Dust": "#FF4A46",
    "Scratch": "#008941",
    "Hair": "#006FA6",
    "Dirt": "#A30059",
    "Fold": "#FFA500",
    "Writing": "#7A4900",
    "Cracks": "#0000A6",
    "Staining": "#63FFAC",
    "Stamp": "#004D43",
    "Sticker": "#8FB0FF",
    "Puncture": "#997D87",
    "Background": "#5A0007",
    "Burn marks": "#809693",
    "Lightleak": "#f6ff1b",
}

class_names = [ 'Material loss', 'Peel', 'Dust', 'Scratch',
                'Hair', 'Dirt', 'Fold', 'Writing', 'Cracks', 'Staining', 'Stamp',
                'Sticker', 'Puncture', 'Burn marks', 'Lightleak', 'Background']

class_to_id = {class_name: idx+1 for idx, class_name in enumerate(class_names)}
class_to_id['Background'] = 255  # Set the Background ID to 255

def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

id_to_rgb = {class_to_id[class_name]: hex_to_rgb(color) for class_name, color in name_color_mapping.items()}
id_to_rgb[0] = (0,0,0)

# Create id2label mapping: ID to class name
id2label = {idx: class_name for class_name, idx in class_to_id.items()}

# Create label2id mapping: class name to ID, which is the same as class_to_id
label2id = class_to_id

# Non-damaged pixels
id2label[0] = 'Clean'
label2id['Clean'] = 0

print("id_to_rgb")
[print(f"{k}: {v}") for k, v in id_to_rgb.items()]
print()
print("id2label")
[print(f"{k}: {v}") for k, v in id2label.items()]
print()
print("label2id")
[print(f"{k}: {v}") for k, v in label2id.items()]
print()
print("name_color_mapping")
[print(f"{k}: {v}") for k, v in name_color_mapping.items()]
print()
print("class_to_id")
[print(f"{k}: {v}") for k, v in class_to_id.items()]

print()

In [None]:
from IPython.display import Markdown

legend='#### Colour labels for each damage type\n'
for damage in class_names:
    legend += '- <span style="color: {color}">{damage}</span>.\n'.format(color=name_color_mapping[damage], damage=damage)
display(Markdown(legend))

## Create dataset splits

Here is an example of how to split the dataset for Leave-one-out cross validation (LOOCV) based on material.

In [None]:
def create_leave_one_out_splits(df, criterion='material'):

    grouped = df.groupby(criterion)
    content_splits = {name: group for name, group in grouped}
    unique_val = df[criterion].unique()

    # Initialize a dictionary to hold the train and validation sets for each LOOCV iteration
    loocv_splits = {}

    for value in unique_val:
        # Create the validation set
        val_set = content_splits[value]

        # Create the training set
        train_set = pd.concat([content_splits[c] for c in unique_val if c != value])

        # Add these to the loocv_splits dictionary
        loocv_splits[value] = {'train_set': train_set, 'val_set': val_set}

    return loocv_splits


For this coursework, we will want to assess the generalisation of the method, so for that we will keep one type of material (`Canvas`) as test set, and only train on the remaining ones.

In [None]:
# split the dataset according to material type
full_splits = create_leave_one_out_splits(df, 'material')

# use Canvas as test set
test_set = full_splits['Canvas']['val_set']

# use the rest as training set
train_set = full_splits['Canvas']['train_set']

# prepare a leave-one-out cross validation for the training set
loocv_splits = create_leave_one_out_splits(train_set, 'material')

# identify the different type of image content
unique_material = train_set['material'].unique()

print("Training set materials:", unique_material)
print("Test set material:", test_set['material'].unique())


To help you, here are some helper functions to help crop and process images.

In [None]:
def random_square_crop_params(image, target_size):
    width, height = image.size
    min_edge = min(width, height)

    # Conditionally set the range for random crop size
    lower_bound = min(min_edge, target_size)
    upper_bound = max(min_edge, target_size)

    # Generate crop_size
    crop_size = random.randint(lower_bound, upper_bound)

    # Check and adjust if crop_size is larger than any dimension of the image
    if crop_size > width or crop_size > height:
        crop_size = min(width, height)

    # Generate random coordinates for the top-left corner of the crop
    x = random.randint(0, width - crop_size)
    y = random.randint(0, height - crop_size)

    return (x, y, x + crop_size, y + crop_size)

def apply_crop_and_resize(image, coords, target_size):
    image_crop = image.crop(coords)
    image_crop = image_crop.resize((target_size, target_size), Image.NEAREST)
    return image_crop

We also provide a simple class for holding the dataset

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import random
import numpy as np
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, dataframe, target_size, is_train=True):
        self.dataframe = dataframe
        self.target_size = target_size
        self.is_train = is_train

        self.to_tensor = transforms.ToTensor()

        # Define the normalization transform
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])

    def __len__(self):
            return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = Image.open(row['image_path']).convert('RGB')
        annotation = Image.open(row['annotation_path']).convert('L')
        annotation_rgb = Image.open(row['annotation_rgb_path']).convert('RGB')
        id = row['id']
        material = row['material']
        content = row['content']

        if self.is_train:
            # Generate random square cropping coordinates
            crop_coords = random_square_crop_params(image, self.target_size)

            # Apply the same cropping and resizing to all
            image = apply_crop_and_resize(image, crop_coords, self.target_size)
            annotation = apply_crop_and_resize(annotation, crop_coords, self.target_size)
            annotation_rgb = apply_crop_and_resize(annotation_rgb, crop_coords, self.target_size)
        else:  # Validation
            # Instead of cropping, downsize the images so that the longest edge is 1024 or less
            # max_edge = max(image.size)
            # if max_edge > 1024:
            #     downsample_ratio = 1024 / max_edge
            #     new_size = tuple([int(dim * downsample_ratio) for dim in image.size])

            #     image = image.resize(new_size, Image.BILINEAR)
            #     annotation = annotation.resize(new_size, Image.NEAREST)
            #     annotation_rgb = annotation_rgb.resize(new_size, Image.BILINEAR)

            # Generate random square cropping coordinates
            crop_coords = random_square_crop_params(image, self.target_size)

            # Apply the same cropping and resizing to all
            image = apply_crop_and_resize(image, crop_coords, self.target_size)
            annotation = apply_crop_and_resize(annotation, crop_coords, self.target_size)
            annotation_rgb = apply_crop_and_resize(annotation_rgb, crop_coords, self.target_size)

        # Convert PIL images to PyTorch tensors
        image = self.to_tensor(image)
        annotation = torch.tensor(np.array(annotation), dtype=torch.long)
        annotation_rgb = self.to_tensor(annotation_rgb)

        # Normalize the image
        image = self.normalize(image)

        # Change all values in annotation that are 255 to 16
        #! why?
        annotation[annotation == 255] = 16

        return {
            'image': image,
            'annotation': annotation,
            'annotation_rgb': annotation_rgb,
            'id': id,
            'material': material,
            'content': content
        }


Here we create a DataModule which encapsulates our training and validation DataLoaders; you can also do this manually by only using the Pytorch DataLoader class, lines 24 and 27.

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F

class CustomDataModule(pl.LightningDataModule):
    def __init__(self, loocv_splits, current_material, target_size, batch_size=32, num_workers=4):
        super().__init__()
        self.loocv_splits = loocv_splits
        self.current_material = current_material
        self.target_size = target_size
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        # Load current train and validation set based on LOOCV iteration
        train_df = self.loocv_splits[self.current_material]['train_set']
        val_df = self.loocv_splits[self.current_material]['val_set'].sample(frac=1).reset_index(drop=True)

        self.train_dataset = CustomDataset(dataframe=train_df, target_size=self.target_size, is_train=True)
        self.val_dataset = CustomDataset(dataframe=val_df, target_size=self.target_size, is_train=False)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        pass


The following will create a data module for validating on the first content in the list (`Parchment`) and training on all the other types of material (you will want to do that for each fold).

In [None]:
num_workers = 4 if using_colab else 0
data_module = CustomDataModule(loocv_splits=loocv_splits,
                               current_material=unique_material[0],
                               target_size=512,
                               batch_size=4,
                               num_workers=num_workers)

Finally, we can get the train and validation data loaders from the data module.

In [None]:
if dataset_exist:
  data_module.setup()
  train_loader = data_module.train_dataloader()
  val_loader = data_module.val_dataloader()

  print("Number of training batches:", len(train_loader))
  print("Number of training samples:", len(train_loader.dataset))
  # val dataset is set to have batch size of 1:
  print("Number of validation batches:", len(val_loader))
  print("Number of validation samples:", len(val_loader.dataset))
  print("image size:", train_loader.dataset[-1]['image'].shape)
  print("annotation size:", train_loader.dataset[-1]['annotation'].shape)
  print("number of material in training set:", len(train_loader.dataset.dataframe['material'].unique()))


# Dataset visualisation

We need to denormalise the images so we can display them

In [None]:
# Mean and std used for normalization
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def denormalize(image, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]):
    img_cpy = image.copy()
    for i in range(3):
        img_cpy[..., i] = img_cpy[..., i] * std[i] + mean[i]
    return img_cpy

## Visualise training samples
Random square crops of the images and correspoding RGB annotations on their own and overlaid onto the image.

In [None]:
num_to_visualise = len(train_loader.dataset) # all images
print("Number of image that can be visualized:", num_to_visualise)

example_batch = next(iter(train_loader))
print("Shape of the image batch:", example_batch['image'].shape)

example_images = example_batch['image']
example_annotations = example_batch['annotation']
example_annotation_rgbs = example_batch['annotation_rgb']

# Number of examples to visualize
# N = min(4, len(example_images))
N = min(num_to_visualise, len(example_images))
print("Number of examples to visualize:", N)

fig, axes = plt.subplots(N, 3, figsize=(15, 5 * N))

for ax, col in zip(axes[0], ['Image', 'Annotation', 'Overlay']):
    ax.set_title(col, fontsize=24)

for i in range(N):
    example_image = denormalize(example_images[i].numpy().transpose((1, 2, 0)), mean, std)  # C, H, W -> H, W, C
    example_annotation = Image.fromarray(np.uint8(example_annotations[i].numpy()), 'L')
    example_annotation_rgb = example_annotation_rgbs[i].numpy().transpose((1, 2, 0))  # C, H, W -> H, W, C

    # Create an alpha (transparency) channel where black pixels in annotation_rgb are fully transparent
    alpha_channel = np.all(example_annotation_rgb == [0, 0, 0], axis=-1)
    example_annotation_rgba = np.dstack((example_annotation_rgb, np.where(alpha_channel, 0, 1)))

    axes[i, 0].imshow(example_image)
    axes[i, 0].axis('off')

    #axes[i, 1].imshow(example_annotation, cmap='gray', vmin=0, vmax=255)
    axes[i, 1].imshow(example_annotation_rgb)
    axes[i, 1].axis('off')

    axes[i, 2].imshow(example_image)
    axes[i, 2].imshow(example_annotation_rgba)
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()


Visualising the validation set, which loads the left-out class as whole images.

In [None]:
val_iter = iter(val_loader)
example_batches = [next(val_iter) for _ in range(N)]

# Initialize empty lists to collect different parts of each batch
example_images = []
example_annotations = []
example_annotation_rgbs = []
example_materials = []
example_contents = []

# Populate the lists with the data from the 4 batches
for batch in example_batches:
    example_images.append(batch['image'].squeeze())
    example_annotations.append(batch['annotation'].squeeze())
    example_annotation_rgbs.append(batch['annotation_rgb'].squeeze())
    example_materials.append(batch['material'][0])
    example_contents.append(batch['content'][0])

    print("batch image shape:", batch['image'].shape)
    print("batch annotation shape:", batch['annotation'].shape)
    print("Shape of the image batch:", example_images[0].shape)
    print("Shape of the annotation batch:", example_annotations[0].shape)

# Number of examples to visualize
# N = min(4, len(example_images))
N = min(num_to_visualise, len(example_images))

fig, axes = plt.subplots(N, 3, figsize=(15, 5 * N))

for ax, col in zip(axes[0], ['Image', 'Annotation', 'Overlay']):
    ax.set_title(col, fontsize=24)

for i in range(N):
    example_image = denormalize(example_images[i].numpy().transpose((1, 2, 0)), mean, std)  # C, H, W -> H, W, C
    example_annotation = example_annotations[i].numpy()
    example_annotation_rgb = example_annotation_rgbs[i].numpy().transpose((1, 2, 0))  # C, H, W -> H, W, C
    example_material = example_materials[i]
    example_content = example_contents[i]
    # Create an alpha (transparency) channel where black pixels in annotation_rgb are fully transparent
    alpha_channel = np.all(example_annotation_rgb == [0, 0, 0], axis=-1)
    example_annotation_rgba = np.dstack((example_annotation_rgb, np.where(alpha_channel, 0, 1)))
    axes[i, 0].imshow(example_image)
    axes[i, 0].axis('off')

    axes[i, 1].imshow(example_annotation_rgb)
    axes[i, 1].axis('off')

    axes[i, 2].imshow(example_image)
    axes[i, 2].imshow(example_annotation_rgba)
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Evaluation

For the final evaluation of the model, make sure to test performance on the left out category, `Canvas` to have a fair idea on how well the model generalises.

In [None]:
test_module = CustomDataModule(loocv_splits=full_splits,
                               current_material='Canvas',
                               target_size=512,
                               batch_size=4)

test_module.setup()

test_loader = test_module.val_dataloader()


In [None]:
# print(test_loader.dataset[0]['annotation'].shape)
# max_all_images = torch.max(torch.tensor([torch.max(test_loader.dataset[i]['annotation']) for i in range(len(test_loader.dataset))]))
# print(max_all_images)
# min_all_images = torch.min(torch.tensor([torch.min(test_loader.dataset[i]['annotation']) for i in range(len(test_loader.dataset))]))
# print(min_all_images)

---

# My Solution:

---

### Network Design

In [None]:
# class UNet(nn.Module):
#     def __init__(self, in_channels, out_channels, features=[8, 16, 32, 64, 128, 256, 512, 1024, 2048], verbose=False):
#         super().__init__()
#         self.verbose = verbose
#         self.in_channels = in_channels
#         self.out_channels = out_channels
#         self.features = features
        
#         self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#                                       # nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
#                         )

#         self.in_conv = nn.Sequential(
#             nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1),
#             nn.ReLU()
#             # nn.Conv2d(features[0], features[0], kernel_size=3, padding=1),
#             # nn.ReLU()
#         )

#         self.encoder_layers = nn.ModuleList()
#         self.decoder_layers = nn.ModuleList()

#         #' Create encoder path
#         prev_channels = features[0]

#         for index, feature in enumerate(features[1:]):
#             self.encoder_layers.append(self.create_encoder_layer(prev_channels, feature))
#             if (index + 1) % 2 == 0:
#                 # Reduce the spatial dimensions by half every 2 layers:
#                 self.encoder_layers.append(nn.MaxPool2d(2)) 
#             prev_channels = feature

#         #' Create decoder path
#         for index, feature in enumerate(features[::-1][:-1]):
#             # + feature because of the skip connection cat operation:
#             in_channels = prev_channels*2 if index == 0 else prev_channels + features[::-1][index]
#             # feature * 2 # if index > 0 else feature
#             self.decoder_layers.append(self.create_decoder_layer(prev_channels, feature))
#             if (index + 1) % 2 == 0:
#                 # upsample by x2:
#                 self.decoder_layers.append(self.upsample)
#             prev_channels = feature * 2

#         # Final layer of the decoder
#         self.final_layer = nn.Sequential(
#             nn.Conv2d(prev_channels//2, out_channels, kernel_size=3, stride=2, padding=1)
#         )

#     def create_encoder_layer(self, in_channels, out_channels):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
#             nn.BatchNorm2d(out_channels), # normalize the output of the previous layer for faster training
#             nn.ReLU()
#             # nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
#             # nn.BatchNorm2d(out_channels),
#             # nn.ReLU()
#         )

#     def create_decoder_layer(self, in_channels, out_channels):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU()
#             # nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
#             # nn.ReLU()
#         )

#     def enable_verbose(self, enabled):
#         self.verbose = enabled

#     def forward(self, x):
        
#         model_dtypes = [param.dtype for param in self.parameters()]
#         if any([model_dtype == torch.float16 for model_dtype in model_dtypes]):
#           if self.verbose:
#               print(f"Using {model_dtypes[0]} percision..")
#           x = x.to('cuda', dtype=torch.float16)
#         elif self.verbose:
#           print(f"Using {model_dtypes[0]} percision..")

#         if self.verbose:
#             print(f"Input shape: {x.shape}, dtype: {x.dtype }")

#         # Store the output from each encoder layer for use in the decoder path (skip connections):
#         encoder_outputs = []

#         x = self.in_conv(x)
#         encoder_outputs.append(x)
#         if self.verbose:
#             print(f"X shape after in_conv: {x.shape}")

#         # Apply encoder layers
#         for encoder_layer in self.encoder_layers:
#             x = encoder_layer(x)
#             encoder_outputs.append(x)
#             if self.verbose:
#                 print(f"X shape after encoder layer: {x.shape}")

#         # upsample x to match the size decoder input:
#         # if self.verbose:
#         #     print(f"Upsampling x by 2..")
#         # x = self.upsample(x)

#         # Apply decoder layers
#         for index, decoder_layer in enumerate(self.decoder_layers):
#             # apply the decoder layer:
#             x = decoder_layer(x)
#             if self.verbose:
#                 print(f"X shape after decoder layer: {x.shape}")
            
#             # get the corresponding encoder output:
#             encoder_output = encoder_outputs.pop()
#             if self.verbose:
#                 print(f"Concatenating: x: {x.shape}, encoder_output: {encoder_output.shape}")
#             x = torch.cat((x, encoder_output), dim=1)
#             if self.verbose:
#                 print(f"X shape after concatenation: {x.shape}")

#         x = self.final_layer(x)
#         if self.verbose:
#             print(f"X shape after final layer: {x.shape}")

#         return x


In [None]:
""" Parts of the U-Net model """
#' Resource: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
            # nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(out_channels),
            # nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        
        model_dtypes = [param.dtype for param in self.parameters()]
        if any([model_dtype == torch.float16 for model_dtype in model_dtypes]):
          x = x.to('cuda', dtype=torch.float16)
        
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

#### Limit training dataset to only one material for simplicity

In [None]:
limit_dataset = False # set to True to limit the dataset to one material for faster training

if limit_dataset:
    # limit the training dataset to only one material:
    train_loader.dataset.dataframe = train_loader.dataset.dataframe[train_loader.dataset.dataframe['material'] == 'Wood']
    print("Number of material in training set:", len(train_loader.dataset.dataframe['material'].unique()))
    print(f"Number of images in the training set: {len(train_loader.dataset)}")

    image_path = train_loader.dataset.dataframe.iloc[-1]['image_path']
    image_annotation_path = train_loader.dataset.dataframe.iloc[-1]['annotation_rgb_path']


    # from IPython.display import display, Image
    # display(Image(filename=image_path))
    # display(Image(filename=image_annotation_path))


### Loss Function:

In [None]:

class DiceLoss(nn.Module):
    def __init__(self, smooth=1, class_weights=None):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.class_weights = class_weights  # Optionally, provide class weights

    def forward(self, prediction, target):
        """
        Calculates the Dice coefficient per class and returns the average.

        Args:
            prediction: Tensor of predictions (N, C, H, W) 
            target: Tensor of ground truth labels (N, H, W), where each value
                    is an integer representing the class label.
        """

        # One-hot encode the target to simplify calculations
        target = torch.eye(prediction.size(1)).to(prediction.device)[target.squeeze(1)]  # (N, C, H, W)
        target = target.permute(0, 3, 1, 2).contiguous()  # (N, H, W, C) 

        # Calculate intersection and union for each class
        intersection = (prediction * target).sum(dim=(1, 2))
        union = prediction.sum(dim=(1, 2)) + target.sum(dim=(1, 2))

        # Apply class weights (if provided)
        if self.class_weights is not None:
            intersection = intersection * self.class_weights.to(prediction.device)
            union = union * self.class_weights.to(prediction.device)

        # Calculate Dice coefficient per class
        if torch.isnan(union).any():
            union[torch.isnan(union)] = 0.0
        if torch.isnan(intersection).any():
            intersection[torch.isnan(intersection)] = 0.0
        
        dominator = union + self.smooth
        nominator = 2.0 * intersection + self.smooth
        dice_coefficient = nominator / dominator
        
        # if dice_coefficient is nan, set it to 0
        if torch.isnan(dice_coefficient).any() or torch.isinf(dice_coefficient).any():
            dice_coefficient[:] = 0.0

        # Calculate mean Dice loss over all classes
        dice_loss = 1 - dice_coefficient.mean()

        return dice_loss

In [None]:

# #' Shape of prediction: torch.Size([1, 17, 512, 512]), dtype: torch.float32
# #' Shape of target: torch.Size([1, 512, 512]), dtype: torch.int64
# # Sample tensors with adjusted shapes
# prediction = torch.rand(4, 17, 512, 512)  # 1 batch, 2 classes, 512x512 
# target = torch.randint(0, 17, (4, 512, 512))  # 1 batch, 512x512, class labels
# prediction = prediction.to(DEVICE)
# target = target.to(DEVICE)

# print(f"Shape of prediction: {prediction.shape}, dtype: {prediction.dtype}")
# print(f"Shape of target: {target.shape}, dtype: {target.dtype}")

# # --- CASE 1: Without class weights ---
# dice_loss_no_weights = DiceLoss()
# loss_no_weights = dice_loss_no_weights(prediction, target)
# print("Loss without weights:", loss_no_weights.item())

# # --- CASE 2: With class weights (ignore class 0) ---
# class_weights = torch.tensor([0.0, 1.0])
# class_weights = class_weights.view(1, 2, 1, 1)  # Expand dimensions
# print("Class weights:", class_weights)
# print(f"class weights shape: {class_weights.shape}")
# dice_loss_with_weights = DiceLoss(class_weights=class_weights)
# loss_with_weights = dice_loss_with_weights(prediction, target)
# print("Loss with weights (ignoring class 0):", loss_with_weights.item())


#### Helper Functions

In [None]:
def train_epoch(model, train_loader, loss_fn, optimiser):
    model.train()
    train_loss = []
    for batch in train_loader:
        images = batch['image'].to(DEVICE)
        targets = batch['annotation'].to(DEVICE)
        outputs = model(images)
        loss = loss_fn(outputs, targets)
        
        # Backward pass:
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss)

def test_epoch(model, test_loader, loss_fn):
    model.eval()
    test_loss = []
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(DEVICE)
            targets = batch['annotation'].to(DEVICE)

            outputs = model(images)
            loss = loss_fn(outputs, targets)

            test_loss.append(loss.detach().cpu().numpy())
    return np.mean(test_loss)


### Hyperparameters

#### Hyperparameters optimisation

In [None]:

torch.manual_seed(0)

num_classes = len(class_names) + 1  # 16 damage classes + 1 background class

in_channels = train_loader.dataset[-1]['image'].shape[-3]
out_channels = num_classes

# class_weights = None
class_weights = torch.ones(num_classes).to(DEVICE)
class_weights[0] = 0  # ignore the background class
class_weights = class_weights.float()
class_weights = class_weights.view(1, num_classes, 1, 1)  # Expand dimensions
# print("Class weights shape:", class_weights.shape)

hyperparameters = { 
    "lr": 1e-3,
    "momentum": 0.2,
    "smooth": 1
}


In [None]:
optimise_parameters = True

if optimise_parameters:

    from ax.service.managed_loop import optimize
    from ax.plot.trace import optimization_trace_single_method
    from ax.utils.notebook.plotting import render, init_notebook_plotting

    def train_evaluate(parameterisation):
        net = UNet(in_channels, out_channels).to(DEVICE)
        # net.enable_verbose(True)
        optimiser = optim.SGD(net.parameters(), lr=parameterisation.get('lr'), momentum=parameterisation.get('momentum'))
        # ignore the background class when calculating the loss:
        loss_fn = DiceLoss(class_weights=class_weights).to(DEVICE)

        try:
            train_epoch(model=net, train_loader=train_loader, loss_fn=loss_fn, optimiser=optimiser)
        except Exception as e:
            print(f"Error happened: {e}")

        prediction_dice_error = []
        for batch in val_loader:
            images = batch['image'].to(DEVICE)
            targets = batch['annotation'].to(DEVICE)
            with torch.no_grad():  # Ensure no gradients are calculated during validation
                outputs = net(images)
                loss = loss_fn(outputs, targets)

            prediction_dice_error.append(loss.detach().cpu().numpy())

        prediction_dice_error_mean = np.mean(prediction_dice_error)
        print(f"Parameters: {parameterisation}, Prediction Correctness: {prediction_dice_error_mean}")

        if prediction_dice_error_mean < 1:
            torch.save(net.state_dict(), f'intermediate_models/unet_{prediction_dice_error_mean}_prec.pth')

        return {"Loss": (prediction_dice_error_mean, 0.0)}

    print("Optimising hyperparameters..")
    parameters = [
        {"name": "lr", "type": "range", "bounds": [1e-6, 1e-1], "log_scale": True},
        {"name": "momentum", "type": "range", "bounds": [0.0, 2.0]}
    ]
    total_trails = 1
    
    best_parameters, values, experiment, optimiser_model = optimize(
        parameters=parameters,
        evaluation_function=train_evaluate,
        objective_name="Loss",
        minimize=True,
        total_trials=total_trails
    )

    hyperparameters["lr"] = round(best_parameters['lr'], 5)
    hyperparameters["momentum"] = round(best_parameters['momentum'], 2)

    print("\nFinished optimising hyperparameters.")
    print("Best parameters:")
    [print(f"\t{key}: {value}") for key, value in best_parameters.items()]

    init_notebook_plotting(offline=True)
    best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
    data = optimization_trace_single_method(best_objectives, title="Optimization trace", ylabel="Similarity")
    render(data)

else:
    print("Using default hyperparameters..")
    [print(f"{key}: {value}") for key, value in hyperparameters.items()]


In [None]:

print(f"Initialising model with {in_channels} input channels and {out_channels} output channels.")
model = UNet(in_channels, out_channels, bilinear=False).to(DEVICE)
model = model.half() # half the precision

optimiser = optim.SGD(model.parameters(), lr=hyperparameters['lr'], momentum=hyperparameters['momentum'])

loss_fn = DiceLoss(smooth=hyperparameters['smooth'], class_weights=class_weights).to(DEVICE)
loss_fn = loss_fn.to(DEVICE)
# loss_fn = nn.MSELoss().to(DEVICE)
print("########## Finished initialising model and training parameters. ##########")

### Training

In [None]:
test_model = False # set to True to test the model correctness before training

if test_model:
    from torchsummary import summary
    # model.enable_verbose(True)
    image_size = train_loader.dataset[-1]['image'].shape[-1] # square images so only one dimension is needed
    summary(model, (in_channels, image_size, image_size))
    # model.enable_verbose(False)


#### Train model

In [None]:

train_model_again = True # set to True to train the model again.

training_performance = {'train_loss': []}
validation_performance = {'val_loss': []}

if train_model_again:
    num_epochs = 200
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, loss_fn, optimiser)
        val_loss = test_epoch(model, val_loader, loss_fn)

        if (epoch < 10) or ((epoch + 1) % 5 == 0):
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        training_performance['train_loss'].append(train_loss)
        validation_performance['val_loss'].append(val_loss)

    print("Training complete!")
    # Plot the training and validation losses
    plt.figure(figsize=(20, 5))
    
    plt.plot(training_performance['train_loss'], label='Train Loss')
    plt.plot(validation_performance['val_loss'], label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')

    plt.title('Training and Validation Accuracies')
    plt.legend()

    # save plot:
    plt.savefig(f'img/training_results/losses_lr={hyperparameters["lr"]}_momentum={hyperparameters["momentum"]}.png')

    plt.show()


In [None]:
save_model = True # set to True to save the model to github after training.

# torch.save(model.state_dict(), 'unet.pth')

if save_model:
    # Save model locally:
    torch.save(model.state_dict(), 'unet.pth')
    # Save model to github:
    !git add unet.pth
    !git commit -m "Add trained UNet model"
    !git push

### Load model from github

In [None]:
load_model = False # set to True to load the model from github after training.

if load_model:

    import requests

    def download_model(url, model_path):
        r = requests.get(url, allow_redirects=True)
        if r.status_code == 200:
            # Override the model.pth file if it already exists:
            with open(model_path, 'wb') as f:
                f.write(r.content)
            print(f"Model downloaded to {model_path}")
            return model_path
        else:
            print(f"Failed to download model. Status code: {r.status_code}")
            return None

    #! update id:
    model_url = 'https://raw.githubusercontent.com/JalalSayed1/Image-Damage-Classification/3baa846923307ae573f4fbecd7c8f00fc269fad4/unet.pth'

    model_path = download_model(model_url, 'unet.pth')
    if model_path:
        print("Model downloaded successfully!")
    else:
        print("Failed to download model.")


### Test Model performance

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
def map_class_ids_to_rgb(input_tensor, id_to_rgb, id2label):
    # Assuming input_tensor is of shape [batch, C, H, W]
    batch_size, C, H, W = input_tensor.shape

    # Initialize the RGB image tensor with shape [batch, 3, H, W]
    rgb_image = torch.zeros((batch_size, 3, H, W), dtype=torch.uint8)

    # Iterate over each image in the batch
    for batch_idx in range(batch_size):
        # Extract the class ID tensor for the current image
        # If input_tensor is actually [1, 3, H, W] representing class IDs, you need to ensure it's a single-channel image.
        # Ensure input_tensor is of shape [batch, 1, H, W] if each pixel represents a class ID.
        class_ids_2d = input_tensor[batch_idx, 0, :, :]  # Assuming class IDs are in a single channel

        classes_freq = {} #!

        # Iterate over each class ID as before
        for class_id, rgb in id_to_rgb.items():
            if class_id == 17:
                class_id = 255

            class_mask = (class_ids_2d == class_id)
            classes_freq[class_id] = class_mask.sum() #!

            for channel, color_value in enumerate(rgb):
                # Correctly use the class_mask for indexing
                # Here, we need to expand class_mask to match rgb_image dimensions for broadcasting
                rgb_image[batch_idx, channel, class_mask] = color_value

        for k, v in sorted(classes_freq.items(), key=lambda item: item[1], reverse=True):
            print(f"Class ID: {k:<5} \t Damage {id2label[k]:<10} \t Num of pixels: {v:<10}") #!
        print()

    return rgb_image


In [None]:
from torchvision.transforms.functional import to_pil_image

model.load_state_dict(torch.load('unet.pth', map_location=DEVICE))
model.eval()
print("Model loaded successfully!")


N = min(5, len(test_loader.dataset))
print(f"Number of examples to visualize: {N}\n")

# print(f"example image shape: {test_loader.dataset[-1]['image'].shape}")
test_batches = [test_loader.dataset[i] for i in range(N)]

class_weights = torch.ones(num_classes).to(DEVICE)
class_weights[0] = 0  # ignore the background class
class_weights = class_weights.float()
class_weights = class_weights.view(1, num_classes, 1, 1)  # Expand dimensions
loss_fn = DiceLoss(class_weights=class_weights).to(DEVICE)

# empty the predictions folder to save the new predictions:
predictions_folder = 'img/predictions/test_results'
for file in os.listdir(predictions_folder):
    file_path = os.path.join(predictions_folder, file)
    if os.path.isfile(file_path):
        os.unlink(file_path) # delete the file

for index, batch in enumerate(test_batches):
    test_image = batch['image'].unsqueeze(0).to(DEVICE)
    test_annotation = (batch['annotation'].squeeze()).to(DEVICE)
    test_annotation_rgb = batch['annotation_rgb'].squeeze()
    predicted_annotation = model(test_image)

    # print(f"Shapes:")
    # print(f"\ttest_image: {test_image.shape}")
    # print(f"\ttest_annotation: {test_annotation.shape}")
    # print(f"\ttest_annotation_rgb: {test_annotation_rgb.shape}")
    # print(f"\tpredicted_annotation: {predicted_annotation.shape}")

    predicted_annotation = torch.argmax(predicted_annotation, dim=0).unsqueeze(0)
    # predicted_annotation_rgb = indices_to_rgb(predicted_annotation, id_to_rgb)

    fig, axes = plt.subplots(1, 5, figsize=(20, 5))

    similarity = round(1 - loss_fn(predicted_annotation, test_annotation.unsqueeze(0)).item(), 2)
    plt.suptitle(f"{index}. Annotation Similarity: {similarity*100}%", fontsize=24, ha='center', va='top', y=1.1)

    titles = ['Image', 'True Annotation', 'True Overlay', 'Predicted Anno.', 'Predicted Overlay']
    for ax, title in zip(axes, titles):
        ax.set_title(title, fontsize=24)

    test_image = denormalize(test_image.squeeze().cpu().numpy().transpose((1, 2, 0)), mean, std)  # C, H, W -> H, W, C
    
    test_annotation_rgb = test_annotation_rgb.numpy().transpose((1, 2, 0))  # C, H, W -> H, W, C

    predicted_annotation_rgb = map_class_ids_to_rgb(predicted_annotation, id_to_rgb, id2label).squeeze(0)
    # C, H, W -> H, W, C
    predicted_annotation_rgb = predicted_annotation_rgb.cpu().numpy().transpose((1, 2, 0))
    
    plt.imsave(os.path.join(predictions_folder, f"{index}.png"), predicted_annotation_rgb)

    axes[0].imshow(test_image)
    axes[0].axis('off')

    axes[1].imshow(test_annotation_rgb)
    axes[1].axis('off')

    alpha_channel = np.all(test_annotation_rgb == [0, 0, 0], axis=-1)
    test_annotation_rgba = np.dstack((test_annotation_rgb, np.where(alpha_channel, 0, 1)))
    axes[2].imshow(test_image)
    axes[2].imshow(test_annotation_rgba)
    axes[2].axis('off')

    axes[3].imshow(predicted_annotation_rgb)
    axes[3].axis('off')

    predicted_annotation_rgba = np.dstack((predicted_annotation_rgb, np.where(alpha_channel, 0, 1)))
    axes[4].imshow(test_image)
    axes[4].imshow(predicted_annotation_rgba)
    axes[4].axis('off')

'''
{'Material loss': ('#1CE6FF', (0, 0, 0), 'Black'),
 'Peel': ('#FF34FF', (28, 230, 255), 'Cyan / Aqua'),
 'Dust': ('#FF4A46', (255, 52, 255), 'Magenta / Fuchsia'),
 'Scratch': ('#008941', (255, 74, 70), 'Red'),
 'Hair': ('#006FA6', (0, 137, 65), 'Teal'),
 'Dirt': ('#A30059', (0, 111, 166), 'Teal'),
 'Fold': ('#FFA500', (163, 0, 89), 'Purple'),
 'Writing': ('#7A4900', (255, 165, 0), 'Yellow'),
 'Cracks': ('#0000A6', (122, 73, 0), 'Olive'),
 'Staining': ('#63FFAC', (0, 0, 166), 'Navy'),
 'Stamp': ('#004D43', (99, 255, 172), 'Silver'),
 'Sticker': ('#8FB0FF', (0, 77, 67), 'Teal'),
 'Puncture': ('#997D87', (143, 176, 255), 'Silver'),
 'Background': ('#5A0007', (153, 125, 135), 'Gray'),
 'Burn marks': ('#809693', (128, 150, 147), 'Gray'),
 'Lightleak': ('#f6ff1b', (246, 255, 27), 'Yellow')}
'''

In [None]:
# !pip freeze > requirements.txt