In [1]:
import os, gc
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import glob
import shutil

from time import time
import random

import warnings
warnings.filterwarnings('ignore')

import cv2
from skimage import io, util
import imgaug.augmenters as iaa

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import seaborn as sns
sns.set_style("whitegrid", {'axes.grid' : False})

from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

In [2]:
# Class labels, used for assigning classes
LABELS = ['Branching', 'Fish', 'Massive', 'Not Massive', 'Substrate', 'Target', 'Water']

labels = {'Branching' : 0, 
          'Fish' : 1, 
          'Massive' : 2,
          'Not Massive' : 3,
          'Substrate' : 4,
          'Target' : 5,
          'Water' : 6}

UNIVERSIAL_COLORS = { 'Background' : 0,
                      'Branching' : 1, 
                      'Fish' : 2, 
                      'Massive' : 3,
                      'Not Massive' : 4,
                      'Substrate' : 5,
                      'Target' : 6,
                      'Water' : 7}

# The universial value for pixels that do no have a class label
NO_LABEL = 255

N_CLASSES = len(LABELS)

cp = np.array([ [178,24,43],
                [239,138,98],
                [253,219,199],
                [247,247,247],
                [209,229,240],
                [103,169,207],
                [33,102,172] ])


labelbox = {"Background" : np.array([0, 0, 0]),
          "Massive" : np.array([117, 219, 87]), 
          "Not Massive" : np.array([87, 219, 170]), 
          "Branching" : np.array([219, 95, 87]) , 
          "Fish" : np.array([219, 208, 87]) ,
          "Substrate" : np.array([87, 155, 219]) , 
          "Target" : np.array([133, 87, 219]) , 
          "Water" : np.array([219, 87, 192]) } 

In [3]:
# takes in a binary mask (containing values from 0 - 7 denoting the class), and converts
# it into a pretty image using the color palette (cp)
def colorize_prediction(pred):
   
    colored_mask = np.zeros(shape = (pred.shape[0], pred.shape[1], 3))

    for _ in range(N_CLASSES):
           
            colored_mask[pred == _] = list(labelbox)[labelbox[_]]
        
    return colored_mask

# Extracts a square patch with lengths equal to ps * 2 from an image, centered on coordinates x, y
def extract_patch(img, x, y):
    
    patch = img[y - ps : y + ps, x - ps : x + ps]
    
    return iaa.Resize(224).augment_image(patch) * (1./255.0)



# This does the majority of the work in this script:
# it is used provide sparse annotations to an image automatically
# when provided with an image, it will project X amount of points on
# the image, extract patches from the image centered on those points,
# and passes all of those patches to the CNN.
#
# The CNN will make predictions on each patch, and store the label,
# the confidence value (how sure the CNN is on its prediction), and 
# the location (x, y) from the image it was extracted from.
# 
# Input and image
# Outputs a dataframe (.csv) with all of the sparse annotations for that image
def get_sparse_points(img, percent):
 
    # determines if the points are sampled from a grid, or randomly
    # the larger the ratio, the more patches are sampled from grid
    # Shoot for about 0.01% of the total number of pixels in image
    ratio = 1.0
    
    num_points = int((img.shape[0] * img.shape[1]) * percent)
    density = int(np.sqrt(num_points)) 

    x_, y_ = np.meshgrid(np.linspace(offset, img.shape[1] - offset, int(density * ratio)), 
                         np.linspace(offset, img.shape[0] - offset, int(density * ratio)))

    xy = np.dstack([x_, y_]).reshape(-1, 2).astype(int)

    x = [point[0] for point in xy]
    y = [point[1] for point in xy]
    
    
    # If you want all of the points but they don't fit within the grid, 
    # this will sample the remainder of the randomly
    x += np.random.randint(offset, img.shape[1] - offset, num_points - len(xy)).tolist()
    y += np.random.randint(offset, img.shape[0] - offset, num_points - len(xy)).tolist()
    
    patches = np.array([extract_patch(img, point[0], point[1]) for point in list(zip(x, y))])
    
    predictions = model.predict(patches)
    
    predicted_labels = [list(labels)[np.argmax(prediction, axis = 0)] for prediction in predictions]
    confidence = [sorted(prediction)[-1] - sorted(prediction)[-2] for prediction in predictions]
    
    sparse_points = pd.DataFrame(list(zip(x, y, predicted_labels, confidence)), 
                                 columns = ['X', 'Y', 'Labels', 'Confidence'])

    
    return sparse_points

# Works, but there is another metrics function in 
# Fast-MSS repo that you can use too.
def get_sparse_scores(gt, pred):    

    cm = confusion_matrix(gt, pred)

    accuracy = np.zeros(shape = (N_CLASSES, ))
    precision = np.zeros(shape = (N_CLASSES, ))
    recall = np.zeros(shape = (N_CLASSES, ))
    iou = np.zeros(shape = (N_CLASSES, ))
    dice = np.zeros(shape = (N_CLASSES, ))

    for i in range(len(np.unique(gt))):

        # regions of the confusion matrix (TN)
        tl = cm[: i, : i].flatten()
        bl = cm[i :, : i].flatten()
        tr = cm[: i, i :].flatten()
        br = cm[i + 1 :, i + 1 :].flatten()

        tp = cm[i][i]
        fn = sum(cm[i, np.arange(cm.shape[0]) != i].flatten())
        fp = sum(cm[np.arange(cm.shape[0]) != i, i].flatten())
        tn = sum(tl) + sum(bl) + sum(tr) + sum(br)
        
        # to avoid nans since there aren't any instances of the class categories in the samples
        if(tp == 0 or tp + fp == 0 or tp + fp + fn == 0):
            continue

        accuracy[i] = (tp + tn) / (tp + tn + fn + fp) # acc
        precision[i] = tp / (tp + fp)
        recall[i] = tp / (tp + fn)
        iou[i] = tp / (tp + fp + fn)
        dice[i] = (2 * precision[i] * recall[i])/(precision[i] + recall[i])
    
    
    return accuracy_score(gt, pred), np.mean(precision), np.mean(recall), np.mean(dice), np.mean(iou)



# This determines how accurate the sparse points are against the ground truth (Dense Annotations)
# Finds the label in the corresponding pixel-index of the dense annotations, compares against sparse
def get_accuracy(gt, sparse_points):
     
    gt_labels = [gt[sparse_points['Y'][i], sparse_points['X'][i]] for i, r in sparse_points.iterrows()]
    gt_labels = [list(labels)[i] for i in gt_labels]
    
    
    return get_sparse_scores(gt_labels, sparse_points['Labels'].values)
    

In [4]:
path = "ground_truth\\"
images = sorted(glob.glob(path + "Images\\*.png"))
gts = sorted(glob.glob(path + "dense\\*.png"))

print(len(gts))


50


In [12]:
# This creates a CNN (efficientnetb0) and loads weights made in the image classification script
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.applications.nasnet import NASNetMobile 
import efficientnet.keras as efn

model = Sequential([
        efn.EfficientNetB0(weights = 'noisy-student', include_top = False,  pooling = 'max'),
        Dropout(.80),
        Dense(7),
        Activation('softmax')
])

model.load_weights("path_to_labels.h5")

In [13]:
# Patch size (84*2 pixels x 84*2 pixels), keep offset the same as ps, but feel free to change ps
# it will do best arround 84 - 112
ps = 84
offset = 84

In [14]:
scores = []


# loops through each of the images, gets sparse points for each, compares the sparse points with the corresponding gt
# the scores should be about 90%. The goal is to be higher, but 90% is a good baseline.
for index in range(50):
   
    basename = images[index].split("\\")[-1].split(".")[0]; print("image: ", str(index), basename)
    
    print(images[index].split("\\")[-1].split(".")[0], "\n", images[index].split("\\")[-1].split(".")[0])
    
    start = time()
    
    gt = io.imread(gts[index])
    image = io.imread(images[index])
        
    # percent is the number of pixels in the image given annoations
    sparse_points = get_sparse_points(image, percent = .00035); num_points = len(sparse_points); 
    
    sparse_points.to_csv("Sparse\\Manual 75\\" + basename + ".csv")
    print(len(sparse_points))

    #score = get_accuracy(gt, sparse_points); scores.append(score)

    
    print("Time: ", round(time() - start, 2))
    

#print("Average Classification Accuracy:", np.mean(scores, axis = 0))



image:  0 cam_1_before_1_0
cam_1_before_1_0 
 cam_1_before_1_0
2809
Time:  14.6
image:  1 cam_1_before_1_15
cam_1_before_1_15 
 cam_1_before_1_15
2809
Time:  14.52
image:  2 cam_1_before_1_22
cam_1_before_1_22 
 cam_1_before_1_22
2809
Time:  14.52
image:  3 cam_1_before_1_31
cam_1_before_1_31 
 cam_1_before_1_31
2809
Time:  14.5
image:  4 cam_1_before_1_44
cam_1_before_1_44 
 cam_1_before_1_44
2809
Time:  14.51
image:  5 cam_1_before_1_66
cam_1_before_1_66 
 cam_1_before_1_66
2809
Time:  14.55
image:  6 cam_1_before_1_7
cam_1_before_1_7 
 cam_1_before_1_7
2809
Time:  14.6
image:  7 cam_1_before_1_71
cam_1_before_1_71 
 cam_1_before_1_71
2809
Time:  14.56
image:  8 cam_1_before_1_88
cam_1_before_1_88 
 cam_1_before_1_88
2809
Time:  14.58
image:  9 cam_1_before_1_96
cam_1_before_1_96 
 cam_1_before_1_96
2809
Time:  14.54
image:  10 cam_1_before_2_1041
cam_1_before_2_1041 
 cam_1_before_2_1041
2809
Time:  14.54
image:  11 cam_1_before_2_117
cam_1_before_2_117 
 cam_1_before_2_117
2809
Tim

In [10]:
# Just viewing the last image
sparse_ = sparse_points[sparse_points['Confidence'] >= .70]
X = sparse_.X.values
Y = sparse_.Y.values
C = [labelbox[_]/255.0 for _ in sparse_.Labels.values]

plt.figure(figsize = (10, 10))
plt.imshow(image, alpha = .85)
plt.scatter(X, Y, c = C)
plt.xticks([])
plt.yticks([])