In [None]:
import os
import shutil
import zipfile
import cv2
import numpy as np
import torch
import time
from torch import nn
import random
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import torch.nn.functional as F
from scipy.ndimage import distance_transform_bf
from torchvision.transforms.functional import center_crop
torch.set_grad_enabled(True)

In [9]:
#extract files.zip
with zipfile.ZipFile('/content/drive/My Drive/orig_dataset.zip', 'r') as zip_ref:
    zip_ref.extractall('./')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [11]:
def get_random_click(ground_truth, prediction):
    prediction = prediction.astype(int)
    gt = ground_truth.astype(int)
    D_map = gt - prediction

    D_plus = D_map.copy()
    D_minus = D_map.copy()
    D_plus[D_plus < 0] = 0 #D_plus are false-negative pixels
    D_minus[D_minus > 0] = 0 #D_minus are false-positive pixels
    D_minus = np.abs(D_minus)

    #sum non zero elements of D_minus and D_plus
    sum_D_minus = np.sum(D_minus)
    sum_D_plus = np.sum(D_plus)


    click_type = False
    if(sum_D_minus > sum_D_plus):
        click_type = False
        selected_map = D_minus
    else:
        click_type = True
        selected_map = D_plus

    #get distances of each pixel to the nearest border
    selected_map = selected_map.astype(np.uint8)
    sel_map_transformed = cv2.distanceTransform(selected_map, cv2.DIST_L2, 5)

    # -----------------------------------------------------
    #this should only happen on perfect prediction
    if(np.sum(sel_map_transformed) == 0):
        return (-1, -1), click_type

    #get probability map and flattened probablities for pixel selection
    P_map = sel_map_transformed / (np.sum(sel_map_transformed))
    flattened_probabilities = P_map.flatten()
    # -----------------------------------------------------------



    #select a random pixel based on the probabilities
    random_pixel_index = np.random.choice(np.arange(len(flattened_probabilities)), p=flattened_probabilities)
    random_pixel_2d_index = np.unravel_index(random_pixel_index, P_map.shape)
    return random_pixel_2d_index, click_type

Convolutional block which combines 2D convolution, batch normalization and ReLU activation into one.

In [12]:
class CNNBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=0):
        super(CNNBlock, self).__init__()

        self.seq_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.seq_block(x)
        return x

Block of N convolutional blocks (N = 3 for basic UNET)

In [13]:
class CNNBlocks(nn.Module):
    def __init__(self,
                 n_conv,
                 in_channels,
                 out_channels,
                 padding):
        super(CNNBlocks, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(n_conv):
            self.layers.append(CNNBlock(in_channels, out_channels, padding=padding))
            # after each convolution we set (next) in_channel to (previous) out_channels
            in_channels = out_channels

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Encoder part of the UNET.

In [14]:
class Encoder(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 padding,
                 downhill=4): #amount of steps in Encoder
        super(Encoder, self).__init__()
        self.enc_layers = nn.ModuleList()

        #add a cnn block and a max pool layer for each step
        for _ in range(downhill):
            self.enc_layers += [
                    CNNBlocks(n_conv=2, in_channels=in_channels, out_channels=out_channels, padding=padding),
                    nn.MaxPool2d(2, 2)
                ]
            in_channels = out_channels #set the amount of input channels of the next step as the output of this step
            out_channels *= 2 #double the amount of output channels for the next step

        #depth of the last CNN block is doubled (bottleneck)
        self.enc_layers.append(CNNBlocks(n_conv=2, in_channels=in_channels,
                                         out_channels=out_channels, padding=padding))

    def forward(self, x):
        route_connection = []
        for layer in self.enc_layers:
            if isinstance(layer, CNNBlocks):
                x = layer(x)
                route_connection.append(x) #appending connections (horizontal arrows in UNET architecture)
            else:
                x = layer(x)
        return x, route_connection

Decoder part of the UNET

In [15]:
class Decoder(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 exit_channels,
                 padding,
                 uphill=4):
        super(Decoder, self).__init__()
        self.exit_channels = exit_channels
        self.layers = nn.ModuleList()

        for i in range(uphill):
            self.layers += [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                CNNBlocks(n_conv=2, in_channels=in_channels,
                          out_channels=out_channels, padding=padding),
            ]
            #halving input and output channels on each step upwards
            in_channels //= 2
            out_channels //= 2

        #one more convolution ath the end with kernel size 1
        self.layers.append(
            nn.Conv2d(in_channels, exit_channels, kernel_size=1, padding=padding),
        )

        #sigmoid output to turn logits into probabilities of pixels belonging to the final mask
        self.layers.append(nn.Sigmoid())


    def forward(self, x, routes_connection):
        #pop the last route connection (this one was created for bottleneck and is useless)
        routes_connection.pop(-1)
        for layer in self.layers:
            if isinstance(layer, CNNBlocks):
                #crop to make input and output dimensions match
                routes_connection[-1] = center_crop(routes_connection[-1], x.shape[2])
                #use concatenation for connections
                x = torch.cat([x, routes_connection.pop(-1)], dim=1)
                x = layer(x)
            else:
                x = layer(x)
        return x

UNET model class

In [16]:
class UNET(nn.Module):
    def __init__(self,
                 in_channels,
                 first_out_channels, #output of the first CNN Block
                 exit_channels,
                 downhill,
                 padding=1
                 ):
        super(UNET, self).__init__()
        self.encoder = Encoder(in_channels, first_out_channels, padding=padding, downhill=downhill)
        self.decoder = Decoder(first_out_channels*(2**downhill), first_out_channels*(2**(downhill-1)),
                               exit_channels, padding=padding, uphill=downhill)

    def forward(self, x):
        enc_out, routes = self.encoder(x)
        out = self.decoder(enc_out, routes)
        return out

Class for loading an image dataset

In [17]:
class CustomImageDataset(Dataset):
    def __init__(self, dataset_dir, mode, transform=None, target_transform=None, total_cnt=None):
        self.dataset_dir = dataset_dir
        self.transform = transform
        self.target_transform = target_transform
        self.image_data = [] #will store tuples of paths to image and label
        for imgnum_dir in os.listdir(dataset_dir + "/" + mode + "/"):
            img_dir = dataset_dir + "/" + mode + "/" + imgnum_dir + "/img"
            label_dir = dataset_dir + "/" + mode + "/" + imgnum_dir + "/label"
            for img in os.listdir(img_dir):
                for label in os.listdir(label_dir):
                    if label.startswith(img.split(".")[0]):
                        self.image_data.append((os.path.join(img_dir, img), os.path.join(label_dir, label)))
        #shuffle training data
        random.shuffle(self.image_data)
        if(total_cnt != None):
            self.image_data = self.image_data[:total_cnt]


    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_data[idx][0])
        label_path = os.path.join(self.image_data[idx][1])

        image = read_image(img_path)
        label = read_image(label_path)

        image = image.numpy()
        label = label.numpy()

        image = np.transpose(image, (1, 2, 0))
        label = np.transpose(label, (1, 2, 0))

        #normalize from 0-255 to 0-1 float range
        image = (image - np.min(image)) / (np.max(image) - np.min(image))
        label = (label - np.min(label)) / (np.max(label) - np.min(label))

        #insert first positive click
        random_pixel_2d_index, click_type = get_random_click(label, np.zeros(label.shape))
        pos_click_map = np.zeros(label.shape)
        pos_click_map[random_pixel_2d_index] = 1
        pos_click_map = gauss_filter(pos_click_map, 2)
        pos_click_map = (pos_click_map - np.min(pos_click_map)) / (np.max(pos_click_map) - np.min(pos_click_map))
        neg_click_map = np.zeros(label.shape)

        #add dimension to pos_click_map
        pos_click_map = np.expand_dims(pos_click_map, axis=2)
        #concatenate channels
        image = np.concatenate((image, pos_click_map, neg_click_map), axis=2)

        return image, label

In [18]:
def gauss_filter(image, sigma):
    return cv2.GaussianBlur(image, (0, 0), sigma)

In [None]:
train_dataset = CustomImageDataset("dataset", "training", transform=transforms.ToTensor(), total_cnt=7500)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataset = CustomImageDataset("dataset", "testing", transform=transforms.ToTensor(), total_cnt=2200)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)
print(len(train_dataloader))
print(len(val_dataloader))

In [20]:
class diceloss(torch.nn.Module):
    def init(self):
        super(diceloss, self).init()
    def forward(self,pred, label):
       smooth = 1
       iflat = pred.contiguous().view(-1)
       tflat = label.contiguous().view(-1)
       intersection = (iflat * tflat).sum()
       A_sum = torch.sum(iflat * iflat)
       B_sum = torch.sum(tflat * tflat)
       return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth))

In [21]:
def add_clicks_to_batch(input, pred, label, k, total_k, probabilistic_clicks=False):
    click_probability = 1.0
    if probabilistic_clicks:
        click_probability = 1.0 - (float(k) / float(total_k))
    #for each image in batch
    pred = pred.permute(0, 2, 3, 1)
    label = label.permute(0, 2, 3, 1)
    for i in range(pred.shape[0]):
        #get random click
        if np.random.rand() > click_probability:
            continue
        #convert from tensor to numpy array (detaching is safe, we do not need gradients in this stage)
        pred_np = pred[i].detach().cpu().numpy()
        label_np = label[i].detach().cpu().numpy()
        #turn pred and label from (W,H,1) to (W,H)
        pred_np = pred_np[:,:,0]
        label_np = label_np[:,:,0]
        random_pixel_2d_index, click_type = get_random_click(label_np, pred_np)
        #nowhere to place click
        if(random_pixel_2d_index == (-1,-1)):
            continue
        pos_map = np.zeros(pred_np.shape)
        neg_map = np.zeros(pred_np.shape)
        if click_type:
            pos_map[random_pixel_2d_index] = 1
            pos_map = gauss_filter(pos_map, 2)
            pos_map = (pos_map - np.min(pos_map)) / (np.max(pos_map) - np.min(pos_map))
        else:
            neg_map[random_pixel_2d_index] = 1
            neg_map = gauss_filter(neg_map, 2)
            neg_map = (neg_map - np.min(neg_map)) / (np.max(neg_map) - np.min(neg_map))

        #get the second channel of the input image
        input_pos_map = input[i][1].detach().cpu().numpy()
        input_neg_map = input[i][2].detach().cpu().numpy()

        #add the input maps and new click maps
        input_pos_map = input_pos_map + pos_map
        input_neg_map = input_neg_map + neg_map

        #normalize the input maps
        input_pos_map = (input_pos_map - np.min(input_pos_map)) / ((np.max(input_pos_map) - np.min(input_pos_map)) + 0.0000001) #small epsilon to avoi NaNs
        input_neg_map = (input_neg_map - np.min(input_neg_map)) / ((np.max(input_neg_map) - np.min(input_neg_map)) + 0.0000001)

        #add the new maps to the input
        input[i][1] = torch.tensor(input_pos_map)
        input[i][2] = torch.tensor(input_neg_map)
    return input

In [22]:
def get_iou(pred, label):
    pred = pred.detach().cpu().numpy()
    label = label.detach().cpu().numpy()
    intersection = np.logical_and(pred, label)
    union = np.logical_or(pred, label)
    iou = np.sum(intersection) / np.sum(union)
    return iou

In [24]:
def eval(dataloader, model, loss_fn, device, k):
    size = len(dataloader)
    avg_loss = 0
    avg_iou = 0

    avg_loss_1 = []
    avg_loss_2 = []
    avg_loss_5 = []
    avg_loss_10 = []
    avg_loss_15 = []
    avg_iou_1 = []
    avg_iou_2 = []
    avg_iou_5 = []
    avg_iou_10 = []
    avg_iou_15 = []

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            X = X.permute(0, 3, 1, 2)
            y = y.permute(0, 3, 1, 2)
            X = X.float()
            y = y > 0.5
            y = y.float()

            pred = model(X)
            pred = center_crop(pred, y.shape[2])
            loss = loss_fn(pred, y)

            pred = pred > 0.5
            avg_iou_1.append(get_iou(pred, y))
            avg_loss_1.append(loss.item())


            for i in range(k):
                pred = model(X)
                pred = center_crop(pred, y.shape[2])
                pred = pred > 0.5
                X = add_clicks_to_batch(X, pred, y, i, k, False)
                pred = model(X)
                pred = center_crop(pred, y.shape[2])
                loss = loss_fn(pred, y)
                if (i == 0):
                    pred = pred > 0.5
                    avg_iou_2.append(get_iou(pred, y))
                    avg_loss_2.append(loss.item())
                elif (i == 3):
                    pred = pred > 0.5
                    avg_iou_5.append(get_iou(pred, y))
                    avg_loss_5.append(loss.item())
                elif (i == 8):
                    pred = pred > 0.5
                    avg_iou_10.append(get_iou(pred, y))
                    avg_loss_10.append(loss.item())
                elif (i == 13):
                    pred = pred > 0.5
                    avg_iou_15.append(get_iou(pred, y))
                    avg_loss_15.append(loss.item())



    iou_total = [avg_iou_1, avg_iou_2, avg_iou_5,avg_iou_10, avg_iou_15]
    loss_total = [avg_loss_1, avg_loss_2, avg_loss_5,avg_loss_10, avg_loss_15]

    return iou_total, loss_total

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

model = UNET(in_channels=3, first_out_channels=32, exit_channels=1, downhill=4)
model.load_state_dict(torch.load("/content/drive/MyDrive/InterSegModelV1.pth"))
model.to(device)
loss_fn = diceloss()

model.eval()
iou_total, loss_total = eval(val_dataloader, model, loss_fn, device, 15)
size = len(val_dataloader)

all_iou = []
all_loss = []

iou_sum = 0
loss_sum = 0
for i in range(len(iou_total)):
  for x in range(len(iou_total[i])):
    iou_sum += iou_total[i][x]
    loss_sum += loss_total[i][x]
  all_iou.append(iou_sum/len(iou_total[i]))
  all_loss.append(loss_sum/len(loss_total[i]))
  loss_sum = 0
  iou_sum = 0

print(iou_total)
print(loss_total)

print(all_iou)
print(all_loss)


plt.figure(figsize = (10,10))
values =  [all_iou[0], all_iou[1],
           all_iou[2], all_iou[3],
           all_iou[4]]

labels = ['1', '2', '5', '10', '15']
colors = ['green', 'red', 'cyan', 'magenta', 'yellow']
plt.ylim(0, 1)
for i, (label, value) in enumerate(zip(labels, values)):
    plt.bar(i, value, align='center', color=colors[i], label=label, width=0.5)
    plt.text(i, value - 0.04, f'{value:.2f}', ha='center')

plt.xticks(range(len(labels)), labels)
plt.xlabel('Number of clicks')
plt.ylabel('Average IoU')
plt.title('Average IoUs per number of clicks')

# Display the plot
plt.savefig('average_iou.png')





In [None]:
#into ope plot display loss of training and validation
plt.figure()
plt.plot(train_loss_record, label="Training Loss")
plt.plot(val_loss_record, label="Validation Loss")
plt.legend()
plt.title("Loss")

plt.figure()
plt.plot(train_iou_record, label="Training IoU")
plt.plot(val_iou_record, label="Validation IoU")
plt.legend()
plt.title("IoU")

In [26]:
def visual_eval(dataloader, model, device, k):
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            X = X.permute(0, 3, 1, 2)
            y = y.permute(0, 3, 1, 2)
            X = X.float()
            y = y > 0.5
            y = y.float()

            pred = model(X)
            pred = center_crop(pred, y.shape[2])
            pred = pred > 0.5

            for i in range(16):
                if i >= pred.shape[0]:
                    break
                #print image, prediction and ground truth in one picture
                print_img = X[i].permute(1,2,0)
                print_img = print_img.clone().detach().cpu().numpy()
                print_pred = pred[i].permute(1,2,0)
                print_pred = print_pred.clone().detach().cpu().numpy()
                print_gt = y[i].permute(1,2,0)
                print_gt = print_gt.clone().detach().cpu().numpy()

                print_gt = y[i].permute(1,2,0)
                print_gt = print_gt.clone().detach().cpu().numpy()
                plt.subplot(4,4,i+1)
                plt.axis('off')
                #make images as close to each other as possible
                plt.subplots_adjust(wspace=0, hspace=0)
                plt.imshow(print_img)
                plt.imshow(print_pred, alpha=0.4, cmap='hot', interpolation="spline16")
            print("Click 1")
            plt.show()
            avg_iou = get_iou(pred, y)
            print(avg_iou)

            for i in range(k):
                pred = model(X)
                pred = center_crop(pred, y.shape[2])
                pred = pred > 0.5
                X = add_clicks_to_batch(X, pred, y, i, k, False)
                pred = model(X)
                pred = center_crop(pred, y.shape[2])
                pred = pred > 0.5
                plt.figure(figsize=(10,10))
                if i == 0 or i == 3 or i == 8 or i == 13:
                  print("Click: ", i+2)
                  for i in range(16):
                      #if pred is out of range, break
                      if i >= pred.shape[0]:
                          break
                      #print image, prediction and ground truth in one picture
                      print_img = X[i].permute(1,2,0)
                      print_img = print_img.clone().detach().cpu().numpy()
                      print_pred = pred[i].permute(1,2,0)
                      print_pred = print_pred.clone().detach().cpu().numpy()
                      print_gt = y[i].permute(1,2,0)
                      print_gt = print_gt.clone().detach().cpu().numpy()

                      print_gt = y[i].permute(1,2,0)
                      print_gt = print_gt.clone().detach().cpu().numpy()
                      plt.subplot(4,4,i+1)
                      plt.axis('off')
                      #make images as close to each other as possible
                      plt.subplots_adjust(wspace=0, hspace=0)
                      plt.imshow(print_img)
                      plt.imshow(print_pred, alpha=0.4, cmap='hot', interpolation="spline16")
                      #plt.imshow(print_gt, alpha=0.2, cmap="Blues")
                  plt.show()
                  avg_iou = get_iou(pred, y)
                  print(avg_iou)
                  print("-----------------")



In [None]:
visual_eval(val_dataloader, model, device, 14)

In [None]:
!cp /content/InterSegModelV1.pth /content/drive/MyDrive/InterSegModelV1.pth