<a href="https://colab.research.google.com/github/ThOpaque/Binary_Classifier_DeepLearning/blob/main/VGG16_FCN8_CamVid_TF_to_PT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os
import zipfile
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import numpy as np

import torch
import torchvision
!pip install torch-summary
from torchsummary import summary

from matplotlib import pyplot as plt
import seaborn as sns


import tensorflow as tf

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [2]:
# download the dataset (zipped file)
!gdown --id 0B0d9ZiqAgFkiOHR1NTJhWVJMNEU -O /tmp/fcnn-dataset.zip 

Downloading...
From: https://drive.google.com/uc?id=0B0d9ZiqAgFkiOHR1NTJhWVJMNEU
To: /tmp/fcnn-dataset.zip
100% 126M/126M [00:01<00:00, 120MB/s]


In [4]:
# extract the downloaded dataset to a local directory: /tmp/fcnn
local_zip = '/tmp/fcnn-dataset.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp/fcnn')
zip_ref.close()

In [5]:
# pixel labels in the video frames


In [None]:
### TODO
def map_filename_to_image_and_mask(t_filename, a_filename, height=224, width=224):
  '''
  Preprocesses the dataset by:
    * resizing the input image and label maps
    * normalizing the input image pixels
    * reshaping the label maps from (height, width, 1) to (height, width, 12)

  Args:
    t_filename (string) -- path to the raw input image
    a_filename (string) -- path to the raw annotation (label map) file
    height (int) -- height in pixels to resize to
    width (int) -- width in pixels to resize to

  Returns:
    image (tensor) -- preprocessed image
    annotation (tensor) -- preprocessed annotation
  '''

  # Convert image and mask files to tensors 
  img_raw = PIL.Image.open(t_filename)
  img_raw = torchvision.transforms.ToTensor()(img_raw)
  anno_raw = PIL.Image.open(a_filename)
  anno_raw = torchvision.transforms.ToTensor()(anno_raw)
  #img_raw = tf.io.read_file(t_filename)
  #anno_raw = tf.io.read_file(a_filename)
  #image = tf.image.decode_jpeg(img_raw)
  #annotation = tf.image.decode_jpeg(anno_raw)
 
  # Resize image and segmentation mask
  image = tf.image.resize(image, (height, width,))
  annotation = tf.image.resize(annotation, (height, width,))
  image = tf.reshape(image, (height, width, 3,))
  annotation = tf.cast(annotation, dtype=tf.int32)
  annotation = tf.reshape(annotation, (height, width, 1,))
  stack_list = []

  # Reshape segmentation masks
  for c in range(len(class_names)):
    mask = tf.equal(annotation[:,:,0], tf.constant(c))
    stack_list.append(tf.cast(mask, dtype=tf.int32))
  
  annotation = tf.stack(stack_list, axis=2)

  # Normalize pixels in the input image
  image = image/127.5
  image -= 1

  return image, annotation

In [10]:
# show folders inside the dataset you downloaded
!ls /tmp/fcnn/dataset1

annotations_prepped_test   images_prepped_test
annotations_prepped_train  images_prepped_train


In [54]:
class CamVid_Dataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, anno_map_dir, height=224, width=224):      
        self.height = height
        self.width = width

        # generates the lists of image and label map paths
        image_file_list = os.listdir(image_dir)
        anno_map_file_list = os.listdir(anno_map_dir)
        self.image_paths = [os.path.join(image_dir, fname) for fname in image_file_list]
        self.anno_map_paths = [os.path.join(anno_map_dir, fname) for fname in anno_map_file_list]

        self.class_names = ['sky', 'building','column/pole', 'road', 'side walk', 
                            'vegetation', 'traffic light', 'fence', 'vehicle', 
                            'pedestrian', 'byciclist', 'void']

    def __len__(self):
        return len(self.image_paths)

    def preprocess(self, img_path, anno_path):
        # reading
        image = torchvision.io.read_image(img_path)
        annotation = torchvision.io.read_image(anno_path)

        # resizing
        image = torchvision.transforms.Resize((self.height, self.width))(image)
        annotation = torchvision.transforms.Resize((self.height, self.width))(annotation).to(torch.int32)

        # Reshape segmentation masks
        stack_list = []
        for c in range(len(self.class_names)):
            mask = torch.eq(b[0, :, :], torch.as_tensor(c, dtype=torch.int32))
            stack_list.append(mask)
        annotation = torch.stack(stack_list, dim=2)
        
        # Normalizing between -1 and 1
        image = image / 127.5
        image -= 1

        return image, annotation


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.to_list()

        image, annotation = self.preprocess(self.image_paths[idx], self.anno_map_paths[idx])

        return image, annotation

In [59]:
BATCH_SIZE = 64

training_dataset = CamVid_Dataset('/tmp/fcnn/dataset1/images_prepped_train/','/tmp/fcnn/dataset1/annotations_prepped_train/')
validation_dataset = CamVid_Dataset('/tmp/fcnn/dataset1/images_prepped_test/','/tmp/fcnn/dataset1/annotations_prepped_test/')

training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [60]:
# generate a list that contains one color for each class
colors = sns.color_palette(None, len(class_names))

# print class name - normalized RGB tuple pairs
# the tuple values will be multiplied by 255 in the helper functions later
# to convert to the (0,0,0) to (255,255,255) RGB values you might be familiar with
for class_name, color in zip(class_names, colors):
  print(f'{class_name} -- {color}')

sky -- (0.12156862745098039, 0.4666666666666667, 0.7058823529411765)
building -- (1.0, 0.4980392156862745, 0.054901960784313725)
column/pole -- (0.17254901960784313, 0.6274509803921569, 0.17254901960784313)
road -- (0.8392156862745098, 0.15294117647058825, 0.1568627450980392)
side walk -- (0.5803921568627451, 0.403921568627451, 0.7411764705882353)
vegetation -- (0.5490196078431373, 0.33725490196078434, 0.29411764705882354)
traffic light -- (0.8901960784313725, 0.4666666666666667, 0.7607843137254902)
fence -- (0.4980392156862745, 0.4980392156862745, 0.4980392156862745)
vehicle -- (0.7372549019607844, 0.7411764705882353, 0.13333333333333333)
pedestrian -- (0.09019607843137255, 0.7450980392156863, 0.8117647058823529)
byciclist -- (0.12156862745098039, 0.4666666666666667, 0.7058823529411765)
void -- (1.0, 0.4980392156862745, 0.054901960784313725)
