In [1]:
from torchvision import transforms
import torch
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image

import sys
sys.path.append('../')
import utils.image as im
import utils.cam as cam
import utils.grabcut as gc
import utils.metrics as m
import utils.json as json
import utils.path as path
from utils.VOCSegmentation import VOCSegmentation

In [3]:
root_path = path.goback_from_current_dir(1)

json_path = root_path + 'json\\'

sgm_path  = root_path + 'VOCdevkit\VOC2012\SegmentationObject\\'

In [4]:
data_tbl = VOCSegmentation(root = root_path,
                           year = '2012',
                           image_set = 'trainval',
                           download = False,
                           transform = transforms.ToTensor(),
                           target_transform = transforms.ToTensor(),
                           transforms = None,
                           target = 'Object')

In [5]:
data = iter(torch.utils.data.DataLoader(data_tbl,
                                        batch_size = 1,
                                        shuffle = False,
                                        num_workers = 2))

In [6]:
annotations = json.open_json(json_path + 'voc-object-annotations')
annotations_clean = json.open_json(json_path + 'voc-object-annotations')

In [7]:
N = len(annotations)
N_TOTAL = len([1 for annot in annotations.values() for _ in annot])

In [None]:
annotations = iter(annotations.items())

In [None]:
def idx_to_color(n):
    nb2 = [int(b) for b in list(format(n,'06b'))]
    return list(reversed([nb2[i] * 64 + nb2[i+3] * 128 for i in range(3)])) #BGR -> RGB

In [None]:
def show_color(color):
    im.show(np.array(idx_to_color(color))[np.newaxis, np.newaxis, ...])
    
def notify_error(msg):
    sgm_cv2 = transforms.ToTensor()(Image.open(sgm_path + name + '.png').convert("RGB"))
    sgm_cv2 = im.pil_to_cv2(sgm_cv2.numpy())
    if first:
        
        print(f'Number of bboxes for this image = {len(annots)}')
        for k, annot2 in enumerate(annots):
            _, _, bbox2 = annot2
            im.show(im.draw_cbbox(img_cv2, bbox2, color = idx_to_color(k+1)))
            im.show(im.draw_cbbox(sgm_cv2, bbox2, color = idx_to_color(k+1)))
    print('==============\n')
        
    print(msg + f' for img {i} {name} and bbox {j}')
    im.show(sgm_cv2[ymin:ymax, xmin:xmax, :])
    print(f'Bbox associated color : expected color')
    show_color(expected_color)
    print('_____________________________________\n')

In [None]:
top1_match = 0
top2_match = 0
match      = 0
skipped    = 0
wrong      = 0

for i in range(N):
    first = True
    if i % 100 == 0:
        print(i)
    skip = 0
    
    img, sgm = next(data)
    img = torch.squeeze(img)
    sgm = torch.squeeze(sgm)
    img_pil = img
    img_cv2 = im.pil_to_cv2(img.numpy())
    sgm     = im.f1_to_f255(sgm.numpy())
    
    name, annots = next(annotations)
    
    for j, annot in enumerate(annots):
        _, _, bbox = annot
        
        expected_color = j + 1 - skip
        
        xmin, ymin, xmax, ymax = bbox
        crop_sgm = sgm[ymin:ymax, xmin:xmax]
        crop_sgm = crop_sgm[crop_sgm != 0]
        crop_sgm = crop_sgm[crop_sgm != 255]
        if len(np.unique(crop_sgm)) > 0: # more than background and undef
            
            unique, counts = np.unique(crop_sgm, return_counts = True)
            counts = np.argsort(-counts)

            if expected_color == unique[counts][0].tolist():
                top1_match += 1
            elif expected_color == unique[counts][1].tolist():
                top2_match += 1
            elif expected_color in unique[counts].tolist():
                match      += 1
            else:
                notify_error('Colors dont match')
                first = False
                
                print("Wrong")
                wrong += 1
        else:
            val = annotations_clean[name]
            val.pop(j - skip)
            annotations_clean[name] = val
            skip += 1
            
            notify_error('No true color')
            first = False
            
            print("Skipped")
            skipped += 1

In [None]:
print(f'# top1_match \t = {top1_match}')
print(f'# top2_match \t = {top2_match}')
print(f'# match \t = {match}')
print(f'# skipped \t = {skipped}')
print(f'# wrong\t\t = {wrong}')
print(f'__________________ +')
print(f'Total \t\t = {top1_match + top2_match + match + skipped + wrong} / {N_TOTAL}')

In [None]:
N_TOTAL_CLEAN = len([1 for annot in annotations_clean.values() for _ in annot])

In [None]:
print(f'{N_TOTAL - skipped} == {N_TOTAL_CLEAN}')

In [None]:
json.save_json(json_path + 'voc-object-annotations-clean', annotations_clean)

In [None]:
data = iter(torch.utils.data.DataLoader(data_tbl,
                                        batch_size = 1,
                                        shuffle = False,
                                        num_workers = 2))
annotations_clean_iter = iter(annotations_clean.items())

top1_match = 0
top2_match = 0
match      = 0
skipped    = 0
wrong      = 0

for i in range(N):
    first = True
    if i % 100 == 0:
        print(i)
    skip = 0
    
    img, sgm = next(data)
    img = torch.squeeze(img)
    sgm = torch.squeeze(sgm)
    img_pil = img
    img_cv2 = im.pil_to_cv2(img.numpy())
    sgm     = im.f1_to_f255(sgm.numpy())
    
    name, annots = next(annotations_clean_iter)
    
    for j, annot in enumerate(annots):
        _, _, bbox = annot
        
        expected_color = j + 1 - skip
        
        xmin, ymin, xmax, ymax = bbox
        crop_sgm = sgm[ymin:ymax, xmin:xmax]
        crop_sgm = crop_sgm[crop_sgm != 0]
        crop_sgm = crop_sgm[crop_sgm != 255]
        if len(np.unique(crop_sgm)) > 0: # more than background and undef
            
            unique, counts = np.unique(crop_sgm, return_counts = True)
            counts = np.argsort(-counts)

            if expected_color == unique[counts][0].tolist():
                top1_match += 1
            elif expected_color == unique[counts][1].tolist():
                top2_match += 1
            elif expected_color in unique[counts].tolist():
                match      += 1
            else:
                wrong += 1
        else:
            skipped += 1

In [None]:
print(f'# top1_match \t = {top1_match}')
print(f'# top2_match \t = {top2_match}')
print(f'# match \t = {match}')
print(f'# skipped \t = {skipped}')
print(f'# wrong\t\t = {wrong}')
print(f'__________________ +')
print(f'Total \t\t = {top1_match + top2_match + match + skipped + wrong} / {N_TOTAL_CLEAN}')