<a href="https://colab.research.google.com/github/VladMorarK19032334/Project24/blob/main/CTR_Seg%26KeypointLocal_ExperimentalEnvi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Libraries

In [None]:
if MODEL_AUGMENTATION == 'complex': # due to piecewise affine => different import is required
  !python -m pip install --upgrade opencv-contrib-python
  !pip uninstall opencv-python
  !pip install git+https://github.com/albumentations-team/albumentations
  !pip install opencv-python
else:
  !pip install albumentations==0.4.6

  import albumentations as A
  from albumentations.pytorch import ToTensorV2

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import numpy as np
import os
from torch.utils.data import Dataset
from PIL import Image
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch
import torchvision
from torch.utils.data import DataLoader

import cv2

import statistics as stat
import math

from IPython.display import clear_output
import matplotlib.pyplot as plt

import warnings

Experimental Parameters - to decide what experiment to run for

In [None]:

# True - the loss has weights on classes, False - runs the segmentation for UNet with no weights
# available for multi-learning as well
WEIGHTED_UNET = True 

'''
Multi-learning Network with Segmentation Feedback propagation			MultiSFPNet
Double Headed UNet			                                          DH-Unet
Multi-Learning Network Non-Connected Output		  	                MultiNCONet
'''
EXPERIMENTAL_MODEL = 'MultiSFPNet' # select the model to run

'''
none			No augmentation
full			The full augmentation used in the latest models (more complex with no piecewise)
complex		Full augmentation with Piecewise Affine transformation
'''
MODEL_AUGMENTATION = 'none'


# select training dataset
LARGE_DATASET = False
LARGE_IMAGESIZE = True


# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 5
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 480  # original size: 480px
IMAGE_WIDTH = 640  # original size: 640px
PIN_MEMORY = True
LOAD_MODEL = False

# segmentation weight loss
SEG_LOSS_WEIGHT = 1e+4 # previously was 1e+4
SEG_LOSS_WEIGHT_MIN = 1 # minimum value of loss
SEG_LOSS_WEIGHT_FUNCTION = 'step' # step, linear, exp

# weight for class segmentation in cross entropy loss
CLASS_WEIGHTS = torch.tensor([0.5, 2.0, 1.0, 1.0]) # [bg, tip, middle, base]


# form keypoint normalization
KEYPOINT_NORM = False 




if LARGE_DATASET == True: 
  TRAIN_IMG_DIR = ""
  TRAIN_MASK_DIR = ""
  TRAIN_LOCAL_DIR = ""
else:
  TRAIN_IMG_DIR = ""
  TRAIN_MASK_DIR = ""
  TRAIN_LOCAL_DIR = ""


VAL_IMAGE_DIR = ""
VAL_MASK_DIR = ""
VAL_LOCAL_DIR = ""

# datasets for saving all necessary visual results
GLOBAL_FOLDER = "" # location of envi savings
EXPERIMENT_NAME = ""
SAVE_FOLDER = f'{GLOBAL_FOLDER}/{EXPERIMENT_NAME}'


Main Runner

In [None]:
# main algorithm run function

def main():
  _saved_image_index = 0
  # select segmentation properties
  if MODEL_AUGMENTATION == 'full': # full augmentation
      train_transform = A.Compose(
          [
              A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
              A.Rotate(limit=90, p=1.0),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.25),
              A.Normalize(
                  mean=[0.0, 0.0, 0.0],
                  std=[1.0, 1.0, 1.0],
                  max_pixel_value=255.0,
              ),
              ToTensorV2(),
          ],
          keypoint_params=A.KeypointParams(format='xy',remove_invisible=False)
      )

      val_transforms = A.Compose(
          [
              A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
              A.Normalize(
                  mean=[0.0, 0.0, 0.0],
                  std=[1.0, 1.0, 1.0],
                  max_pixel_value=255.0, # divide by 255
              ),
              ToTensorV2(),
          ],
          keypoint_params=A.KeypointParams(format='xy',remove_invisible=False)
      )
  elif MODEL_AUGMENTATION == 'complex': # complex augmentation
    train_transform = A.Compose(
          [
              A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
              A.Rotate(limit=90, p=1.0),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.25),
              A.augmentations.geometric.transforms.PiecewiseAffine(scale=(0.03, 0.05), nb_rows=8, nb_cols=8, 
                                interpolation=1, mask_interpolation=0, cval=0, 
                                cval_mask=0, mode='constant', absolute_scale=False, 
                                always_apply=False, keypoints_threshold=0.01, p=0.2), # bilinear interpolation image, nearest neighbour for mask
              A.Normalize(
                  mean=[0.0, 0.0, 0.0],
                  std=[1.0, 1.0, 1.0],
                  max_pixel_value=255.0,
              ),
              ToTensorV2(),
          ],
          keypoint_params=A.KeypointParams(format='xy',remove_invisible=False)
      )

    val_transforms = A.Compose(
          [
              A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
              A.Normalize(
                  mean=[0.0, 0.0, 0.0],
                  std=[1.0, 1.0, 1.0],
                  max_pixel_value=255.0, # divide by 255
              ),
              ToTensorV2(),
          ],
          keypoint_params=A.KeypointParams(format='xy',remove_invisible=False)
      )
  else: # none augmentation
    train_transform = A.Compose(
          [
              A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
              A.Normalize(
                  mean=[0.0, 0.0, 0.0],
                  std=[1.0, 1.0, 1.0],
                  max_pixel_value=255.0,
              ),
              ToTensorV2(),
          ],
          keypoint_params=A.KeypointParams(format='xy',remove_invisible=False)
      )

    val_transforms = A.Compose(
          [
              A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
              A.Normalize(
                  mean=[0.0, 0.0, 0.0],
                  std=[1.0, 1.0, 1.0],
                  max_pixel_value=255.0, # divide by 255
              ),
              ToTensorV2(),
          ],
          keypoint_params=A.KeypointParams(format='xy',remove_invisible=False)
      )
    

  model = UNET(in_channels=3, out_channels=4, out_channels_local=8).to(DEVICE)
  
  # select segmentation properties
  if WEIGHTED_UNET: # weighted classes
    loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS.to(DEVICE))
  else:
    loss_fn = nn.CrossEntropyLoss()

  loss_local = nn.MSELoss() # for keypoints
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

  train_loader, val_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      TRAIN_LOCAL_DIR,
      VAL_IMAGE_DIR,
      VAL_MASK_DIR,
      VAL_LOCAL_DIR,
      BATCH_SIZE,
      train_transform,
      val_transforms,
      NUM_WORKERS,
      PIN_MEMORY,
  )

  if LOAD_MODEL:
      load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
  
  accuracy(val_loader, model, device=DEVICE)
  scaler = torch.cuda.amp.GradScaler()

  for epoch in range(NUM_EPOCHS):

      seglossfactor(SEG_LOSS_WEIGHT_FUNCTION, SEG_LOSS_WEIGHT_MIN, epoch)


      print(f'EPOCH: {epoch+0}')
      torch.cuda.empty_cache() # empty cuda cache
      train_fn(train_loader, model, optimizer, loss_fn, loss_local, scaler)

      #save_loss_propagation(loss_arr, epochs=1)

      # save model
      checkpoint = {
          "state_dict": model.state_dict(),
          "optimizer": optimizer.state_dict(),
      }
      save_checkpoint(checkpoint)

      # check accuracy
      accuracy(val_loader, model, device=DEVICE)

      # print examples every 5 epochs
      if _saved_image_index%10 == 0 or _saved_image_index == 0 or _saved_image_index == 99:
        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, _saved_image_index, SAVE_FOLDER, device=DEVICE
        )

      _saved_image_index += 1
        

#torch.backends.cudnn.enabled = False
# train model
main()
save_loss_propagation(loss_arr)


Training

In [None]:
loss_arr = []

loss_arr_epoch = []

# does 1 Epoch of training
def train_fn(loader, model, optimizer, loss_fn, loss_local, scaler):
    loop = tqdm(loader) # progress bar

    seg_loss_arr = []
    local_loss_arr = []

    loop_idx = 0 # how many time you went through data (for duplicated datasets)


    for batch_idx, (data, mask_target, keypoints_target) in enumerate(loop):
        data = data.to(device=DEVICE)
        mask_target = mask_target.long().to(device=DEVICE)
        keypoints_target = keypoints_target.float().to(device=DEVICE)

        if batch_idx%32 == 0 and batch_idx!=0: # check if run through all unique data
          loop_idx += 1

        # forward
        losses = []
        with torch.cuda.amp.autocast():
            [output_mask, output_keypoints] = model(data)
            #targets = targets.long().to(device=DEVICE)
            loss1 = loss_fn(output_mask, mask_target)
            loss2 = loss_local(output_keypoints.float(), keypoints_target)

            if batch_idx > 14 + loop_idx*32: # if extra set available => cancel effects of segmentation loss as no mask is available (This way works only for the non-duplicated dataset)
              losses.append(0*loss1)
            else:
              losses.append(SEG_LOSS_WEIGHT*loss1)

            losses.append(loss2)
            loss = sum(losses)

            loss_arr_epoch.append(loss.detach().cpu())

            # append to separate arrays
            seg_loss_arr.append(float(SEG_LOSS_WEIGHT*loss1))
            local_loss_arr.append(float(loss2))
            

        # backward'
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


    loss_arr.append(float(sum(loss_arr_epoch))/len(loss_arr_epoch))
    # print losses
    print(f"\n The segmentation loss is: {stat.mean(seg_loss_arr)}")
    print(f"The localization loss is: {stat.mean(local_loss_arr)}\n")

Segmentation Loss Factor

In [None]:
'''
format - 'step', 'linear', 'exp'
       - decides on seg loss decay factor
'''
def seglossfactor(format, min, epoch, seg=SEG_LOSS_WEIGHT):
  
  # step function
  if format == 'step':
      # set weight in terms of epoch from focus on seg to focus on local
      if epoch == 10:
        #SEG_LOSS_WEIGHT = 1e+4
        pass
      elif epoch == 20:
        #SEG_LOSS_WEIGHT = 100
        SEG_LOSS_WEIGHT = 10
      elif epoch == 40: # after 40 epoch segmentation of CTR to BG is pretty well done => balance
        #SEG_LOSS_WEIGHT = 10
        SEG_LOSS_WEIGHT = min 
      elif epoch == 60: # weight to localization
        #SEG_LOSS_WEIGHT = 10
        pass


  # linear function: y = 1/(epoch+1)^2 * x + min
  elif format == 'linear':
      print(1+abs(epoch))
      a = 1/((1+abs(epoch))*(1+abs(epoch)))
      SEG_LOSS_WEIGHT = a* seg + min

  # exponential function
  elif format == 'exp':
      w = (1-abs(epoch))
      SEG_LOSS_WEIGHT = math.exp(w*seg + min)

  # fail case
  else:
      warnings.warn(f'NO SEG LOSS FACTOR FUNCTION DETECTED')
      

Save loss from experiment run - MAY BE USED COMPARING LOSS PROPAGATION BETWEEN Networks

In [None]:
def save_loss_propagation(loss_arr, title='Test Loss', epochs=100,save_folder=SAVE_FOLDER):
  # make loss array as numpy
  loss = np.asarray(loss_arr)
  # y is the number of epochs
  y = np.arange(epochs)

  # plot settings
  plt.title(f"{title}") 
  plt.xlabel("loss") 
  plt.ylabel("Epoch") 
  plt.plot(loss,y) 

  # show plot of loss
  plt.show()

  # save plot of loss to folder
  plt.savefig(f'{SAVE_FOLDER}/loss_propagation/{title}.png')

Segmentation Algorithm and Multi-learning Algorithm

In [None]:

if EXPERIMENTAL_MODEL == 'MultiSFPNet': # Multi-learning Network with Segmentation Feedback propagation
  class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

  # Double headed UNET
  class UNET(nn.Module):
    # output 3 channels (3 tubes)
    def __init__(
            self, in_channels=3, out_channels=1, out_channels_local=2, features=[64, 128, 256, 512], features_local=[64, 128, 256]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.downs_local = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.features = features
        self.features_local = features_local
        self.conv1x1 = nn.ModuleList()

        self.tb = nn.ModuleList() # transfer block => fusing module for segmentation and localization information
        self.deco_local = nn.ModuleList()

        # Down part of UNET (downsampling)
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        in_channels = 3

        for feature in features_local:
            self.downs_local.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in features_local: # avoid last two features
            # avoid last feature
            if (feature == features_local[-1]):
              break
            self.conv1x1.append(nn.Conv2d(feature, features_local[-1], 1))

        # Up part of UNET (upsampling)
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))


        # transfer block
        for feature in reversed(features_local):
          if feature == 128: # only for 256 and 64 channels
            continue
          self.tb.append(
              nn.ConvTranspose2d(
                    feature, feature, kernel_size=2, stride=2
                )
          )
          self.tb.append(DoubleConv(feature*2, feature))

        # decoding to 64 channels from 256 channels
        self.deco_local.append(
            nn.Conv2d(features_local[-1], features_local[0], 3, 1, 1, bias=False))
        self.deco_local.append(
            nn.BatchNorm2d(features_local[0]))
        self.deco_local.append(
            nn.ReLU(inplace=True))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # final 2d layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        # double convolutions
        self.seg_to_concat = DoubleConv(self.features[0], self.features_local[-1])
        self.final_conv_local = DoubleConv(self.features_local[-1]*2, self.features_local[0])

        # final 2d layer for localization (fully connected layer)
        if LARGE_IMAGESIZE == True:
          self.fc1 = nn.Linear(64*60*80, 32)
        else:
          self.fc1 = nn.Linear(64*15*20, 32)
        self.fc2 = nn.Linear(32, out_channels_local)

    def forward(self, x):
        skip_connections = []
        maps = [] # for fusing localization

        #x_local = torch.clone(x.detach()) # localization
        x_local = torch.clone(x) # localization

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        # second encoder for localization
        for down in self.downs_local:
            x_local = down(x_local)
            x_local = self.pool(x_local)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # reverse list

        # decoding part
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2] # integer division

            # check if their shapes does not match => makes it generalize
            if x.shape != skip_connection.shape:
                # do resizing
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

            #maps.append(x.detach()) # append map for usage feature channels [512, 256, 128, 64]
            maps.append(x) # append map for usage feature channels [512, 256, 128, 64]

        

        # take maps and use transfer module to fuse with localization features
        # afterwards do addition fusing of transfer module output with input from localization
        for idx in range(0, len(self.tb), 2):
            # clone input from localization to be used for fusing
            local_input = torch.clone(x_local)

            # apply trans conv on localization part
            x_local = self.tb[idx](x_local)

            # concatanate with seg input
            map = maps[idx+1] # integer division # avoid working with 512, 128 channel features

            # check if their shapes does not match => makes it generalize
            if x_local.shape != map.shape:
                # do resizing
                x_local = TF.resize(x_local, size=map.shape[2:])

            concat_map = torch.cat((map, x_local), dim=1)

            # double conv
            x_local = self.tb[idx+1](concat_map)

            # resize for addition module
            if x_local.shape != local_input.shape:
                # do resizing
                x_local = TF.resize(x_local, size=local_input.shape[2:])

            # fusing by addition of input from localization with output of tb
            x_local = torch.add(x_local, local_input)

            # only for 256 channel output of calculation => upsampled to 64
            if idx == 0:
              for oper_idx in range(len(self.deco_local)):
                x_local = self.deco_local[oper_idx](x_local)

        #print(x_local.shape)
        if LARGE_IMAGESIZE == True:
          x_local = x_local.view(-1, 64*60*80)
        else:
          x_local = x_local.view(-1, 64*15*20)
        

        return [self.final_conv(x), self.fc2(self.fc1(x_local))]

elif EXPERIMENTAL_MODEL == 'MultiNCONet': # Multi-Learning Network Non-Connected Output
  
  class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

  # Double headed UNET
  class UNET(nn.Module):
    # output 3 channels (3 tubes)
    def __init__(
            self, in_channels=3, out_channels=1, out_channels_local=2, features=[64, 128, 256, 512], features_local=[64, 128, 256]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.downs_local = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.features = features
        self.features_local = features_local
        self.conv1x1 = nn.ModuleList()

        # Down part of UNET (downsampling)
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        in_channels = 3

        for feature in features_local:
            self.downs_local.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in features_local: # avoid last two features
            # avoid last feature
            if (feature == features_local[-1]):
              break
            self.conv1x1.append(nn.Conv2d(feature, features_local[-1], 1))

        # Up part of UNET (upsampling)
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # final 2d layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        # double convolutions
        self.seg_to_concat = DoubleConv(self.features[0], self.features_local[-1])
        self.final_conv_local = DoubleConv(self.features_local[-1]*2, self.features_local[0])

        # final 2d layer for localization (fully connected layer)
        self.fc1 = nn.Linear(64*60*80, 32)
        self.fc2 = nn.Linear(32, out_channels_local)

    def forward(self, x):
        skip_connections = []
        maps = [] # for fusing localization

        #x_local = torch.clone(x) # localization

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # reverse list

        # decoding part
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2] # integer division

            # check if their shapes does not match => makes it generalize
            if x.shape != skip_connection.shape:
                # do resizing
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        #print(x.shape)

        #x_local = torch.clone(x).view(-1, 64*480*640)


        return [self.final_conv(x), self.fc2(self.fc1(TF.resize(x, size=(60, 80)).view(-1, 64*60*80)))]

elif EXPERIMENTAL_MODEL == 'DH-Unet': # Double Headed UNet
  
    class DoubleConv(nn.Module):
      def __init__(self, in_channels, out_channels):
          super(DoubleConv, self).__init__()
          self.conv = nn.Sequential(
              nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True),
              nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)
          )

      def forward(self, x):
          return self.conv(x)

    # Double headed UNET
    class UNET(nn.Module):
      # output 3 channels (3 tubes)
      def __init__(
              self, in_channels=3, out_channels=1, out_channels_local=2, features=[64, 128, 256, 512], features_local=[64, 128, 256]):
          super(UNET, self).__init__()
          self.ups = nn.ModuleList()
          self.downs = nn.ModuleList()
          self.downs_local = nn.ModuleList()
          self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
          self.features = features
          self.features_local = features_local
          self.conv1x1 = nn.ModuleList()

          # Down part of UNET (downsampling)
          for feature in features:
              self.downs.append(DoubleConv(in_channels, feature))
              in_channels = feature

          in_channels = 3

          for feature in features_local:
              self.downs_local.append(DoubleConv(in_channels, feature))
              in_channels = feature

          for feature in features_local: # avoid last two features
              # avoid last feature
              if (feature == features_local[-1]):
                break
              self.conv1x1.append(nn.Conv2d(feature, features_local[-1], 1))

          # Up part of UNET (upsampling)
          for feature in reversed(features):
              self.ups.append(
                  nn.ConvTranspose2d(
                      feature*2, feature, kernel_size=2, stride=2
                  )
              )
              self.ups.append(DoubleConv(feature*2, feature))

          self.bottleneck = DoubleConv(features[-1], features[-1]*2)

          # final 2d layer
          self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

          # double convolutions
          self.seg_to_concat = DoubleConv(self.features[0], self.features_local[-1])
          self.final_conv_local = DoubleConv(self.features_local[-1]*2, self.features_local[0])

          # final 2d layer for localization (fully connected layer)
          self.fc1 = nn.Linear(64*60*80, 32)
          self.fc2 = nn.Linear(32, out_channels_local)

      def forward(self, x):
          skip_connections = []
          maps = [] # for fusing localization

          x_local = torch.clone(x) # localization

          for down in self.downs:
              x = down(x)
              skip_connections.append(x)
              x = self.pool(x)
          
          # second encoder for localization
          for down in self.downs_local:
              x_local = down(x_local)
              maps.append(x_local)
              x_local = self.pool(x_local)

          x = self.bottleneck(x)
          skip_connections = skip_connections[::-1] # reverse list

          # decoding part
          for idx in range(0, len(self.ups), 2):
              x = self.ups[idx](x)
              skip_connection = skip_connections[idx//2] # integer division

              # check if their shapes does not match => makes it generalize
              if x.shape != skip_connection.shape:
                  # do resizing
                  x = TF.resize(x, size=skip_connection.shape[2:])

              concat_skip = torch.cat((skip_connection, x), dim=1)
              x = self.ups[idx+1](concat_skip)

          #x_seg = torch.clone(x)

          # Second head for localization
          # add encoder output to fused map
          fused_map = torch.clone(x_local) # clone x_local input
          output_seg = torch.clone(x) # clone x input

          for idx in range(len(maps)-1):
              #print(idx)
              map = maps[idx] # get one map

              map = self.pool(map)
              map = self.conv1x1[idx](map)

              # change shape to x size
              if x_local.shape != map.shape:
                  #print(map.shape)
                  # do resizing on mapping to match x
                  map = TF.resize(map, size=x_local.shape[2:])

              # element wise addition to fused map
              #print(fused_map.shape)
              #print(map.shape)
              fused_map = torch.add(fused_map, map)

          # reset x to fused map
          x_local = torch.clone(fused_map)

          ## concatination of fused map with pre prediction output of segmentation

          # downsample output of x (64 feature channels) in 256 feature channels (fused module channel size)
          seg_pred_map = self.seg_to_concat(output_seg)

          #print('seg_pred_map')
          
          # do concatination of fused map with pre prediction output of segmentation
          # check if their shapes does not match => makes it generalize
          if x_local.shape != seg_pred_map.shape:
              # do resizing
              #x_local = TF.resize(x_local, size=seg_pred_map.shape[2:]) #resize on output of localization
              seg_pred_map = TF.resize(seg_pred_map, size=x_local.shape[2:]) #resize on output of segmentation

          #print(seg_pred_map.shape)
          #print(x_local.shape)

          concat_local = torch.cat((seg_pred_map, x_local), dim=1)
          #print(concat_local.shape)
          # set output to concatination and do double conv to 64 channels
          x_local = self.final_conv_local(concat_local)

          #print(x_local.shape)
          x_local = x_local.view(-1, 64*60*80)

          return [self.final_conv(x), self.fc2(self.fc1(x_local))]

else:
  warnings.warn(f'NO MODEL WAS DETECTED WITH THE NAME: {EXPERIMENTAL_MODEL}')
  warnings.warn(f'The available models are: MultiSFPNet  DH-Unet  MultiNCONet') 


Dataset

In [None]:

class CTRDatasetNpy(Dataset):
    def __init__(self, images_file, masks_file, locals_file, transform=None):
        self.images_file = images_file
        self.masks_file = masks_file
        self.keypoints_file = locals_file
        self.transform = transform
        self.images = np.load(images_file) # load numpy data of all images
        self.masks = np.load(masks_file) # load numpy data of all masks
        self.keypoints = np.load(locals_file) # load npy file with keypoints

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        mask = self.masks[index] # mask data is already normalized
        keypoints = self.keypoints[index]

        # The annotation is sometimes reaching edge on y coordonate because the CTR base is located there in some samples
        for i in range(4):
          # manage y-coordonates (idx = 1)
          if keypoints[i][1] >= 480: 
            keypoints[i][1] = keypoints[i][1] - 1
          # manage x-coordonates (idx = 1)
          if keypoints[i][0] >= 640: 
            keypoints[i][0] = keypoints[i][0] - 1

        if self.transform is not None: # do data augmentation
            augmentations = self.transform(image=image, mask=mask, keypoints=keypoints)
            image = augmentations["image"]
            mask = augmentations["mask"]
            keypoints = augmentations["keypoints"] # outputs list object

        # transform list into tensor
        keypoints = torch.FloatTensor(keypoints)

        # round transformation to closest integer
        keypoints = torch.round(keypoints)
        #print(keypoints)

        
        # 4 points as array of 8
        keypoints = torch.reshape(keypoints, (-1,))

        return image, mask, keypoints


Utils

In [None]:
KEYPOINT_COLOR = (0, 255, 0) # Green

def vis_keypoints(image, keypoints, color=KEYPOINT_COLOR, diameter=5): #(BATCH, C)
    image = image.copy()

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # remove normalization from image
    image = image*255

    '''
    for (x, y) in keypoints:
        cv2.circle(image, (int(x), int(y)), diameter, (0, 255, 0), -1)
    '''

    for i in range(4):
      cv2.circle(image, (int(keypoints[i*2]), int(keypoints[i*2+1])), diameter, KEYPOINT_COLOR, -1)
      
    return image


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
    train_dir,
    train_maskdir,
    train_keypointdir,
    val_dir,
    val_maskdir,
    val_keypointdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CTRDatasetNpy(
        train_dir,
        train_maskdir,
        train_keypointdir,
        train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CTRDatasetNpy(
        val_dir,
        val_maskdir,
        val_keypointdir,
        val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader


def save_predictions_as_imgs(
    loader, model, saved_image_index, folder=f'{SAVE_FOLDER}/output_masks/', device="cuda"):
    model.eval()

    idx = -1
    for x, y, z in loader: # image, mask, keypoints
        idx += 1
        x = x.to(device=device)
        with torch.no_grad():

            [output_mask, output_keypoints] = model(x)

            #print("Pre prediction accuracy")
            if KEYPOINT_NORM == True:
              output_keypoints = 0.5*torch.tanh(output_keypoints) # map between -1 1
              


            ### keypoints processing and saving
            images = x.detach().cpu().numpy()
            outputs_keypoints = output_keypoints.detach().cpu()
            gt_keypoints = z.detach().cpu()

            if KEYPOINT_NORM == True:
              pass
            else:
              outputs_keypoints = outputs_keypoints.numpy()
              gt_keypoints = gt_keypoints.numpy()


            # save keypoint output
            for pred_idx in range(len(outputs_keypoints)):
                image = np.moveaxis(images[pred_idx], 0, -1)

                pred_image = vis_keypoints(image, outputs_keypoints[pred_idx])

                perm1 = pred_image

                im_pred = Image.fromarray((perm1).astype(np.uint8))

                im_pred.save(f"{folder}/kpt_pred_epoch{saved_image_index}_batch{idx}_img{pred_idx}.png")      
              

            # save few samples of predictions every time you save images
            if idx == 0: # for first batch in validation
              '''
              for pred_idx in range(len(outputs_keypoints)):
                image = np.moveaxis(images[pred_idx], 0, -1)

                pred_image = vis_keypoints(image, outputs_keypoints[pred_idx])

                perm1 = pred_image

                im_pred = Image.fromarray((perm1).astype(np.uint8))

                im_pred.save(f"{folder}/pred_{saved_image_index}_{pred_idx}.png")
              '''

              # save initially the ground truths (just the first time)
              if saved_image_index == 0:
                for pred_idx in range(len(gt_keypoints)):
                  image = np.moveaxis(images[pred_idx], 0, -1)

                  gt_image = vis_keypoints(image, gt_keypoints[pred_idx])

                  perm2 = gt_image

                  im_gt = Image.fromarray((perm2).astype(np.uint8))

                  im_gt.save(f"{folder}/groundtruth_{saved_image_index}_{pred_idx}.png")

                  torchvision.utils.save_image(y.float().unsqueeze(1), f"{folder}/mask_{saved_image_index}.png", normalize=True) # save mask 



            print(f"The model predicted: {outputs_keypoints}; and the actual keypoints are: {gt_keypoints}")
            print("-------------------------------------------------------------------")
            print("-------------------------------------------------------------------")
            print("-------------------------------------------------------------------")
            print(f"The model predicted: {outputs_keypoints[-1]}; and the actual keypoints are: {gt_keypoints[-1]}")


            ### mask processing and saving
            #torchvision.utils.save_image(output_mask, f"{folder}/mask_pred_{saved_image_index}.png", normalize=True) # save mask
            
            preds1 = output_mask.detach().cpu().numpy()

            #print(preds1.ndim)

            for out_mask_idx in range(len(preds1)): # save mask
                preds2 = preds1[out_mask_idx] # get last prediction from batch

                #print(preds2.shape)

                preds3 = preds2

                seg = np.moveaxis(preds3, 0, -1)

                #print(seg.shape)
                
                print(f"Unique values of seg: {np.unique(seg)}")

                ## coloring
                my_seg = seg2img(seg)

                print(f"Unique values of seg: {np.unique(my_seg)}")

                im = Image.fromarray(my_seg)
                im.save(f"{folder}/mask_pred_epoch{saved_image_index}_batch{idx}_img{out_mask_idx}.png")
            
            print(f"Saved prediction: {saved_image_index}")

    model.train()


def seg2img(seg: np.array) -> np.array:
    colours = np.array(  # Colour triplets in cv2 convention (BGR instead of RGB)
        [[0, 0, 0],      # Black
         [0, 0, 255],    # Red
         [0, 255, 0],    # Green
         [255, 0, 0]],   # Blue
        dtype='uint8'
    )

    '''
    print("The segmentation array")
    print(seg.ndim)
    '''
    
    if seg.ndim == 2:  # assuming [H, W] containing class indices
        if np.min(seg) < 0 or np.max(seg) > 3:
            raise ValueError("Incorrect number of classes in seg array")
        return colours[seg]

    elif seg.ndim == 3:  # assuming [H, W, C] with C containing class probabilities (logits)
        img = seg2img(np.argmax(seg, axis=2))

        # Convert image to HSV colour space to get direct access to the saturation
        hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

        logits = np.copy(seg)  # Copy to avoid accidentally changing a mutable array outside the function

        max_vals = np.max(logits, axis=-1)  # Probability of the dominant class at each location

        ind = tuple(np.indices(logits.shape[:-1])) + (logits.argmax(axis=-1),)

        logits[ind] = 0  # Set all probabilities of dominant classes to 0

        sec_vals = np.max(logits, axis=-1)

        # Maximum probability is now that of the second-most-dominant class at each location
        sat_vals = (max_vals - sec_vals) / max_vals

        # Saturation is the lower, the closer dominant prob is to second-most-dominant prob
        hsv_img[..., 1] = np.uint8(sat_vals * 255)  # Update image with the saturation values
        hsv_img[..., 2] = 255
        img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)

        # Now fix the black values - make gray value equal to 255-sat_val
        indices = np.argmax(seg, axis=2) == 0
        img[indices] = 255 * (1 - sat_vals[indices])[..., np.newaxis]

        return img

Accuracy Utils

In [None]:
# segmentation accuracy (dice + good pixels)
def accuracy(loader, model, device="cuda"):
    num_correct = 1
    num_pixels = 1
    dice_score = 0
    model.eval()

    mean_dist = np.zeros([4])

    softmax = nn.Softmax(dim=1)

    g_dist_err = 0
    g_std_err = 0

    validation_size = 0

    with torch.no_grad():
        absolute_error_epoch = 0 # absolute error recorded for all epoch
        for x, y, z in loader: # image, mask, keypoints
            x = x.to(device)
            y = y.to(device)
            z = z.detach().to(device)

            #print("Pre prediction accuracy")
            [output_mask, keypoints_preds] = model(x)

            # apply softmax
            output_mask = softmax(output_mask)

            images = x.detach().cpu()
            keypoints_preds = keypoints_preds.detach().cpu()
            gt_keypoints = z.detach().cpu()
            mask = y.detach().cpu()

            validation_size += len(gt_keypoints) # add batch size

            
            #print(output_mask.shape)

            #print(output_mask.shape)

            output_mask = torch.transpose(output_mask, 1, 2)
            output_mask = torch.transpose(output_mask, 2, 3)

            
            dice_score += dice_coef2(mask.numpy(), output_mask.detach().cpu().numpy())


            # keypoint accuracy
            med = mean_endpoint_dist(gt_keypoints, keypoints_preds)

            # overall accuracy
            [mean_err, std_err] = med_std_keypoint_err(gt_keypoints, keypoints_preds)

            # add global errors
            if validation_size > 15:
              g_dist_err += mean_err
              g_std_err += std_err
            else:
              g_dist_err += mean_err
              g_std_err += std_err

            mean_dist += med

            '''
            for idx in range(len(med)):
              print(f"Keypoint (Mean Endpoint Distance) {idx}: {med[idx]} pixels (tip, tube 2, tube 1, base)")
              print(f"Keypoint (Standard Deviation Endpoint) {idx}: {std[idx]} pixels (tip, tube 2, tube 1, base)")
            '''

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score*100/len(loader):.2f} %")

    #print(f"mean_dist score: {mean_dist/len(loader):.2f} %")

    for idx in range(len(med)):
      print(f"Keypoint (Mean Endpoint Distance) {idx}: {mean_dist[idx]/validation_size:.2f} pixels (tip, tube 2, tube 1, base)")

    #print(len(loader))
    
    print(f"Global Mean Err score: {g_dist_err/len(loader):.2f} pixels")
    print(f"Global Std Err score: {g_std_err/len(loader):.2f} pixels")
    
   
    model.train()



def dice_coef2(y_true, y_pred, no_classes=4):
    dice=0

    #print(np.unique(y_true))
    #print(np.unique(y_pred))

    # arg max from output => returns max indecise from class comp
    max_indecies = np.argmax(y_pred, axis=-1)

    # comparing for each class
    for class_index in range(no_classes):
      # copy arrays for accidental mutations
      logits = np.copy(max_indecies)
      gt = np.copy(y_true)

      # logical operation to filter for given class
      pred_arr = (logits == class_index)
      # set all values from all other classes to 0
      gt_arr = (gt == class_index)

      # two binary arrays for a specific class filtered
      dice += dice_coef(pred_arr, gt_arr)

      #print(dice)

    # average dice score
    dice = dice/no_classes

    return dice


def dice_coef(y_true, y_pred, index=0):
    # make 0 - 1 probabilities for specific index in output

    # flatten arrays
    y_true_f = y_true
    y_pred_f = y_pred


    # calculate score
    intersection = np.sum(y_true_f * y_pred_f)
    smooth = 0.0001
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)



# mean endpoint distance for each keypoints (for each tip, tube 2, tube 1, base)
def mean_endpoint_dist(keypoints_gt, keypoint_preds): # for batch
      meds = np.zeros([4])
      keypoints_gt = keypoints_gt.float().view(len(keypoints_gt), 8) # gt is 2, 4; preds are [,8]
      keypoint_preds = keypoint_preds.float()

      # calculate endpoint distance
      endpoint_dist = abs(keypoint_preds-keypoints_gt)

      # mean endpoint distance in each batch sample
      for batch_idx in range(len(keypoint_preds)):
        mean_batch = []
        for idx in range(0, 8, 2): #calculate mean for each point
          math.pow(endpoint_dist[batch_idx][idx].float(), 2)    
          local_mean = math.sqrt((math.pow(endpoint_dist[batch_idx][idx], 2) + math.pow(endpoint_dist[batch_idx][idx+1], 2)))
          
          mean_batch.append(local_mean/2)
          print(f'Mean for sample {batch_idx}: {stat.mean(mean_batch)}')

        meds += mean_batch

      '''
      print(mean_batch)
      print(endpoint_dist)
      print(meds)
      '''

      # return mean endpoint dist list
      return meds


def med_std_keypoint_err(keypoints_gt, keypoint_preds):
  
  # arrange tensors
  keypoints_gt = keypoints_gt.float().view(len(keypoints_gt), 8) # gt is 2, 4; preds are [,8]
  keypoint_preds = keypoint_preds.float()

  # calculate endpoint distance from gt
  endpoint_dist = abs(keypoint_preds-keypoints_gt)

  # mean distance error
  mean_dist_err = torch.mean(endpoint_dist)

  # average standard dev error over all pixels
  std_err = torch.std(endpoint_dist)

  return [mean_dist_err, std_err]
