In [1]:
import json
import time
import datetime
import numpy as np
#np.set_printoptions(threshold=np.inf)
from collections import defaultdict
import glob
import cv2
import os
import random
import matplotlib.pyplot as plt
import seaborn as sns

# DPT imports
import torch
import torch.nn.functional as F
import util.io
from torchvision.transforms import Compose
from dpt.models import DPTSegmentationModel
from dpt.transforms import Resize, NormalizeImage, PrepareForNet

# Classifier imports
from tensorflow.keras.applications import * #Efficient Net included here
from tensorflow.keras import models
from tensorflow.keras import layers
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU') 
for gpu_instance in physical_devices: 
    tf.config.experimental.set_memory_growth(gpu_instance, True)

In [2]:
height = 224
width = 224

conv_base = EfficientNetB0(weights="efficientnetb0_notop.h5", include_top=False, input_shape=(height,width,3))
spill_classifier = models.Sequential()
spill_classifier.add(conv_base)
spill_classifier.add(layers.GlobalMaxPooling2D(name="gap"))
#avoid overfitting
#model.add(layers.Dropout(rate=0.2, name="dropout_out"))
# Set NUMBER_OF_CLASSES to the number of your final predictions.
spill_classifier.add(layers.Dense(2, activation="softmax", name="fc_out"))
spill_classifier.load_weights("model_best_weights.h5")
#spill_classifier = EfficientNetB0(weights="model_best_weights.h5", include_top=False, input_shape=(height,width,3))

In [19]:
#candidate_spill_categories = torch.tensor([22,28,29,44,61,105,110,121,129,138,142,143,148,23]).reshape(1,-1,1)
candidate_spill_categories = torch.tensor([22,29,44,61,105,110,121,129,138,142,143,148,23]).reshape(1,-1,1).to('cuda')
floor_cat = torch.tensor([4]).reshape(1,1,1).to('cuda')
# water, mirror, rug, sign, river, fountain, swimming_pool, food, lake, tray, screen, plate, glass, painting
top_k = 3

In [23]:
def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True):
    """Run segmentation network

    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    print("initialize")

    # select device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = 'cpu'
    print("device: %s" % device)

    net_w = net_h = 480

    # load network
    if model_type == "dpt_large":
        model = DPTSegmentationModel(
            150,
            path=model_path,
            backbone="vitl16_384",
        )
    elif model_type == "dpt_hybrid":
        model = DPTSegmentationModel(
            150,
            path=model_path,
            backbone="vitb_rn50_384",
        )
    else:
        assert (
            False
        ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]"

    transform = Compose(
        [
            Resize(
                net_w,
                net_h,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method="minimal",
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            PrepareForNet(),
        ]
    )

    model.eval()

    #if optimize == True and device == torch.device("cuda"):
    #    model = model.to(memory_format=torch.channels_last)
    #    model = model.half()

    model.to(device)

    # get input
    img_names = glob.glob(os.path.join(input_path, "*"))
    num_images = len(img_names)

    # create output folder
    os.makedirs(output_path, exist_ok=True)

    print("start processing")

    for ind, img_name in enumerate(img_names):

        print("  processing {} ({}/{})".format(img_name, ind + 1, num_images))

        # input
        img = util.io.read_image(img_name)
        img_input = transform({"image": img})["image"]

        # compute
        with torch.no_grad():
            sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
            #if optimize == True and device == torch.device("cuda"):
            #    sample = sample.to(memory_format=torch.channels_last)
            #    sample = sample.half()

            out = model.forward(sample)

            prediction = torch.nn.functional.interpolate(
                out, size=img.shape[:2], mode="bicubic", align_corners=False
            )
            max_pred = torch.argmax(prediction, dim=1, keepdim=True) + 1
            max_pred = max_pred.squeeze().cpu().numpy()
            
            sorted_k = torch.argsort(prediction, dim=1, descending=True)[:,:top_k] + 1
            '''top_k_cats = []
            for c in range(10):
                cat_plot = cv2.resize(sorted_k[:,c,:,:].float().squeeze().unsqueeze(-1).cpu().numpy(), (70, 40))
                cat_plot = cat_plot.astype(np.int32)
                top_k_cats.append(cat_plot)
            
            top_k_cats = np.stack(top_k_cats,axis=0)
            top1 = sorted_k[:,0,:,:].float().squeeze().unsqueeze(-1).cpu().numpy()
            top2 = sorted_k[:,1,:,:].float().squeeze().unsqueeze(-1).cpu().numpy()
            top3 = sorted_k[:,2,:,:].float().squeeze().unsqueeze(-1).cpu().numpy()'''

            spill_pix = (sorted_k.reshape(top_k,1,-1) == candidate_spill_categories).any(dim=1).any(dim=0)
            floor_pix = (sorted_k[:,:1].reshape(1,-1) == 4).any(dim=0)
            candidate_pix = torch.logical_and(spill_pix,floor_pix).reshape(img.shape[:2])
            candidate_pix = candidate_pix.long()
            seg_mask = candidate_pix.cpu().numpy() + 2
            candidate_pix = candidate_pix.cpu().numpy() * 255
            

        # output
        filename = os.path.join(
            output_path, os.path.splitext(os.path.basename(img_name))[0]
        )
        util.io.write_segm_img(filename, img, seg_mask, alpha=0.5)
        
        img = (img*255).astype(np.uint8)
        crops,bboxes = get_spill_crops(img, candidate_pix.astype(np.uint8))
        print(len(crops),len(bboxes))
        for crop,bbox in zip(crops,bboxes):
            resized = cv2.resize(crop, (224, 224)).reshape(1,224,224,3)
            inp = resized/255.
            outp = spill_classifier(inp)
            spill = tf.argmax(outp,axis=1).numpy()[0]
            if spill==0:
                cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0,255,0), 2)
                cv2.putText(img, "No Spill", (bbox[0],bbox[1]), cv2.FONT_HERSHEY_PLAIN, 1, (0,255,0), 2)
            else:
                cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 2)
                cv2.putText(img, "Spill", (bbox[0],bbox[1]), cv2.FONT_HERSHEY_PLAIN, 1, (255,0,0), 2)
        
        #cv2.imshow('a',img[:,:,::-1])
        #key = cv2.waitKey(0)
        #if key == ord('q'):
        #    break
        cv2.imwrite(filename[:-4]+'2.png',img[:,:,::-1])

    cv2.destroyAllWindows()
    print("finished")


In [24]:
def get_spill_crops(image,spill_seg):
    cnts,_ = cv2.findContours(spill_seg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    bboxes = []
    img_h,img_w,_ = image.shape
    img2 = image.copy()
    for c in cnts:
        x,y,w,h = cv2.boundingRect(c)
        bboxes.append((max(x-20,0),max(y-20,0),min(x+w+20,img_w-1),min(y+h+20,img_h-1)))
    #    cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (36,255,12), 2)
    
    no_overlap = True
    while no_overlap:
        if len(bboxes) < 2:
            no_overlap = False
            continue
        no_overlap = True
        bb_ix = -1
        while (bb_ix := bb_ix+1) < len(bboxes)-1:
            bb1 = bboxes[bb_ix]
            bb_ix2 = bb_ix
            while (bb_ix2 := bb_ix2+1) < len(bboxes):
                bb2 = bboxes[bb_ix2]

                x_left = max(bb1[0], bb2[0])
                y_top = max(bb1[1], bb2[1])
                x_right = min(bb1[2], bb2[2])
                y_bottom = min(bb1[3], bb2[3])

                if x_right <= x_left or y_bottom <= y_top:
                    continue
                else:
                    new_bbox = (min(bb1[0],bb2[0]),min(bb1[1],bb2[1]),max(bb1[2],bb2[2]),max(bb1[3],bb2[3]))
                    bboxes[bb_ix] = new_bbox
                    bb_ix = -1
                    del bboxes[bb_ix2]
                    no_overlap = False
                    break
    
    crops = []
    for bbox in bboxes:
        #cv2.rectangle(img2, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (36,255,12), 2)
        crops.append(img2[bbox[1]:bbox[3],bbox[0]:bbox[2]])
    
    #name = str(random.random())
    #cv2.imwrite('bbox_images/'+name+'.png', image)
    #cv2.imwrite('bbox_images/'+name[:-1]+'z.png', img2)
    
    return crops,bboxes

In [25]:
#spill_points1 = [(10,36),(18,32),(18,33),(18,34),(24,33),(24,34),(24,35),(24,36)]
#spill_points1 = [(33,33),(33,34),(34,33),(34,34),(34,35)]
spill_points1 = [(35,35),(35,36),(36,35),(36,36)]

start = datetime.datetime.now()

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# compute segmentation maps
run(
    "input",
    "output_semseg",
    "weights/dpt_hybrid-ade20k-53898607.pt"
)

print((datetime.datetime.now()-start).total_seconds())

initialize
device: cuda
start processing
  processing input/vlcsnap-2021-08-12-21h45m41s356.png (1/18)
2 2
  processing input/vlcsnap-2021-08-12-21h47m34s914.png (2/18)
4 4
  processing input/milk.png (3/18)
3 3
  processing input/images.jpeg (4/18)
1 1
  processing input/vlcsnap-2021-08-12-21h48m56s063.png (5/18)
1 1
  processing input/vlcsnap-2021-08-12-21h48m12s767.png (6/18)
3 3
  processing input/images (2).jpeg (7/18)
1 1
  processing input/vlcsnap-2021-08-12-21h47m00s244.png (8/18)
5 5
  processing input/vlcsnap-2021-08-12-21h47m45s421.png (9/18)
4 4
  processing input/download (1).jpeg (10/18)
1 1
  processing input/vlcsnap-2021-08-12-21h49m05s700.png (11/18)
1 1
  processing input/vlcsnap-2021-08-12-21h48m51s562.png (12/18)
2 2
  processing input/1627967401592.jpeg (13/18)
4 4
  processing input/vlcsnap-2021-08-12-21h48m47s313.png (14/18)
1 1
  processing input/vlcsnap-2021-08-12-21h45m20s326.png (15/18)
3 3
  processing input/download.jpeg (16/18)
1 1
  processing input/image

In [None]:
height = 224
width = 224
spill_classifier = EfficientNetB0(weights="model_best_weights.h5", include_top=True, input_shape=(height,width,3))

In [None]:
def find_top_separators(preds,spill_points,k):
    spill_map = preds == -1
    spill_map[:,:,:] = False
    floor_map = preds == 4
    for pnt in spill_points:
        floor_map[:10,pnt[0],pnt[1]] = False
        spill_map[:10,pnt[0],pnt[1]] = True
    
    top_k = preds[spill_map]
    print(spill_map.shape,top_k.shape)
    print(np.unique(top_k))

In [None]:
#find_top_separators(top_k_cats,spill_points1,3)
        
        '''if ind == 0:
            print("-----------TOP 1-------------")
            cat_plot = cv2.resize(top1, (70, 40))
            cat_plot = cat_plot.astype(np.int32)
            fig, ax = plt.subplots(figsize=(35,20))
            sns.heatmap(cat_plot,annot=True,cmap='Blues', fmt='d', ax=ax)

            print("-----------TOP 2-------------")
            cat_plot = cv2.resize(top2, (70, 40))
            cat_plot = cat_plot.astype(np.int32)
            fig, ax = plt.subplots(figsize=(35,20))
            sns.heatmap(cat_plot,annot=True,cmap='Blues', fmt='d', ax=ax)

            print("-----------TOP 3-------------")
            cat_plot = cv2.resize(top3, (70, 40))
            cat_plot = cat_plot.astype(np.int32)
            fig, ax = plt.subplots(figsize=(35,20))
            sns.heatmap(cat_plot,annot=True,cmap='Blues', fmt='d', ax=ax)'''

In [3]:
coco_imgs = glob.glob("/media/petrus/Data/coco/train2017/*")
print(coco_imgs[-1])

/media/petrus/Data/coco/train2017/000000225087.jpg


In [4]:
annots = "/home/petrus/Downloads/stuff_annotations_trainval2017/annotations/stuff_train2017.json"
file = open(annots)
data = json.load(file)

In [5]:
cats = data['categories']
id_to_name = {}
for cat in cats:
    id_to_name[cat['id']] = cat['name']
    
cat_counts = defaultdict(int)
for annot in data['annotations']:
    cat_counts[id_to_name[annot['category_id']]] += 1

print(cat_counts)

defaultdict(<class 'int'>, {'food-other': 6672, 'plastic': 11137, 'table': 16282, 'other': 117266, 'clouds': 9886, 'grass': 22575, 'ground-other': 6252, 'tree': 36466, 'wood': 5053, 'flower': 3259, 'plant-other': 9522, 'wall-concrete': 31481, 'clothes': 27657, 'river': 2313, 'sea': 6598, 'sky-other': 31808, 'floor-other': 8893, 'floor-tile': 6618, 'metal': 22526, 'fence': 11303, 'straw': 1385, 'furniture-other': 17882, 'bush': 9849, 'ceiling-tile': 351, 'pavement': 18311, 'road': 15402, 'wall-stone': 2020, 'window-other': 14209, 'bridge': 1676, 'building-other': 23021, 'house': 6549, 'platform': 2009, 'railroad': 2720, 'roof': 4490, 'floor-stone': 1259, 'dirt': 10163, 'paper': 9521, 'textile-other': 13052, 'cabinet': 7176, 'cardboard': 3787, 'counter': 4589, 'wall-tile': 5290, 'napkin': 1405, 'fog': 2659, 'mountain': 4887, 'railing': 2068, 'gravel': 2613, 'wall-other': 19095, 'floor-marble': 1002, 'wall-wood': 6642, 'window-blind': 2297, 'desk-stuff': 2909, 'floor-wood': 6324, 'mat': 5

In [13]:
floor_cats = ['floor-other','floor-wood','floor-stone','floor-marble','floor-tile','carpet']
ground_cats = ['ground-other','playingfield','platform','railroad','pavement','road','gravel','mud','dirt','snow','sand']

In [17]:
img_names = {}
for img in data['images']:
    img_names[img['id']] = img['file_name']

In [27]:
water_images = {}
for annot in data['annotations']:
    if id_to_name[annot['category_id']] == 'water-other':
        water_images[img_names[annot['image_id']]] = [annot]

for annot in data['annotations']:
    if img_names[annot['image_id']] in list(water_images.keys()):
        water_images[img_names[annot['image_id']]].append(annot)

In [30]:
floor_images = []
ground_images = []
for fp,img_annots in water_images.items():
    for annot in img_annots:
        if id_to_name[annot['category_id']] in floor_cats:
            floor_images.append(fp)
            break
    
    for annot in img_annots:
        if id_to_name[annot['category_id']] in ground_cats:
            ground_images.append(fp)
            break

In [34]:
print(len(floor_images))
print(len(ground_images))

for fp in ground_images:
    img = cv2.imread("/media/petrus/Data/coco/train2017/"+fp)
    cv2.imshow(fp,img)
    key = cv2.waitKey(0)
    cv2.destroyAllWindows()
    if key==27:
        break

279
1415


In [16]:
print(data['images'][4])
print(data['annotations'][0])

{'license': 3, 'file_name': '000000554625.jpg', 'coco_url': 'http://images.cocodataset.org/train2017/000000554625.jpg', 'height': 640, 'width': 426, 'date_captured': '2013-11-14 16:03:19', 'flickr_url': 'http://farm5.staticflickr.com/4086/5094162993_8f59d8a473_z.jpg', 'id': 554625}
{'segmentation': {'counts': 'Z4l4T:00000O10000O1000000O10000O1000000O1000000O101O0fKoJdNR5\\1PKbNP5^1RK`Nn4`1RK`No4_1QK`NP5`1PK`NP5`1PK`NP5`1PK_NQ5a1oJ_NQ5a1oJ_NR5`1nJ`NR5`1nJ_NS5a1mJ_NS5a1mJ_NS5a1mJ^NU5a1kJ_NU5a1kJ_NU5a1kJ_NU5a1kJ^NV5b1jJ^NW5a1iJ_NW5a1iJ_NW5a1iJ^NX5b1hJ^NX5b1hJ^NY5a1gJ^NZ5b1fJ^NZ5b1fJ^NZ5b1fJ^NZ5b1fJ]N[5c1eJ]N\\5b1dJ^N\\5b1dJ]N]5c1cJ]N]5c1cJ]N]5c1cJ]N]5c1cJ\\N^5d1bJ\\N^5d1bJ\\N^5d1bJ\\N^5d1bJ[N_5e1aJ[N_5e1aJ[N_5e1aJ[N_5e1aJ[N_5e1aJ[N_5e1aJ[N_5e1aJZN`5f1`JZN`5f1`JZN`5f1`JZN`5f1`JZN`5f1`JZN`5f1`JZN`5f1`JZN`5f1`JYNa5g1_JYNa5g1_JYNa5g1_JYNa5g1_JYN`5h1`JXN`5h1`JXN_5i1aJWN_5i1aJWN_5i1aJVN_5k1aJUN_5k1aJUN^5l1bJTN^5l1bJTN]5m1cJSN]5m1cJSN\\5n1dJRN\\5n1dJQN\\5P2dJPN\\5P2dJPN[5Q2eJoM[5Q2eJoM[5Q2eJoMZ5