In [None]:
import numpy as np
import cv2
import os
import glob
import warnings
import matplotlib.pyplot as plt

In [None]:
def bb_midpoint_to_corner(bb):
    label = bb[0]
    x1 = bb[1] - bb[3]/2
    x2 = bb[1] + bb[3]/2
    y1 = bb[2] - bb[4]/2
    y2 = bb[2] + bb[4]/2
    # A: area will only be used for sorting
    area = bb[3]*bb[4]
    corner_list = [label, x1, x2, y1, y2, area]
    return np.array(corner_list)

def open_yolo_sort(path, image_name):
    try:
        image = cv2.imread(path + image_name)
        shape = image.shape
        width = shape[1]
        height = shape[0]
        label = path + os.path.splitext(image_name)[0] + ".txt"
        boxes = np.genfromtxt(label, delimiter=' ')
        bb = boxes
        # reshaping the np array is necessary in case a file with a single box is read
        boxes = boxes.reshape(boxes.size//5, 5)
        #print(boxes.shape)
        boxes = np.apply_along_axis(bb_midpoint_to_corner, axis=1, arr=boxes)
        # A: sorting by area
        boxes = boxes[boxes[:, 5].argsort()]
        # A: reversing the sorted list so bigger areas come first
        boxes = boxes[::-1]
        return image, boxes, width, height
    except Exception as e:
        #print(e)
        print(image_name)
        return image, None, None, None
    

colors = [(162, 0, 255),  # Chave seccionadora lamina (Aberta)
         (97, 16, 162),   # Chave seccionadora lamina (Fechada)
         (81, 162, 0),    # Chave seccionadora tandem (Aberta)
         (48, 97, 165),   # Chave seccionadora tandem (Fechada)
         (121, 121, 121), # Disjuntor
         (255, 97, 178),  # Fusivel
         (154, 32, 121),  # Isolador disco de vidro
         (255, 255, 125), # Isolador pino de porcelana
         (162, 243, 162), # Mufla
         (143, 211, 255), # Para-raio
         (40, 0, 186),    # Religador
         (255, 182, 0),   # Transformador
         (138, 138, 0),   # Transformador de Corrente (TC)
         (162, 48, 0),    # Transformador de Potencial (TP)
         (162, 0, 96)     # Chave tripolar
         ] 

colors_rgb = []
for c in colors:
    colors_rgb.append((c[2], c[1], c[0]))

In [None]:
_ITER_COUNT = 10

def grabcut(image_path, save_path, image_name, filter_class):
    """Use Grabcut to create a binary segmentation mask of all objects of a given class in an image,
       based on YOLO box annotations.
    Param:
        image_path: path to the dir where the images are stored
        save_path: path to the dir where the masks will be saved
        image_name: name for the image proper
        filter_class: intex of the class to make a GrabCut mask from.
    Returns:
        1 if there were any labels of the class from filter list and the mask was made
        succesfully, 0 otherwise;
        The number of boxes that belonged to filter_class in the image, which may be 0.
        An RGB image with shape (image_height, image_width, 3) that is a binary mask, but colored
        with the correct object class from the colors global array. The image will be pure black,
        if no boxes of filter_class were present.
    Based on:
        https://stackoverflow.com/questions/12810405/opencv-set-color-to-a-foreground-marked-pixel-gc-pr-fgd
        https://docs.opencv.org/3.4/d8/d83/tutorial_py_grabcut.html
        https://pyimagesearch.com/2020/09/28/image-segmentation-with-mask-r-cnn-grabcut-and-opencv/     
    """
    image, bb, w, h = open_yolo_sort(image_path, image_name)
    mask = image.copy()*0
    # the mask has to be in grayscale for grabcut to work
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    
#     fig = plt.figure(figsize=(20,20))
#     ax = fig.add_subplot(141)
#     ax.imshow(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
#     plt.axis('off')
    
    filter_flag = False
    count = 0
    if bb is not None:
        for label, x1, x2, y1, y2, area in bb:
            if label == filter_class:
                count += 1
                filter_flag = True
                rw = (x2*w - x1*w)
                rh = (y2*h - y1*h)
                # the likely foreground around the center
                cv2.rectangle(mask, (int(x1*w), int(y1*h)), (int(x2*w), int(y2*h)), 100, -1)
                # the certainly foreground will be the inntermost 20%
                cv2.rectangle(mask, (int(x1*w + rw*0.4), int(y1*h + rh*0.4)), (int(x2*w - rw*0.4), int(y2*h - rh*0.4)), 255, -1)
    else:
        return 0, 0, image.copy()*0
        
    # grabcut will only be performed if there was any bounding box in the filter list to start with
    if filter_flag:
        # https://stackoverflow.com/questions/12810405/opencv-set-color-to-a-foreground-marked-pixel-gc-pr-fgd
        # https://docs.opencv.org/3.4/d8/d83/tutorial_py_grabcut.html
        # https://pyimagesearch.com/2020/09/28/image-segmentation-with-mask-r-cnn-grabcut-and-opencv/
        # GC_BGD Certainly a background pixel
        # GC_FGD Certainly a foreground (object) pixel
        # GC_PR_BGD Probably a background pixel
        # GC_PR_FGD Probably a foreground pixel
        mask[mask == 255] = cv2.GC_FGD    # the 20% innermost
        mask[mask == 100] = cv2.GC_PR_FGD # the area around center
        mask[mask == 0]   = cv2.GC_BGD    # the black pixels background
        gct = np.zeros(image.shape[:2], np.uint8)
        bgdModel, fgdModel = np.zeros((1, 65), np.float64), np.zeros((1, 65), np.float64)
        (gcMask, bgModel, fgModel) = cv2.grabCut(image, mask, None, bgdModel, fgdModel, _ITER_COUNT, cv2.GC_INIT_WITH_MASK)
        outputMask = np.where((gcMask == cv2.GC_BGD) | (gcMask == cv2.GC_PR_BGD), 0, 1)
        outputMask = (outputMask * 255).astype("uint8")
        
#         ax3 = fig.add_subplot(142)
#         ax3.imshow(cv2.cvtColor(outputMask, cv2.COLOR_GRAY2BGR))
#         plt.axis('off')

        output = cv2.cvtColor(outputMask, cv2.COLOR_GRAY2BGR) 
        output = np.where(output != (0, 0, 0), colors_rgb[filter_class], output)

#         ax3 = fig.add_subplot(143)
#         ax3.imshow(cv2.cvtColor(output.astype(np.uint8), cv2.COLOR_RGB2BGR))
#         plt.axis('off')
        
        return 1, count, output
    else:
        return 0, 0, image.copy()*0

In [None]:
def apply_grabcut_multiclass(image_path, save_path, image_name, filter_list, mask_list):
    """
    Creates a segmentation mask out of independent masks for each class.
    Param:
        image_path: path to the dir where the images are stored
        save_path: path to the dir where the masks will be saved
        image_name: name for the image proper
        filter_list: array with the labels that will be accounted for
        mask_list: array of segmentation masks for each class label. Each is an image with 
        the same resolution as the one that will be opened using image_path and image_name.
    Returns:
        Nothing, but it does save the image. If no objects from filter_list were found, a pitch
        black image (0, 0, 0) with the same resolution as the one that will be opened using
        image_path and image_name will be saved instead.
    """
    image, bb, w, h = open_yolo_sort(image_path, image_name)
    gc_mask = image.copy()*0
    
    if bb is not None:
        for label, x1, x2, y1, y2, area in bb:
            label = int(label)
            if label in filter_list:
                mask = mask_list[label]
                sx1 = int(x1*w)
                sx2 = int(x2*w)
                sy1 = int(y1*h)
                sy2 = int(y2*h)
                gc_mask[sy1:sy2, sx1:sx2][np.where(np.all(mask[sy1:sy2, sx1:sx2] == colors_rgb[label], axis=-1))[:2]] = colors_rgb[label]
        cv2.imwrite(save_path + os.path.splitext(image_name)[0] +".png", gc_mask)
#         fig = plt.figure(figsize=(20,20))
#         ax = fig.add_subplot(111)
#         test = cv2.cvtColor(gc_mask.astype(np.uint8), cv2.COLOR_RGB2BGR)
#         ax.imshow(test)
#         plt.axis('off')

In [None]:
# paths = ["/home/jovyan/work/yolo_og/yolov3/Data/15_classes/test/"]
paths = ["/home/jovyan/work/deeplab/data/train/"]

# filter list está no modo in
filter_list = [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14] # excluding only porcelain pin insulator
save_path = "../data/14_class/train_gci/"

import sys
np.set_printoptions(threshold=sys.maxsize)

check_path = os.path.isdir(save_path)
if not check_path:
    os.makedirs(save_path)
    print("created folder: ", save_path) 
    
#label_path = "../data/labels_17class/"
images = 0
boxes = 0
total_images = 0
for image_path in paths:
    print(image_path)
    file_list = os.path.join(os.path.join(image_path, "*.*"))
    image_list = []
    file_list = glob.glob(file_list)
    for name in file_list:
        if "txt" not in name:
            image_list.append(name.split("/")[-1])
    total_images += len(image_list)
    mask_list = [None]*len(colors)
    for file in image_list:
        image_increase = True # to only add the image once if a class was in the label list
        for object_class in filter_list:
#             print(object_class)
            im, box_num, mask_list[object_class] = grabcut(image_path, save_path, file, object_class)
            if image_increase:
                images += im
                image_increase = False
            boxes += box_num
        apply_grabcut_multiclass(image_path, save_path, file, filter_list, mask_list)

text = f'Images: {images}\nBoxes: {boxes}\nTotal images: {total_images}'
print(text)
with open("../data/14_class/train_gci.txt", 'w') as f:
    f.write(text)