In [11]:
import os
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import json
from tqdm.auto import tqdm
from collections import defaultdict
import cv2


data_dir = "/content/drive/MyDrive/2021.summer_URP/PD/KAIST_PD"
baselineID = 30
save_dir = os.path.join('/content/drive/MyDrive/2021.summer_URP/PD/check_prediction', str(baselineID))
pred_json_path = '/content/100epoch.json'
video_name = str(baselineID) + '.avi'

ann_box_color = '#FF0000'
pred_box_color = '#0000FF'
box_line_width = 2
font_size = 10
fps = 5

save_img = True

In [12]:
img_ids_txt = 'test-all-20.txt'
with open(os.path.join(data_dir, img_ids_txt), 'r') as f:
    img_ids = f.readlines()
img_ids = list(map(lambda x: x[:-1], img_ids))

ann_paths = list(map(
    lambda img_id: os.path.join('annotation_json', img_id+'.json'),
    img_ids
))
img_dir_names = list(map(
    lambda img_id: os.path.split(img_id),
    img_ids
))
color_img_paths = list(map(
    lambda img_dir_name: os.path.join(
        'images', img_dir_name[0], 'visible', img_dir_name[1]+'.jpg'), 
    img_dir_names 
))
thermal_img_paths = list(map(
    lambda img_dir_name: os.path.join(
        'images', img_dir_name[0], 'lwir', img_dir_name[1]+'.jpg'), 
    img_dir_names
))

with open(os.path.join(pred_json_path), 'r') as j:
    pred_json = json.load(j)
preds = defaultdict(list)
for pred in pred_json:
    if pred['category_id'] == 1:
        x_min, y_min, w, h = pred['bbox']
        x_max = x_min + w
        y_max = y_min + h
        pred['bbox'] = [x_min, y_min, x_max, y_max]
    preds[pred['image_id']].append(pred)

In [13]:
def draw_box(image, anns, preds_in_img, img_id):
    draw = ImageDraw.Draw(image)
    font = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf', font_size)
    text_size = font.getsize(img_id)
    text_location = [0, 0]
    textbox_location = [*text_location, text_size[0], text_size[1]]
    draw.rectangle(xy=textbox_location, fill='black')
    draw.text(xy=text_location, text=img_id, fill='white', font=font)


    for ann in anns['annotation']:
        box_loc = ann['bbox']
        cat_id = ann['category_id']
        if cat_id == 1:
            draw.rectangle(xy=box_loc, outline=ann_box_color, width=box_line_width)

    for pred in preds_in_img:
        box_loc = pred['bbox']
        cat_id = pred['category_id']
        if cat_id == 1:
            draw.rectangle(xy=box_loc, outline=pred_box_color, width=box_line_width)
            score = str(round(pred['score'] * 100, 2))
            text_size = font.getsize(score)
            text_location = [box_loc[0] + 2., box_loc[1] - text_size[1]]
            textbox_location = [box_loc[0], box_loc[1] - text_size[1],
                                box_loc[0] + text_size[0] + 4., box_loc[1]]
            draw.rectangle(xy=textbox_location, fill=pred_box_color)
            draw.text(xy=text_location, text=score, fill='white', font=font)

    return image

In [14]:
num_img = len(color_img_paths)

color_frame_list = []
thermal_frame_list = []
for i in tqdm(range(num_img)):
    with open(os.path.join(data_dir, ann_paths[i]), 'r') as j:
        anns = json.load(j)
    pred_in_img = preds[i]
    img_id = img_ids[i]

    color_image = Image.open(os.path.join(data_dir, color_img_paths[i]))
    color_image = draw_box(color_image, anns, pred_in_img, img_id)
    color_frame_list.append(cv2.cvtColor(np.array(color_image),
                                         cv2.COLOR_BGR2RGB))

    thermal_image = Image.open(os.path.join(data_dir, thermal_img_paths[i]))
    thermal_image = draw_box(thermal_image, anns, pred_in_img, img_id)
    thermal_frame_list.append(cv2.cvtColor(np.array(thermal_image),
                                           cv2.COLOR_BGR2RGB))

HBox(children=(FloatProgress(value=0.0, max=2252.0), HTML(value='')))




In [15]:
width = 640
height = 512
size = (width, height)
color_video_path = os.path.join(save_dir, 'color_prediction_video.mp4')
thermal_video_path = os.path.join(save_dir, 'thermal_prediction_video.mp4')

out = cv2.VideoWriter(color_video_path,
                      cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
for frame in tqdm(color_frame_list):
    out.write(frame)
out.release()

out = cv2.VideoWriter(thermal_video_path,
                      cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
for frame in tqdm(thermal_frame_list):
    out.write(frame)
out.release()

HBox(children=(FloatProgress(value=0.0, max=2252.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2252.0), HTML(value='')))


