In [None]:
import os
import sys
import re
from copy import deepcopy
from typing import List, Dict, Tuple
from collections import Counter, OrderedDict
from glob import glob
import json
import toml
from math import floor

from tqdm.auto import tqdm
import cv2
from PIL import Image
import numpy as np

sys.path.append("./modules/") # add path to scan customized module
from fileop import create_new_dir
from gallery_utils import draw_x_on_image, draw_predict_ans_on_image
from datasetop import sortFishNameForDataset
import plt_show

# print("="*100, "\n")

Load `make_cam_gallery.toml`

In [None]:
with open("make_cam_gallery.toml", mode="r") as f_reader:
    config = toml.load(f_reader)

column = config["layout"]["column"]

line_color = config["draw"]["drop_image"]["line"]["color"]
line_width = config["draw"]["drop_image"]["line"]["width"]

cam_weight = config["draw"]["cam_image"]["weight"]
replace_cam_color = config["draw"]["cam_image"]["replace_color"]["enable"]
replaced_colormap  = getattr(cv2, config["draw"]["cam_image"]["replace_color"]["colormap"])
text_correct_color   = config["draw"]["cam_image"]["text"]["color"]["correct"]
text_incorrect_color = config["draw"]["cam_image"]["text"]["color"]["incorrect"]
text_shadow_color    = config["draw"]["cam_image"]["text"]["color"]["shadow"]
text_font_style      = config["draw"]["cam_image"]["text"]["font_style"]
text_font_size       = config["draw"]["cam_image"]["text"]["font_size"] # if None, do auto-detection

load_dir_root = config["model"]["history_root"]
model_name    = config["model"]["model_name"]
model_history = config["model"]["history"]

Load `train_config.toml`

In [None]:
load_dir = os.path.join(load_dir_root, model_name, model_history)
train_config_path = os.path.join(load_dir, r"train_config.toml")

with open(train_config_path, mode="r") as f_reader:
    train_config = toml.load(f_reader)

dataset_root       = os.path.normpath(train_config["dataset"]["root"])
dataset_name       = train_config["dataset"]["name"]
dataset_gen_method = train_config["dataset"]["gen_method"]
dataset_stdev      = train_config["dataset"]["stdev"]
dataset_param_name = train_config["dataset"]["param_name"]

Generate `path_vars`

In [None]:
dataset_dir = os.path.join(dataset_root, dataset_name, dataset_gen_method, dataset_stdev, dataset_param_name)

test_selected_dir = os.path.join(dataset_dir, "test", "selected")
test_drop_dir = os.path.join(dataset_dir, "test", "drop")

# cam_result_root, cam_gallery_dir
cam_result_root = os.path.join(load_dir, "cam_result")
cam_gallery_dir = os.path.join(load_dir, "!--- CAM Gallery")

# cam_gallery_class_dir
logs_path = os.path.join(dataset_dir, r"{Logs}_train_selected_summary.log")
with open(logs_path, 'r') as f_reader: class_counts: Dict[str, int] = json.load(f_reader)

# dir_name with ranking  
rank_dict = {}
for i in range(10+1):
    if i < 5: rank_dict[i*10] = f"Match{str(i*10)}_(misMatch)"
    elif i == 10: rank_dict[i*10] = f"Match{str(i*10)}_(Full)"
    else: rank_dict[i*10] =  f"Match{str(i*10)}"

for key, _ in class_counts.items():
    for _, value in rank_dict.items():
        create_new_dir(os.path.join(cam_gallery_dir, key, value), display_in_CLI=False)

In [None]:
# Read `predict_ans.log`

logs_path = os.path.join(load_dir, r"{Logs}_predict_ans.log")
with open(logs_path, 'r') as f_reader: 
    predict_ans_dict: Dict[str, int] = json.load(f_reader)

Run

In [None]:
def single_cam_gallery(fish_name_for_dataset:str, 
                       test_selected_dir:str, test_drop_dir:str, cam_result_root:str, pbar_n_fish:tqdm):
    
    pbar_n_fish.desc = f"Generate ' {fish_name_for_dataset} ' "
    pbar_n_fish.refresh()
    
    fish_name_for_dataset_split_list = re.split(" |_|-", fish_name_for_dataset)
    fish_cls = fish_name_for_dataset_split_list[0]
    
    test_selected_path_list = sorted(glob(os.path.normpath((f"{test_selected_dir}/{fish_cls}/"
                                                            f"{fish_name_for_dataset}_selected_*.tiff")))
                                     , key=sortFishNameForDataset)
    
    test_drop_path_list = sorted(glob(os.path.normpath((f"{test_drop_dir}/{fish_cls}/"
                                                        f"{fish_name_for_dataset}_drop_*.tiff")))
                                 , key=sortFishNameForDataset)

    if replace_cam_color:
        cam_result_path_list = sorted(glob(os.path.normpath(f"{cam_result_root}/{fish_name_for_dataset}/grayscale_map/*.tiff"))
                                      , key=sortFishNameForDataset)
    else:
        cam_result_path_list = sorted(glob(os.path.normpath(f"{cam_result_root}/{fish_name_for_dataset}/color_map/*.tiff"))
                                      , key=sortFishNameForDataset)
    
    # read images as Dict[path, cv2.Mat]
    test_selected_img_dict = { img_path: cv2.imread(img_path) for img_path in test_selected_path_list }
    test_drop_img_dict = { img_path: cv2.imread(img_path) for img_path in test_drop_path_list }
    cam_result_img_dict = { img_path: cv2.imread(img_path) for img_path in cam_result_path_list }
    
    
    # draw on 'drop' images
    for path, bgr_img in test_drop_img_dict.items():

        rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
        rgb_img = np.uint8(rgb_img * 0.5) # suppress brightness
        
        rgb_img = Image.fromarray(rgb_img)
        draw_x_on_image(rgb_img, line_color, line_width)

        test_drop_img_dict[path] = cv2.cvtColor(np.array(rgb_img), cv2.COLOR_RGB2BGR)
    
    
    # draw on `cam` images
    pred_cls_cnt = Counter()
    for (cam_path, cam_img), (selected_path, selected_bgr_img) in zip(cam_result_img_dict.items(), test_selected_img_dict.items()):
        
        selected_rgb_img = cv2.cvtColor(selected_bgr_img, cv2.COLOR_BGR2RGB)
        
        if replace_cam_color: cam_bgr_img = cv2.applyColorMap(cam_img, replaced_colormap) # BGR
        else: cam_bgr_img = cam_img
        cam_rgb_img = cv2.cvtColor(cam_bgr_img, cv2.COLOR_BGR2RGB)
            
        # overlay `cam` image on `selected` image
        cam_overlay = ((cam_rgb_img/255) * cam_weight + 
                       (selected_rgb_img/255) * (1-cam_weight))
        cam_overlay = np.uint8(255 * cam_overlay)
        
        # get param for `draw_predict_ans_on_image`
        selected_image_name = selected_path.split(os.sep)[-1].split(".")[0]
        gt_cls = predict_ans_dict[selected_image_name]['gt']
        pred_cls = predict_ans_dict[selected_image_name]['pred']
        pred_cls_cnt.update(pred_cls)
        if gt_cls != pred_cls:
            # create a red mask
            mask = np.zeros_like(cam_overlay) # black mask
            mask[:, :, 0] = 1 # modify to `red` mask
            mask_overlay = np.uint8(255 *((cam_overlay/255) * 0.7 + mask * 0.3)) # fusion with red mask
            # draw text
            rgb_img = Image.fromarray(mask_overlay)
            draw_predict_ans_on_image(rgb_img, pred_cls, gt_cls,
                                      text_font_style, text_font_size,
                                      text_correct_color,
                                      text_incorrect_color,
                                      text_shadow_color)
            cam_overlay = np.array(rgb_img)
        cam_result_img_dict[cam_path] = cv2.cvtColor(cam_overlay, cv2.COLOR_RGB2BGR)
    
    cls_matching_state = "TBA"
    matching_ratio = floor((pred_cls_cnt[fish_cls]/len(cam_result_path_list))*100)/100
    matching_ratio_percent = matching_ratio*100
    for key, value in rank_dict.items():
        if matching_ratio_percent >= key: cls_matching_state = value
    
    
    # concate dicts: `test_selected_img_dict`, `test_drop_img_dict`
    orig_img_dict = deepcopy(test_selected_img_dict)
    orig_img_dict.update(test_drop_img_dict)
    sorted_orig_img_dict = OrderedDict(sorted(list(orig_img_dict.items()), key=lambda x: sortFishNameForDataset(x[0])))
    orig_img_list = [ img for _, img in sorted_orig_img_dict.items() ]
    
    # plot with 'Auto Row Calculation'
    kwargs_plot_with_imglist_auto_row = {
        "img_list"   : orig_img_list,
        "column"     : column,
        "fig_dpi"    : 200,
        "figtitle"   : f"( original ) {fish_name_for_dataset} : {orig_img_list[-1].shape[:2]}",
        "save_path"  : f"{cam_gallery_dir}/{fish_cls}/{cls_matching_state}/{fish_name_for_dataset}_orig.png",
        "show_fig"   : False
    }
    plt_show.plot_with_imglist_auto_row(**kwargs_plot_with_imglist_auto_row)
    
    
    # concate dicts: `cam_result_img_dict`, `test_drop_img_dict`
    cam_overlay_img_dict = deepcopy(cam_result_img_dict)
    cam_overlay_img_dict.update(test_drop_img_dict)
    sorted_cam_overlay_img_dict = OrderedDict(sorted(list(cam_overlay_img_dict.items()), key=lambda x: sortFishNameForDataset(x[0])))
    cam_overlay_img_list = [ img for _, img in sorted_cam_overlay_img_dict.items() ]
    
    # plot with 'Auto Row Calculation'
    kwargs_plot_with_imglist_auto_row = {
        "img_list"   : cam_overlay_img_list,
        "column"     : column,
        "fig_dpi"    : 200,
        "figtitle"   : (f"( cam overlay ) {fish_name_for_dataset} : {cam_overlay_img_list[-1].shape[:2]}, "
                        f"correct : {pred_cls_cnt[fish_cls]}/{len(cam_result_path_list)} ({matching_ratio})") ,
        "save_path"  : f"{cam_gallery_dir}/{fish_cls}/{cls_matching_state}/{fish_name_for_dataset}_overlay.png",
        "show_fig"   : False
    }
    plt_show.plot_with_imglist_auto_row(**kwargs_plot_with_imglist_auto_row)
    
    
    pbar_n_fish.update(1)
    pbar_n_fish.refresh()

In [None]:
fish_name_for_dataset_list = [ path.split(os.sep)[-1] for path in glob(os.path.normpath(f"{cam_result_root}/*")) ]
fish_name_for_dataset_list.sort()

pbar_n_fish = tqdm(total=len(fish_name_for_dataset_list), desc="CAM Gallery ")


for fish_name_for_dataset in fish_name_for_dataset_list:
    single_cam_gallery(fish_name_for_dataset, test_selected_dir, test_drop_dir, cam_result_root, pbar_n_fish)


pbar_n_fish.close()