In [6]:
import numpy as np
import cv2
from detection import detect_objects
import os
import copy
import json
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def select_frames(video_path, dst_path, n_frames=100):
    # parameters:
    #   video_path - absolute or relative path to video file
    #   dst_path - absolute or relative path to result frames folder
    #   n_frames - number of frames to sample
    cap = cv2.VideoCapture(video_path)
    count = 0
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_list = np.linspace(0, total_frames, n_frames, dtype=int)
    # folder = video_path.split('.')[0]
    frame_count = 0
    print('video fps:', fps)
    print('video total frames:', total_frames)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if count in frame_list:
            cv2.imwrite(f'{dst_path}/frame{str(frame_count).zfill(6)}.jpg', frame)
            frame_count += 1
        count += 1
    cap.release()
    cv2.destroyAllWindows()
    print('Frame selection: done')

In [8]:
thresholds = {
    'bowl': 250,
    'orange': 200,
    'teddybear': 200
}
def get_co3d_masks(mask_path, threshold):
    # parameters:
    #   mask_path - absolute or relative path to mask folder
    for file in tqdm(os.listdir(mask_path), desc='CO3D mask processing..'):
        img = cv2.imread(mask_path + '/' + file)
        img_w = copy.deepcopy(img)
        img_w[img_w >= threshold] = 255
        img_w[img_w < threshold] = 0
        cv2.imwrite(mask_path + '/' + file, img_w)

        
def get_colmap_masks(mask_path):
    # parameters:
    #   mask_path - absolute or relative path to mask folder
    for file in tqdm(os.listdir(mask_path), desc='Colmap processing..'):
        file = os.path.abspath(mask_path + '\\' + file)
        new_file = '\\'.join(file.split('\\')[:-1]) + \
        '\\' + \
        file.split('\\')[-1].replace('mask', 'frame')
        os.rename(file, new_file + '.png')
        
        
def get_masks(frame_path, mask_path, coco_class, frame_process=None, colmap_process=False):
#folder_path = r"C:\Users\marem\PycharmProjects\home\projects\unn\cw_1\frame_selector\frames\chair_auto"
    # parameters:
    #   frame_path - absolute or relative path to frames folder
    #   mask_path - absolute or relative path to mask folder
    #   coco_class - COCO class to recognize
    #   frame_process - what to do with frames on which object was not recognized:
    #     delete - delete such frames
    #     mask - make full frame masks for such frames
    #   colmap_process - whether to process masks for COLMAP format or not
    result = detect_objects(frame_path)
    folder_prefix = frame_path.split('\\')[-1]
    names_list = []
    for file in os.listdir(mask_path):
        file = os.path.abspath(mask_path + file)
        os.remove(file)
    for i, res in enumerate(tqdm(result, desc='Mask processing..')):
        best_obj = None
        best_obj_score = -1
        best_obj_area = -1
        for obj in res[1]:
            obj_coords = np.array(obj[1], dtype='int')
            obj_area = abs(obj_coords[1]-obj_coords[3]) * abs(obj_coords[0]-obj_coords[2])
            if obj[0] == coco_class:
                if obj_area > best_obj_area:
                    best_obj = obj
                    best_obj_area = obj_area

        if best_obj is None:
            continue
        names_list.append(res[0])
        coords = np.array(best_obj[1], dtype='int')
        img = cv2.imread(res[0])
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray[:, :] = 0
        gray[coords[1]:coords[3], coords[0]:coords[2]]=255
        mask_name = res[0].split('\\')[-1].replace('frame', 'mask')
        cv2.imwrite(mask_path + mask_name, gray)
    
    if frame_process == 'delete':
        for file in tqdm(os.listdir(frame_path), desc='Frame deletion..'):
            file = os.path.abspath(frame_path + '/' + file)
            if file not in names_list:
                os.remove(file)
    elif frame_process == 'mask':
        for file in tqdm(os.listdir(frame_path), desc='Frame masking..'):
            file = os.path.abspath(frame_path + '/' + file)
            if file not in names_list:
                img = cv2.imread(file)
                gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                gray[:, :] = 255
                mask_name = file.split('\\')[-1].replace('frame', 'mask')
                cv2.imwrite(mask_path + mask_name, gray)
    if colmap_process:
        get_colmap_masks(mask_path)
    if co3d_threshold is not None:
        get_co3d_masks(mask_path, co3d_threshold)
    print('Masks processing: done')

In [4]:
coco_classes = {
    'orange': 49,
    'bowl': 45,
    'teddybear': 77,
    'chair_auto': 56,
    'chair_manual': 56,
    'snow_auto': 77,
    'vase_auto': 75,
}

In [10]:
def get_specific_result(data_type, dst):
    # parameters:
    #   data_type - type of modality to get specific results for
    #   dst - absolute or relative path to resulting .json file
    spec_results = dict()
    for folder, val_1 in results.items():
        spec_results[folder] = dict()
        for param, val_2 in val_1.items():
            spec_results[folder][param] = dict()
            for field, val_3 in val_2.items():
                spec_results[folder][param][field] = dict()
                for res_type, val_4 in val_3.items():
                    if res_type == data_type:
                        spec_results[folder][param][field] = val_4
    with open(dst, 'w') as file:
        json.dump(spec_results, file, indent=2)

def get_ssim_json(src, folders, dst, specific_result=None, specific_dst=None):
    # parameters:
    #   src - absolute path to ssim results folder
    #   folders - prefixes of sample names (list)
    #   dst - absolute or relative path to resulting .json file
    #   specific_result - type of modality to get specific results for
    #   specific_dst - absolute or relative path to specific resulting .json file
    results = dict()
    for folder in folders:
        results[folder] = dict()
        prefix = src + '/' + folder
        for filename in ['score_log_GEOM.json', 'score_log_CURV.json', 'score_log_NORM.json', 'score_log_COLOR.json']:
            if 'NORM' in filename:
                file_path = prefix + '/' + 'score_log_CURV.json'
            else:
                file_path = prefix + '/' + filename
            data = None
            with open(file_path) as file:
                data = json.load(file)
            best_result = {
                'overall': None,
                'individual': None,
                'video': None,
                'internet': None
            }
            best_result_index = {
                'overall': None,
                'individual': None,
                'video': None,
                'internet': None
            } 
            ssim_param = filename.split('.')[0].split('_')[-1]
            results[folder][ssim_param] = dict()
            for i, result in data.items():
                for rec_type in ['individual', 'video', 'internet']:
                    if best_result['overall'] is None or result['score'][ssim_param.lower() + 'AB'] > best_result['overall']:
                        best_result['overall'] = result['score'][ssim_param.lower() + 'AB']
                        best_result_index['overall'] = result['name']
                    if rec_type in result['name']:
                        if best_result[rec_type] is None or result['score'][ssim_param.lower() + 'AB'] > best_result[rec_type]:
                            best_result[rec_type] = result['score'][ssim_param.lower() + 'AB']
                            best_result_index[rec_type] = result['name']
            results[folder][ssim_param] = { 
                'index': best_result_index,
                'best_result': best_result
            }
    with open(dst, 'w') as file:
        json.dump(results, file, indent=2)
    if specific_result is not None and specific_dst is not None:
        get_specific_result(specific_result, specific_dst)

In [16]:
def get_plot(plot_data, folders, dst, linetypes, colors, data_type, param, plot_type='ssim'):
    if plot_type == 'ssim':
        text_param = 'Score'
    else:
        text_param = 'Computation time'
    plt.style.use('ggplot')
    plt.figure(figsize=(10, 6), tight_layout=True)
    for folder, linetype, color in zip(folders, linetypes, colors):
        plt.plot(plot_data[folder]['x'], plot_data[folder]['y'], linetype, linewidth=2, label=folder, color=color)
        plt.plot(np.argmax(plot_data[folder]['y']), np.max(plot_data[folder]['y']), 'o', color='red')
    plt.xticks(plot_data[folder]['x'],
               ["None_None", "None_On", "Object_None", "Object_On", "Segment_None", "Segment_On"])
    plt.xlabel('Configs')
    plt.ylabel(text_param)
    plt.title(f'{text_param} by config')
    plt.legend(title=f'{data_type} {param}', loc='lower right')
    plt.grid(True)
    plt.savefig(dst, bbox_inches='tight')


def get_resulting_plots(src, folders, ssim_dst, linetypes, colors, param='GEOM', data_type='video', time_dst=None,
                        time_file=None, ):
    plot_data = dict()
    for folder in folders:
        plot_data[folder] = dict()
        plot_data[folder]['x'] = []
        plot_data[folder]['y'] = []
    for folder in folders:
        prefix = src + '/' + folder
        for filename in ['score_log_GEOM.json', 'score_log_CURV.json', 'score_log_NORM.json', 'score_log_COLOR.json']:
            if 'NORM' in filename:
                file_path = prefix + '/' + 'score_log_CURV.json'
            else:
                file_path = prefix + '/' + filename
            data = None
            with open(file_path) as file:
                data = json.load(file)
            ssim_param = filename.split('.')[0].split('_')[-1]
            if ssim_param != param:
                continue
            count = 0
            for i, result in data.items():
                if data_type not in result['name']:
                    continue
                plot_data[folder]['x'].append(count)
                plot_data[folder]['y'].append(result['score'][param.lower() + 'AB'])
                count += 1

    if ssim_dst is not None:
        get_plot(plot_data, folders, ssim_dst, linetypes, colors, data_type, param, plot_type='ssim')
    if time_dst is not None and time_file is not None:
        with open(time_file, 'r') as file:
            data = json.load(file)
        for folder, val_1 in data.items():
            plot_data[folder]['x'] = []
            plot_data[folder]['y'] = []
            for data_type, val_2 in val_1.items():
                if data_type not in data_type:
                    continue
                count = 0
                for mask_type, val_3 in val_2.items():
                    for vocab_tree, total_time in val_3.items():
                        plot_data[folder]['x'].append(count)
                        plot_data[folder]['y'].append(int(total_time))
                        count += 1
        get_plot(plot_data, folders, time_dst, linetypes, colors, data_type, param, plot_type='time')