# UNet for Vaihingen dataset

## 1. Setup

### 1.1 Install dependencies

Please re-run this as we now need another package (for downloading the data).

In [None]:
import sys
!{sys.executable} -m pip install torch torchvision
!{sys.executable} -m pip install matplotlib
# Nice progress bar
!{sys.executable} -m pip install tqdm                      

# Albumentation for data augmentation
!{sys.executable} -m pip install -U git+https://github.com/albumentations-team/albumentations

# for downloading files from Google drive
!{sys.executable} -m pip install gdown

!{sys.executable} -m pip install --upgrade opencv-python

import glob
import os
from google.colab import drive
drive.mount('/content/gdrive')

### 1.2 Check if GPU available

It is highly recommanded to have access to a GPU as the segmentation requires big computations

In [None]:
import torch

print(torch.cuda.is_available())

True


### 1.3 Random seed

In [None]:
seed = 324533
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## 2. Dataset

For this project the [ISPRS Vaihingen](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-vaihingen/) semantic segmentation dataset will be used.
This is a set of fully-labelled satellite image-segmentation mask pairs, with 9cm resolution and six land cover classes: Impervious, Buildings, Low Vegetation, Tree, Car, Clutter. The images come from a large satellite scene over the town Vaihingen in Germany and were divided into 33 patches, some of which are available with ground truth. As those patches are quite big they have been reduced in multiple 512x512 images.

In [None]:
!gdown --id 1S8oCD1fK4_l2L6lYwuHOHcNtKhRYg19b

!tar -xf vaihingen_512x512_full.tar.gz
data_root = 'dataset_512x512_full'


Downloading...
From: https://drive.google.com/uc?id=1S8oCD1fK4_l2L6lYwuHOHcNtKhRYg19b
To: /content/vaihingen_512x512_full.tar.gz
100% 853M/853M [00:15<00:00, 54.2MB/s]


## 2. Transformations for data augmentation & Dataset Init

As the dataset is not that big, some data augmentation has been made. The images having 5 channels and a segmentation mask, the basic PyTorch augmentations cannot be used. Hence another library which can handle this kind of images has been used. This is the [Albumentation](https://albumentations.ai/) library.

Three set of augmentations exists: 
* One for training which crop the images to 400x400 pixels and apply diverse rotation with a probability of 50 %
* One for validating which only apply a cropping to match sizes
* A final one for the report analysis which does not do any augmentation

In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from albumentations import VerticalFlip, Transpose, HorizontalFlip, RandomRotate90, RandomCrop
crop_height = 400
crop_width = 400
train_transform = A.Compose(
    [
        RandomCrop(height=crop_height, width=crop_width),
        VerticalFlip(p=0.5),
        RandomRotate90(p=0.5),
        Transpose(p=0.5),
        HorizontalFlip(p=0.5),
        ToTensorV2()
    ]
)

test_transform = A.Compose(
    [
        A.RandomCrop(height=crop_height, width=crop_width),
        ToTensorV2()
    ]
)
report_transform = A.Compose(
    [
        ToTensorV2()
    ]
)

Dataset initialisation

In [None]:
import os
import torch
from torch.utils.data import dataset
from torch.utils.data import DataLoader
import torchvision.transforms as T      # transformations that can be used e.g. for data conversion or augmentation
import numpy as np
from PIL import Image


class VaihingenDataset(dataset.Dataset):
    '''
        Custom Dataset class that loads images and ground truth segmentation
        masks from a directory.
    '''

    # image statistics, calculated in advance as averages across the full
    # training data set
    IMAGE_MEANS = (
        (121.03431026287558, 82.52572736507886, 81.92368178210943),     # IR-R-G tiles
        (285.34753853934154),                                           # DSM
        (31.005143030549313)                                            # nDSM
    )
    IMAGE_STDS = (
        (54.21029197978022, 38.434924159900554, 37.040640374137475),    # IR-R-G tiles
        (6.485453035150256),                                            # DSM
        (36.040236155124326)                                            # nDSM
    )


    # label class names
    LABEL_CLASSES = (
        'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'
    )


    def __init__(self, data_root, transform = None):
        '''
            Dataset class constructor. Here we initialize the dataset instance
            and retrieve file names (and other metadata, if present) for all the
            images and labels (ground truth semantic segmentation maps).
        '''
        super().__init__()

        self.data_root = data_root

        # find all images. In our case they are listed in a CSV file called
        # "fileList.csv" under the "data_root"
        with open(os.path.join(self.data_root, 'fileList.csv'), 'r') as f:
            lines = f.readlines()
        self.transform = transform
        
        # parse CSV lines into data tokens: first column is the label file, the
        # remaining ones are the image files
        self.data = []
        for line in lines[1:]:      # skip header
            self.data.append(line.strip().split(','))

    def __len__(self):
        '''
            This function tells the Data Loader how many images there are in
            this dataset.
        '''
        return len(self.data)

    
    def __getitem__(self, idx):
        '''
            Here's where we load, prepare, and convert the images and
            segmentation mask for the data element at the given "idx".
        '''
        item = self.data[idx]

        # load segmentation mask
        labels_true = Image.open(os.path.join(self.data_root, 'labels', item[0]))
        labels = np.array(labels_true, dtype=np.int64)   

        # load all images 
        images = [Image.open(os.path.join(self.data_root, 'images', i)) for i in item[1:]]    

        images2 = images.copy()
        for i in range(len(images)):  
            img = images[i]
            img = np.array(img, dtype=np.float32)                 
            img2 = (img - self.IMAGE_MEANS[i]) / self.IMAGE_STDS[i]      # normalize
            images[i] = img2
            
        
        img_IRRG = np.array(images[0], dtype=np.float32)     
        img_DSM = np.array(images[1], dtype=np.float32)  
        img_nDSM = np.array(images[2],dtype = np.float32)   
        images_all = np.dstack((img_IRRG,img_DSM,img_nDSM))
        label = labels
        aug_images = []
        aug_labels = []
        if self.transform is not None:
            if augmentation_visu:  
              N = 25
            else:
              N = 1
            for k in range(N):
              aug_img = images_all
              aug_label = np.array(labels_true)
              augmented = self.transform(image=aug_img, mask = labels)
              augmented_image = augmented['image']#.permute(2,0,1)
              augmented_label = augmented['mask']

              aug_labels.append(augmented_label)
              aug_images.append(augmented_image)
        else:
          aug_images.append(images_all)
          aug_labels.append(label)


        aug_images_tensor = []
        for i in range(len(aug_images)):
          if isinstance(aug_images[i],torch.Tensor):
            aug_images_tensor.append(aug_images[i].float())
          else:
            aug_images_tensor.append(torch.from_numpy(aug_images[i]).float())
        for i in range(len(aug_images_tensor)): 
          order = aug_images_tensor[i].size(0)
          if order != 5:
            aug_images_tensor[i] = aug_images_tensor[i].permute(2,0,1)

        aug_label_tensor = []
        for i in range(len(aug_labels)):
          if isinstance(aug_labels[i],torch.Tensor):
            aug_label_tensor.append(aug_labels[i])
          else:
            aug_label_tensor.append(torch.from_numpy(aug_labels[i]).long())


        tensors = [T.ToTensor()(i) for i in images]
        tensors = torch.cat(tensors, dim=0).float()         
        
        labels = torch.from_numpy(labels).long()        

        # Decide wehter or not multiple augmented iamges are shown for the user
        if augmentation_visu:
            return aug_images_tensor,aug_label_tensor
        else:
            return aug_images_tensor[0],aug_label_tensor[0]





# we also create a function for the data loader here (see Section 2.6 in Exercise 6)
def load_dataloader(batch_size, split, transform):
  return DataLoader(
      VaihingenDataset(os.path.join(data_root, split), transform),
      batch_size=batch_size,
      shuffle=(split=='train'),       # we shuffle the image order for the training dataset
      num_workers=2                   # perform data loading with two CPU threads
  )

Image visualisation from the dataset

In [None]:
  # visualise
import os
%matplotlib inline

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import albumentations as A
import cv2

#discrete color scheme
cMap = ListedColormap(['black', 'grey', 'lawngreen', 'darkgreen', 'orange', 'red'])     
#  'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'

dataset_train = VaihingenDataset(os.path.join(data_root, 'train'), transform=train_transform)
dataset_visu =VaihingenDataset(os.path.join(data_root, 'val'), transform=report_transform)

# Decide wether or not the augmentations are shown
augmentation_visu = False

# 1 is a good image which contains "complex" scene
aug_data, aug_trg = dataset_visu.__getitem__(1)

visualise = True
if augmentation_visu:
    for i in range(len(aug_data)):
        data_augmented = aug_data[i] 
        label_augmented = aug_trg[i]
        f, axarr = plt.subplots(nrows=1,ncols=3)
        plt.sca(axarr[0]); 
        plt.imshow(data_augmented[:3,...].permute(1,2,0).numpy()); plt.title('NIR-R-G')
        plt.sca(axarr[1]); 
        plt.imshow(data_augmented[3,...].squeeze().numpy()); plt.title('DSM')
        plt.sca(axarr[2]); 
        plt.imshow(data_augmented[4,...].squeeze().numpy()); plt.title('nDSM')
        plt.sca(axarr[3]); 
        cax = plt.imshow(label_augmented.squeeze().numpy(), cmap=cMap)                # target: segmentation mask
        cbar = f.colorbar(cax, ticks=list(range(len(dataset_train.LABEL_CLASSES))))
        cbar.ax.set_yticklabels(list(dataset_train.LABEL_CLASSES))
        plt.title('Target: segmentation mask')
        plt.show()
else:
    data_augmented = aug_data
    label_augmented = aug_trg
    f, axarr = plt.subplots(nrows=1,ncols=4)
    plt.sca(axarr[0]); 
    plt.imshow(data_augmented[:3,...].permute(1,2,0).numpy()); plt.title('NIR-R-G')
    plt.sca(axarr[1]); 
    plt.imshow(data_augmented[3,...].squeeze().numpy()); plt.title('DSM')
    plt.sca(axarr[2]); 
    plt.imshow(data_augmented[4,...].squeeze().numpy()); plt.title('nDSM')
    plt.sca(axarr[3]); 
    cax = plt.imshow(label_augmented.squeeze().numpy(), cmap=cMap)                # target: segmentation mask
    cbar = f.colorbar(cax, ticks=list(range(len(dataset_train.LABEL_CLASSES))))
    cbar.ax.set_yticklabels(list(dataset_train.LABEL_CLASSES))
    plt.title('Target: segmentation mask')
    plt.show()


The following cell is here to compute the class imbalance. It must be done only when first initialising the dataset. For this dataset, the results have been reported further and does not need to be run.
In order to run it *compute_class_imbalance* must be set to True.

In [None]:
compute_class_imbalance = False
if compute_class_imbalance:  
  import os
  %matplotlib inline

  import matplotlib.pyplot as plt
  from matplotlib.colors import ListedColormap
  #import albumentations as A
  import cv2
  from tqdm import tqdm

  #discrete color scheme
  cMap = ListedColormap(['black', 'grey', 'lawngreen', 'darkgreen', 'orange', 'red'])     #  'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'

  dataset_train = VaihingenDataset(os.path.join(data_root, 'train'))

  nbr_class = np.zeros(6)

  for k in range(len(dataset_train)):
    _, aug_trg = dataset_train.__getitem__(k)
    aug_trg.numpy()
    for i in range(np.shape(aug_trg)[0]):
      for j in range(np.shape(aug_trg)[1]):
        label_px = aug_trg[i,j]
        if label_px == 0:
          nbr_class[0] += 1
        elif label_px == 1:
          nbr_class[1] += 1
        elif label_px == 2:
          nbr_class[2] += 1
        elif label_px == 3:
          nbr_class[3] += 1
        elif label_px == 4:
          nbr_class[4] += 1
        elif label_px == 5:
          nbr_class[5] += 1

  # This shows the improtance of each class on the training dataset
  class_imbalance = np.zeros(6)
  print(nbr_class)
  class_imbalance = nbr_class/np.sum(nbr_class)
  print(class_imbalance)

Class imbalance for the chosen dataset and train set. This allows to avoid running the previous cell.

In [None]:
if not compute_class_imbalance:
  class_imbalance = [0.37021928, 0.22978685, 0.18439881, 0.19673866, 0.01098469, 0.0078717]

## 3. Model


More details about the model blocks and implementation can be found on the report.

This part focus on the model selection. The weighted terms meand that the class imbalance are being taken into account when training.

Please select only one of the models.


In [None]:
choose_Unet = True
choose_Unet_weighted = False

## 3.1. UNet Model
Initialise the UNet model's blocks

In [None]:
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        return c10


Checks that the model is alright

In [None]:
dataloader_train = load_dataloader(2, 'train',train_transform)
n_channels = 5  # NIR - R - G - DSM - nDSM
n_classes = 6  #'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'

model = Unet(n_channels,n_classes)
data, _ = iter(dataloader_train).__next__()
pred = model(data)

assert pred.size(1) == len(dataset_train.LABEL_CLASSES), f'ERROR: invalid number of model output channels (should be # classes {len(dataset_train.LABEL_CLASSES)}, got {pred.size(1)})'
assert pred.size(2) == data.size(2), f'ERROR: invalid spatial height of model output (should be {data.size(2)}, got {pred.size(2)})'
assert pred.size(3) == data.size(3), f'ERROR: invalid spatial width of model output (should be {data.size(3)}, got {pred.size(3)})'

## 3.2. Model check

In [None]:
from torchsummary import summary

model_sum  = model.to(device='cuda')

print("Model:",model.__class__.__name__)
summary(model_sum,(5, crop_height, crop_width),device='cuda')
print("Model:",model.__class__.__name__)

Model: GSUnet
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 400, 400]           2,944
       BatchNorm2d-2         [-1, 64, 400, 400]             128
              ReLU-3         [-1, 64, 400, 400]               0
            Conv2d-4         [-1, 64, 400, 400]          36,928
       BatchNorm2d-5         [-1, 64, 400, 400]             128
              ReLU-6         [-1, 64, 400, 400]               0
        DoubleConv-7         [-1, 64, 400, 400]               0
         MaxPool2d-8         [-1, 64, 200, 200]               0
            Conv2d-9        [-1, 128, 200, 200]          73,856
      BatchNorm2d-10        [-1, 128, 200, 200]             256
             ReLU-11        [-1, 128, 200, 200]               0
           Conv2d-12        [-1, 128, 200, 200]         147,584
      BatchNorm2d-13        [-1, 128, 200, 200]             256
             ReLU-14     

## 4. Model training function initialisation

This section trains the chosen model with/without weighted loss function as defined before.

## 4.1. UNet loss functions

In [None]:
# Loss function chose
if choose_Unet_weighted:
  class_imbalance = [0.37021928, 0.22978685, 0.18439881, 0.19673866, 0.01098469, 0.0078717]
  normedWeights = [(1 - x) for x in class_imbalance]
  normedWeights = torch.FloatTensor(normedWeights).to('cuda')

  criterion = nn.CrossEntropyLoss(weight=normedWeights)
else:
  criterion = nn.CrossEntropyLoss() 
  print("no weight")

no weight


## 4.2. General loss function setup

In [None]:
# Sets up the optimiser

from torch.optim import SGD

def setup_optimiser(model, learning_rate, momentum, weight_decay):
  return SGD(
    model.parameters(),
    learning_rate,
    momentum,
    weight_decay
  )

In [None]:
# Training block implementation
from tqdm.notebook import trange 


def train_epoch(data_loader, model, optimiser, device):

  # set model to training mode. This is important because some layers behave differently during training and testing
  model.train(True)
  model.to(device)

  # stats
  loss_total = 0.0
  oa_total = 0.0

  # iterate over dataset
  pBar = trange(len(data_loader))
  for idx, (data, target) in enumerate(data_loader):

    #TODO: implement the training step here. Check the introductory slides if you need help.

    # put data and target onto correct device
    data, target = data.to(device), target.to(device)

    # reset gradients
    optimiser.zero_grad()

    # forward pass
    pred = model(data)

    # loss
    loss = criterion(pred, target)

    # backward pass
    loss.backward()

    # parameter update
    optimiser.step()

    # stats update
    loss_total += loss.item()
    oa_total += torch.mean((pred.argmax(1) == target).float()).item()

    # format progress bar
    pBar.set_description('Loss: {:.2f}, OA: {:.2f}'.format(
      loss_total/(idx+1),
      100 * oa_total/(idx+1)
    ))
    pBar.update(1)
  
  pBar.close()

  # normalise stats
  loss_total /= len(data_loader)
  oa_total /= len(data_loader)

  return model, loss_total, oa_total

In [None]:
def validate_epoch(data_loader, model, device):      

  # set model to evaluation mode
  model.train(False)
  model.to(device)

  # stats
  loss_total = 0.0
  oa_total = 0.0

  # iterate over dataset
  pBar = trange(len(data_loader))
  for idx, (data, target) in enumerate(data_loader):
    with torch.no_grad():


      # put data and target onto correct device
      data, target = data.to(device), target.to(device)

      # forward pass
      pred = model(data)

      # loss
      loss = criterion(pred, target)

      # stats update
      loss_total += loss.item()
      oa_total += torch.mean((pred.argmax(1) == target).float()).item()

      # format progress bar
      pBar.set_description('Loss: {:.2f}, OA: {:.2f}'.format(
        loss_total/(idx+1),
        100 * oa_total/(idx+1)
      ))
      pBar.update(1)

  pBar.close()

  # normalise stats
  loss_total /= len(data_loader)
  oa_total /= len(data_loader)

  return loss_total, oa_total

The following section loads the correct models weights according to the choice of model and weighted loss or not

### 1. UNet with or without weighted Loss

In [None]:
if choose_Unet:
  import glob
  from google.colab import drive
  drive.mount('/content/gdrive')

  os.makedirs('gdrive/MyDrive/SegVai/UNet_final_no_weight', exist_ok=True)

  def load_model(n_channels=5, n_classes=6,epoch='latest'):
    model = Unet(n_channels, n_classes)
    modelStates = glob.glob('gdrive/MyDrive/SegVai/UNet_final_no_weight/*.pth')
    if len(modelStates) and (epoch == 'latest' or epoch > 0):
      modelStates = [int(m.replace('gdrive/MyDrive/SegVai/UNet_final_no_weight/','').replace('.pth', '')) for m in modelStates]
      if epoch == 'latest':
        epoch = max(modelStates)
      stateDict = torch.load(open(f'gdrive/MyDrive/SegVai/UNet_final_no_weight/{epoch}.pth', 'rb'), map_location='cpu')
      model.load_state_dict(stateDict)
    else:
      # fresh model
      epoch = 0
    return model, epoch


  def save_model(model, epoch):
    torch.save(model.state_dict(), open(f'gdrive/MyDrive/SegVai/UNet_final_no_weight/{epoch}.pth', 'wb'))

elif choose_Unet_weighted:
  import glob
  from google.colab import drive
  drive.mount('/content/gdrive')

  os.makedirs('gdrive/MyDrive/SegVai/UNet_final', exist_ok=True)

  def load_model(n_channels=5, n_classes=6,epoch='latest'):
    model = Unet(n_channels, n_classes)
    modelStates = glob.glob('gdrive/MyDrive/SegVai/UNet_final/*.pth')
    if len(modelStates) and (epoch == 'latest' or epoch > 0):
      modelStates = [int(m.replace('gdrive/MyDrive/SegVai/UNet_final/','').replace('.pth', '')) for m in modelStates]
      if epoch == 'latest':
        epoch = max(modelStates)
      stateDict = torch.load(open(f'gdrive/MyDrive/SegVai/UNet_final/{epoch}.pth', 'rb'), map_location='cpu')
      model.load_state_dict(stateDict)
    else:
      # fresh model
      epoch = 0
    return model, epoch


  def save_model(model, epoch):
    torch.save(model.state_dict(), open(f'gdrive/MyDrive/SegVai/UNet_final/{epoch}.pth', 'wb'))

## 5. Training model cell

The next cell determines if we want to train or just visualise resuls.

In [None]:
training = True

In [None]:
from torch import optim
if training:
  # define hyperparameters
  device = 'cuda'
  start_epoch =  0    # set to 0 to start from scratch again or to 'latest' to continue training from saved checkpoint
  batch_size = 2
  learning_rate = 0.1
  momentum = 0.5
  weight_decay = 0.001
  num_epochs = 70
  n_channels = 5  # NIR - R - G - DSM - nDSM
  n_classes = 6  #'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'



  # initialise data loaders
  dl_train = load_dataloader(batch_size, 'train',train_transform)

  dl_val = load_dataloader(batch_size, 'val',test_transform)

  # load model

  model, epoch = load_model(n_channels, n_classes,epoch=start_epoch)
  optimi = setup_optimiser(model, learning_rate, momentum, weight_decay)

  # do epochs
  while epoch < num_epochs:

    # training
    model, loss_train, oa_train = train_epoch(dl_train, model, optimi, device)

    # validation
    loss_val, oa_val = validate_epoch(dl_val, model, device)

    # print stats
    print('[Ep. {}/{}] Loss train: {:.2f}, val: {:.2f}; OA train: {:.2f}, val: {:.2f}'.format(
        epoch+1, num_epochs,
        loss_train, loss_val,
        100*oa_train, 100*oa_val
    ))

    # save model
    epoch += 1
    if epoch % 10 == 0:
      save_model(model, epoch)

## 6. Model prediction & Metrics evaluation


### 1. Basic function implementation for visualisation and Metrics computations

The following function implements the diverse metrics used. Respectively: 
F-Score, Cohen's kappa, Overall Accuracy, User's Accuracy, Producer's Accuracy, Intersection over Union, recall and precision


In [None]:
from math import isnan
import torch
import numpy as np
from sklearn.metrics import confusion_matrix

def compute_metrics(cm):
    '''
    Adapted from:
        https://github.com/davidtvs/PyTorch-ENet/blob/master/metric/iou.py
        https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/metrics.py#L2716-L2844
    '''
    
    sum_over_row = cm.sum(axis=0)
    sum_over_col = cm.sum(axis=1)
    true_positives = np.diag(cm)

    # sum_over_row + sum_over_col = 2 * true_positives + false_positives + false_negatives.
    denominator = sum_over_row + sum_over_col - true_positives
    
    iou = true_positives / denominator

    precision = true_positives/sum_over_col

    recall = true_positives/sum_over_row
    
    f1 = 2* (precision*recall)/(precision + recall)
    for l in range(len(f1)):
      if math.isnan(f1[l]):
        f1[l] = 0.00
      if math.isnan(iou[l]):
        iou[l] = 0.00
    N = sum(sum_over_row)
    tmp = 0
    for i in range(len(sum_over_row)):
      tmp += sum_over_row[i]*sum_over_col[i]
    nominator = N *(sum(true_positives.flatten())) - tmp
    denominator = N*N - tmp
    kappa = (nominator)/(denominator)

    PA = true_positives/sum_over_row
    UA = true_positives/sum_over_col
    for l in range(len(PA)):
      if math.isnan(PA[l]):
        PA[l] = 0.00
      if math.isnan(UA[l]):
        UA[l] = 0.00
    OA = sum(true_positives)/N
    return iou, recall, precision, f1, kappa, OA,PA, UA

In [None]:
def check_cmap(pred, old_cmap):
  new_cmap = old_cmap.copy()
  uniqueness = np.unique(pred.cpu().numpy())
  for k in range(len(new_cmap)):
    if k not in uniqueness:
        new_cmap.remove(old_cmap[k])
  return new_cmap

### 2. Model output visualisation and metrics evaluation


#### 1. UNet without weighted Loss

In [None]:
if choose_Unet:
  device = 'cuda'
  import warnings
  %matplotlib inline
  import matplotlib.pyplot as plt
  import math
  import pandas as pd
  from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
  base_cmp = ['black', 'grey', 'lawngreen', 'darkgreen', 'orange', 'red']
  cMap = ListedColormap(base_cmp)     #  'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'
  warnings.filterwarnings("ignore")

  def visualize(dataLoader,n_channels, n_classes, epochs, show_metrics = False, numImages=5):
    models = [load_model(n_channels, n_classes,e)[0] for e in epochs]
    numModels = len(models)
    for idx, (data, labels) in enumerate(dataLoader):
      list_gt_labels = []
      if idx == numImages:
        break
      if idx == 0 or idx == 1:
        continue

      f, ax = plt.subplots(nrows=1, ncols=numModels+1, figsize = (10, 10))

      list_gt_labels.append(labels[0,...].cpu().numpy().flatten())

      # plot ground truth
      plt.sca(ax[0]);
      cax = plt.imshow(labels.squeeze().cpu().numpy(), cmap=cMap2)
      cbar = f.colorbar(cax, ticks=list(range(6)),orientation='horizontal')
      cbar.ax.set_xticklabels(['Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'],rotation ='45')
      plt.title('Ground Truth') 
      
      
      ax[0].axis('off')
      if idx == 0:
        ax[0].set_title('Ground Truth')
      conf_matrix = []
      accuracy = []
      for mIdx, model in enumerate(models):
        list_predictions = []
        model = model.to(device)
        
        with torch.no_grad():
          pred = model(data.to(device))

          # get the prediction label 
          yhat = torch.argmax(pred, dim=1)

          list_predictions.append(yhat[0,...].cpu().numpy().flatten())     
          all_predictions = np.concatenate(list_predictions)
          all_gt_labels = np.concatenate(list_gt_labels)
          accuracy.append(accuracy_score(all_gt_labels, all_predictions))
          conf_matrix.append(confusion_matrix(all_gt_labels, all_predictions,labels=[0,1, 2, 3,4,5]))    
            
          new_cmap = check_cmap(yhat, base_cmp)
          plt.sca(ax[mIdx+1]); 
          cax = plt.imshow(yhat[0,...].cpu().numpy(), cmap=ListedColormap(new_cmap))


      # Define if the confusion matrix, and all the metrics are shown or only the model prediction
      if show_metrics:
        if numModels == 1:
          _, ax = plt.subplots(nrows=1, ncols=numModels, figsize = (8, 8))
        else:
          _, ax = plt.subplots(nrows=1, ncols=numModels, figsize = (15, 15))
        for mIdx, model in enumerate(models):
          conf_matrix_one = conf_matrix[mIdx]
          if numModels == 1:
            ax.matshow(conf_matrix_one, cmap=plt.cm.Blues, alpha=0.5)
          else:
            ax[mIdx].matshow(conf_matrix_one, cmap=plt.cm.Blues, alpha=0.5)

          # Metrics computation
          iou, recall, precision, f1, kappa, OA,PA, UA = compute_metrics(conf_matrix_one)
          
          
          for i in range(conf_matrix_one.shape[0]):
            for j in range(conf_matrix_one.shape[1]):
              if math.isnan(conf_matrix_one[i,j]):
                conf_matrix_one[i,j] = 0
              if numModels == 1:
                ax.text(x=j, y=i,s=conf_matrix_one[i, j], va='center', ha='center', size='x-large')
              else:
                ax[mIdx].text(x=j, y=i,s=conf_matrix_one[i, j], va='center', ha='center', size='x-large')
            
            if numModels == 1:
              ax.set_xlabel('Predictions', fontsize=18)
              ax.set_ylabel('Ground Truth', fontsize=18)
              ax.set_title('Confusion Matrix', fontsize=18)
            else:
              ax[mIdx].set_xlabel('Predictions', fontsize=18)
              ax[mIdx].set_ylabel('Ground Truth', fontsize=18)
              ax[mIdx].set_title('Confusion Matrix', fontsize=18)
          if idx == 2 and (epochs[mIdx] == 'latest' or epochs[mIdx] == 30):
                print("Epoch:",epochs[mIdx])
                print("F1",f1)
                print("IoU",iou)
                print("Kappa",kappa)
                print("OA",OA)
                print("UA",UA)
                print("PA",PA)
                print("Mean FScore",sum(f1)/len(f1))
                print("Mean IoU",sum(iou)/len(iou))
                print("Mean UA",sum(UA)/len(UA))
                print("Mean PA",sum(PA)/len(PA))
                print("===========================")
        
  # visualize predictions. For the sake of comparison, the image 1 of the validation set as been chosen as it contains all classes
  dl_val_single = load_dataloader(1, 'val',report_transform)

  # load model states at different epochs
  epochs = [30]                                         
  n_channels = 5  # NIR - R - G - DSM - nDSM
  n_classes = 6  #'Impervious', 'Buildings', ' Low Vegetation', 'Tree', 'Car', 'Clutter'

  visualize(dl_val_single, n_channels, n_classes,epochs, show_metrics = True, numImages=3)

#### 2. UNet with weighted Loss

In [None]:
if choose_Unet_weighted:
  device = 'cuda'
  import warnings
  %matplotlib inline
  import matplotlib.pyplot as plt
  import math
  import pandas as pd
  from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
  base_cmp = ['black', 'grey', 'lawngreen', 'darkgreen', 'orange', 'red']
  cMap = ListedColormap(base_cmp)     #  'Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'
  warnings.filterwarnings("ignore")

  def visualize(dataLoader,n_channels, n_classes, epochs, show_metrics = False, numImages=5):
    models = [load_model(n_channels, n_classes,e)[0] for e in epochs]
    numModels = len(models)
    for idx, (data, labels) in enumerate(dataLoader):
      list_gt_labels = []
      if idx == numImages:
        break
      if idx == 0 or idx == 1:
        continue

      f, ax = plt.subplots(nrows=1, ncols=numModels+1, figsize = (10, 10))

      list_gt_labels.append(labels[0,...].cpu().numpy().flatten())

      # plot ground truth
      plt.sca(ax[0]);
      cax = plt.imshow(labels.squeeze().cpu().numpy(), cmap=cMap)
      cbar = f.colorbar(cax, ticks=list(range(6)))
      cbar.ax.set_yticklabels(['Impervious', 'Buildings', 'Low Vegetation', 'Tree', 'Car', 'Clutter'])
      plt.title('Ground Truth')   
      #ax[0].imshow(labels[0,...].cpu().numpy(), cmap=cMap)
      
      
      ax[0].axis('off')
      if idx == 0:
        ax[0].set_title('Ground Truth')
      conf_matrix = []
      accuracy = []
      for mIdx, model in enumerate(models):
        list_predictions = []
        model = model.to(device)
        
        with torch.no_grad():
          pred = model(data.to(device))
          # get the label (i.e., the maximum position for each pixel along the class dimension)
          yhat = torch.argmax(pred, dim=1)

          list_predictions.append(yhat[0,...].cpu().numpy().flatten())     
          all_predictions = np.concatenate(list_predictions)
          all_gt_labels = np.concatenate(list_gt_labels)
          accuracy.append(accuracy_score(all_gt_labels, all_predictions))
          conf_matrix.append(confusion_matrix(all_gt_labels, all_predictions,labels=[0,1, 2, 3,4,5]))    
            
          new_cmap = check_cmap(yhat, base_cmp)
          plt.sca(ax[mIdx+1]); 
          cax = plt.imshow(yhat[0,...].cpu().numpy(), cmap=ListedColormap(new_cmap))     # target: segmentation mask



      if show_metrics:
        if numModels == 1:
          _, ax = plt.subplots(nrows=1, ncols=numModels, figsize = (8, 8))
        else:
          _, ax = plt.subplots(nrows=1, ncols=numModels, figsize = (15, 15))
        for mIdx, model in enumerate(models):
          conf_matrix_one = conf_matrix[mIdx]
          if numModels == 1:
            ax.matshow(conf_matrix_one, cmap=plt.cm.Blues, alpha=0.5)
          else:
            ax[mIdx].matshow(conf_matrix_one, cmap=plt.cm.Blues, alpha=0.5)

          

          iou, recall, precision, f1, kappa, OA,PA, UA = compute_metrics(conf_matrix_one)
          
          
          for i in range(conf_matrix_one.shape[0]):
            for j in range(conf_matrix_one.shape[1]):
              if math.isnan(conf_matrix_one[i,j]):
                conf_matrix_one[i,j] = 0
              if numModels == 1:
                ax.text(x=j, y=i,s=conf_matrix_one[i, j], va='center', ha='center', size='x-large')
              else:
                ax[mIdx].text(x=j, y=i,s=conf_matrix_one[i, j], va='center', ha='center', size='x-large')
            
            if numModels == 1:
              ax.set_xlabel('Predictions', fontsize=18)
              ax.set_ylabel('Ground Truth', fontsize=18)
              ax.set_title('Confusion Matrix', fontsize=18)
            else:
              ax[mIdx].set_xlabel('Predictions', fontsize=18)
              ax[mIdx].set_ylabel('Ground Truth', fontsize=18)
              ax[mIdx].set_title('Confusion Matrix', fontsize=18)
          if idx == 2 and (epochs[mIdx] == 'latest' or epochs[mIdx] == 30):
                print("Epoch:",epochs[mIdx])
                print("F1",f1)
                print("IoU",iou)
                print("Kappa",kappa)
                print("OA",OA)
                print("UA",UA)
                print("PA",PA)
                print("Mean FScore",sum(f1)/len(f1))
                print("Mean IoU",sum(iou)/len(iou))
                print("Mean UA",sum(UA)/len(UA))
                print("Mean PA",sum(PA)/len(PA))
                print("===========================")
        
  # visualize predictions for a number of epochs
  dl_val_single = load_dataloader(1, 'val',report_transform)

  # load model states at different epochs
  epochs = [30]                                          #TODO: modify this vector according to your wishes, resp. for how many model states you have trained
  n_channels = 5  # NIR - R - G - DSM - nDSM
  n_classes = 6  #'Impervious', 'Buildings', ' Low Vegetation', 'Tree', 'Car', 'Clutter'

  visualize(dl_val_single, n_channels, n_classes,epochs, show_metrics = True, numImages=3)