In [38]:
import torch
import cv2 as cv
import numpy as np
import sys
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils.utils import letterbox, driving_area_mask, lane_line_mask,\
    split_for_trace_model, non_max_suppression, plot_one_box, scale_coords, clip_coords
from time import time

In [41]:
def detect_and_save(model, file, file_name):
    loop_time = time()
    with torch.no_grad():
        model = model.cuda()
        scr = file
        img0 = scr.copy()
        img = cv.resize(img0, (640,480), interpolation=cv.INTER_NEAREST)
        output = img.copy()
        output = np.zeros([480, 640, 3])
            
        img = img.transpose(2, 0, 1)
        img = torch.from_numpy(img).cuda()
        img = img.float().half()
        img /= 255.0
        img = img.unsqueeze(0)
        [pred,anchor_grid],seg,ll = model(img)

        masking = True
        obj_det = True
            
        if masking:
            da_seg_mask = seg
            _, da_seg_mask = torch.max(da_seg_mask, 1)
            da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy()
                
            ll_seg_mask = ll
            ll_seg_mask = torch.round(ll_seg_mask).squeeze(1)
            ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy()
                
            color_area = np.zeros((da_seg_mask.shape[0], da_seg_mask.shape[1], 3), dtype=np.uint8)
                
            color_area[da_seg_mask == 1] = [0, 255, 0]
            color_area[ll_seg_mask == 1] = [255, 0, 0]
            color_seg = color_area
            color_seg = color_seg[..., ::-1]
            color_mask = np.mean(color_seg, 2)
            output[color_mask != 0] = output[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
                
        if obj_det:
            pred = split_for_trace_model(pred,anchor_grid)
            pred = non_max_suppression(pred)
            pred0 = pred[0]
                
            img0_shape = output.shape
            clip_coords(pred0, img0_shape)
                
            for det in pred0:
                *xyxy, _, _ = det
                plot_one_box(xyxy, output)
                
        to_save = Image.fromarray(output.astype(np.uint8))
        to_save.save('../dane/img_zaznaczone/' + file_name)
        #print("FPS {}".format(1.0 / (time() - loop_time)))
        loop_time = time()

In [42]:
model = torch.jit.load('data/weights/yolopv2.pt')
model = model.cuda().half().eval()

In [43]:
path_to_dir = '../dane/data/IMG/'
dir_files = os.listdir(path_to_dir)
for file_name in tqdm(dir_files):
    img = np.asarray(Image.open(path_to_dir + file_name))
    detect_and_save(model, img, file_name)

100%|████████████████████████████████████████████████████████████████████████████| 24108/24108 [18:45<00:00, 21.43it/s]
