In [1]:
import os
import cv2
import plotly.express as px
import numpy as np
from PIL import Image
import yaml
from copy import copy
from tqdm import tqdm
from ultralytics import YOLO

In [2]:
with open('./edu_train.yaml') as f:
    config = yaml.load(f, Loader=yaml.SafeLoader)
CLASSES = config['names']
print(CLASSES)

['__background__', 'text', 'junction', 'crossover', 'terminal', 'gnd', 'vss', 'voltage.dc', 'voltage.ac', 'voltage.battery', 'resistor', 'resistor.adjustable', 'resistor.photo', 'capacitor.unpolarized', 'capacitor.polarized', 'capacitor.adjustable', 'inductor', 'inductor.ferrite', 'inductor.coupled', 'transformer', 'diode', 'diode.light_emitting', 'diode.thyrector', 'diode.zener', 'diac', 'triac', 'thyristor', 'varistor', 'transistor.bjt', 'transistor.fet', 'transistor.photo', 'operational_amplifier', 'operational_amplifier.schmitt_trigger', 'optocoupler', 'integrated_circuit', 'integrated_circuit.ne555', 'integrated_circuit.voltage_regulator', 'xor', 'and', 'or', 'not', 'nand', 'nor', 'probe.current', 'probe.voltage', 'switch', 'relay', 'socket', 'fuse', 'speaker', 'motor', 'lamp', 'microphone', 'antenna', 'crystal', 'mechanical', 'magnetic', 'optical', 'block', 'unknown']


In [3]:
img_dir = './datasets/images/train/'
IMG_PATHS = [os.path.join(img_dir, img_name)
             for img_name in os.listdir(img_dir)
             if os.path.isfile(os.path.join(img_dir, img_name))]
print(IMG_PATHS[:5])

['./datasets/images/train/C100_D1_P1.jpg', './datasets/images/train/C100_D1_P1_aug0.jpg', './datasets/images/train/C100_D1_P1_aug1.jpg', './datasets/images/train/C100_D1_P1_aug2.jpg', './datasets/images/train/C100_D1_P1_aug3.jpg']


In [4]:
def loader(img_path):
    image = np.array(Image.open(img_path))

    labels_path = img_path.replace('images', 'labels').replace(img_path.split('.')[-1], 'txt')
    class_ids = []
    boxes = []
    with open(labels_path) as file:
        for line in file:
            line = line.rstrip()
            items = line.split(' ')
            class_ids.append(int(items[0]))
            
            box = [float(point) for point in items[1:]]
            box = [point if point < 1 else 0.99
                   for point in box]
            boxes.append(box)
    return image, class_ids, boxes


def parse_boxes(img_shape, boxes):
    y_size, x_size = img_shape[:2]
    for box in boxes:
        x_center = box[0] * x_size
        y_center = box[1] * y_size
        width = box[2] * x_size 
        height = box[3] * y_size 

        x1y1 = (int(x_center-(width/2)), int(y_center-(height/2)))
        x2y2 = (int(x_center+(width/2)), int(y_center+(height/2)))
        
        yield [x1y1, x2y2]
        
        
def put_box(img, class_ids, parsed_boxes, pred_confs):
    img = copy(img)
    for class_id, box, conf in zip(class_ids, parsed_boxes, pred_confs):
        x1y1, x2y2 = box
        img = cv2.rectangle(img, x1y1, x2y2, (0, 255, 255), 2)
        img = cv2.putText(img, f'{CLASSES[class_id]} {conf}',
                          x1y1, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
                          (255, 0, 255), 2)
    return img


def parse_pred(predicts):
    pred_classes = [int(cur_pred) for cur_pred in predicts['classes'] if cur_pred or cur_pred==0]
    pred_confs= [round(cur_pred, 3) for cur_pred in predicts['conf'] if pred_classes]

    pred_boxes = [[(int(cur_pred[0]), int(cur_pred[1])), (int(cur_pred[2]), int(cur_pred[3]))] for cur_pred in predicts['boxes'] if pred_classes]
    return pred_classes, pred_boxes, pred_confs

def get_pred(model, img):
    results = model.predict(source=img, verbose=True) #изменили наличие видеокарты
    for result in results:
        pred_class = result.boxes.cls.cpu().numpy().squeeze().tolist()
        pred_box = result.boxes.xyxy.cpu().numpy().squeeze().tolist()
        conf = result.boxes.conf.cpu().numpy().squeeze().tolist()
        
        if type(pred_class) != list:
            pred_class = [pred_class]
        if type(conf) != list:
            conf = [conf]
        if not pred_box:
            pred_box = [pred_box]
        elif type(pred_box[0]) != list:
            pred_box = [pred_box]
        
        yield {'classes': pred_class, 'boxes': pred_box, 'conf': conf}

In [5]:
MODEL = YOLO("runs/detect/yolov3-tiny8/weights/best.pt")

In [6]:
IMG_ITER = iter(IMG_PATHS)

In [7]:
def show_plot(images):
    fig = px.imshow(np.array(images), facet_col=0, facet_col_wrap=2, height = 1200)
    annotations = ['True', 'Predict', 'GrabCut_by_Pred', 'GrabCut_by_Full']
    item_map = {f'{i}': key for i, key in enumerate(annotations)}
    fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]]))
    fig.show()


    
def grab_cut(img, boxes):
    bgdModel = np.zeros((1, 65), np.float64)
    fgdModel = np.zeros((1, 65), np.float64)
    
    total_mask = np.zeros(img.shape[:2], np.uint8)
    if boxes is None:
        y_shape, x_shape = img.shape[:2]
        boxes = [[(0,0),(x_shape-1,y_shape-1)]]
    
    for box in boxes:
        x1y1, x2y2 = box
        rect = (x1y1[0], x1y1[1], x2y2[0] - x1y1[0], x2y2[1] - x1y1[1])

        mask = np.zeros(img.shape[:2], np.uint8)
        cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
        
        mask = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')        
        total_mask[mask==1] = 1
        
    img = img * total_mask[:, :, np.newaxis]
    return img
    
def plot_result():
    img_path = next(IMG_ITER)
    
    image, true_class_ids, true_boxes = loader(img_path)
    true_boxes = [*parse_boxes(image.shape, true_boxes)]
    
    pred_classes, pred_boxes, pred_confs = parse_pred(*get_pred(MODEL, image))
    
    true_img = put_box(image, true_class_ids, true_boxes, ['' for conf in pred_confs])
    pred_img = put_box(image, pred_classes, pred_boxes, pred_confs)
    grab_cut_img = grab_cut(image, pred_boxes)
    grab_cut_true_img = grab_cut(image, boxes=None)

    show_plot([true_img, pred_img, grab_cut_img, grab_cut_true_img])

In [8]:
plot_result()


0: 640x384 10 texts, 1 junction, 1 voltage.battery, 1 resistor, 1 diode.light_emitting, 1 transistor.bjt, 2 optocouplers, 1 switch, 1 speaker, 1207.3ms
Speed: 18.0ms preprocess, 1207.3ms inference, 13.8ms postprocess per image at shape (1, 3, 640, 640)
