In [1]:
import pickle
import glob
import os
from tqdm import tqdm
import numpy as np
import copy
import plotly.express as px

num_classes = 19
exp = 'images'

image_keys = ['image_np01', 'label', 'classes_certify', 'boundary_map', 'gt_adaptive_label']
def read_exp_dir(ds, exp='table'):
    exp_dir = f'/BS/mlcysec2/work/hierarchical-certification/log/{ds}/{exp}'
    overall_dict = {}
    for file in tqdm(glob.glob(os.path.join(exp_dir, '*.pkl'))[:30], desc=f'reading from {exp_dir}'):
        d = pickle.load(open(file, 'rb'))
        filename = os.path.basename(file).replace('.pkl', '')
        overall_dict[filename] = {}
        for image_name, image_d in d.items():
            for model_type, model_d in image_d.items():
                if model_type not in overall_dict:
                    overall_dict[filename][model_type] = {}
                for metric, value in model_d.items():
                    overall_dict[filename][model_type][metric] = value
                        
    return overall_dict

dir_dict = {}
for ds in ['cityscapes', 'acdc', 'cocostuff', 'pascal_ctx']:
    dir_dict[ds] = read_exp_dir(ds, exp)

reading from /BS/mlcysec2/work/hierarchical-certification/log/cityscapes/images:  70%|██████▉   | 348/500 [02:09<01:22,  1.85it/s]

In [47]:
def hex_to_rgb(h):
    if '#' in h:
        h = h.lstrip('#')
        return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
    else:
        return eval(h.replace('rgb', ''))
    
cs_color_dict = {'road': px.colors.qualitative.Dark24[5],  # street - dark gray
'sidewalk': px.colors.qualitative.Set1[8], # sidewalk - brighter gray
'building': px.colors.qualitative.Antique[5], # building - 
'wall': px.colors.qualitative.T10[9], # wall - 
'fence': px.colors.qualitative.Pastel1[6], # fence - 
'pole': px.colors.qualitative.Bold[4], # pole - 
'traffic light': px.colors.qualitative.Set2[5], # traffic light - 
'traffic sign': px.colors.qualitative.Set1[5], # traffic sign - 
'vegetation': px.colors.qualitative.Dark24[2], # vegetation
'terrain': px.colors.qualitative.G10[7], # terrain
'sky': px.colors.qualitative.Plotly[5], # sky
'person': px.colors.qualitative.Set1[0], # person
'rider': px.colors.qualitative.T10[6], # rider
'car': px.colors.qualitative.Dark24[9], # car
'truck': px.colors.qualitative.Alphabet[1], # truck
'bus': px.colors.qualitative.Dark24[19], # bus
'train': px.colors.qualitative.D3[9], # train
'motorcycle': px.colors.qualitative.Pastel[0], # motorcycle
'bicycle': px.colors.qualitative.Pastel[9], # bicycle
'construction & vegetation': px.colors.qualitative.Alphabet[5],
'traffic-sign': px.colors.qualitative.Set2[1],
'human':px.colors.qualitative.Vivid[9],
'vehicle': px.colors.qualitative.Alphabet[13],
'static obstacle': px.colors.qualitative.Pastel[5],
'dynamic obstacle': px.colors.qualitative.Set3[9],
'flat obstacle':px.colors.qualitative.Pastel[10],
'obstacle': px.colors.qualitative.T10[8] 
}
cs_color_pallette = np.array([hex_to_rgb(c) for c in cs_color_dict.values()]+[(255, 255, 255)]*(256-len(cs_color_dict.values())))

In [52]:
import cv2
graph_images_pth = '/BS/mlcysec2/work/hierarchical-certification/graph_images/images'

dataset_palette = {'cityscapes': cs_color_pallette}
for ds, ds_dict in dir_dict.items():
    for im_filename, im_dict in ds_dict.items():
        for model_type, model_d in im_dict.items():
            for k, v in model_d.items():
                if k == 'image_np01':
                    image_np01 = v
                if k == 'boundary_map':
                    boundary_map = v
                if k == 'label':
                    label = v
                if k == 'classes_certify':
                    if isinstance(model_type, tuple):
                        n, n0, f, hi, sigma, tau = model_type
                        if f == None:
                            seg_classes_certify = v
                        else:
                            ada_classes_certify = v
        outdir = os.path.join(graph_images_pth, im_filename.split('.')[0])
        w, h, c = image_np01.shape
        os.makedirs(outdir, exist_ok=True)
        cv2.imwrite(os.path.join(outdir, 'image_np01.png'), image_np01[...,::-1]*255)
        cv2.imwrite(os.path.join(outdir, 'boundary_map.png'), boundary_map.reshape((w, h))*255)
        cv2.imwrite(os.path.join(outdir, 'boundary_map.png'), (1-boundary_map).reshape((w, h))*255)
        boundary_idx = boundary_map == 1
        non_boundary_idx = boundary_map == 0
        ada_abstain_idx = ada_classes_certify == 254
        seg_abstain_idx = seg_classes_certify == 254

        zeros = np.zeros((w*h))
        zeros[seg_classes_certify == 254] = 255
        abstain_seg = zeros.reshape((w, h))
        cv2.imwrite(os.path.join(outdir, 'seg_abstain_map.png'), abstain_seg)
        
        zeros = np.zeros((w*h))
        zeros[ada_classes_certify == 254] = 255
        abstain_ada = zeros.reshape((w, h))
        cv2.imwrite(os.path.join(outdir, 'ada_abstain_map.png'), abstain_ada)

        zeros = np.zeros((w*h))
        zeros[ada_abstain_idx & boundary_idx] = 255
        cv2.imwrite(os.path.join(outdir, 'ada_inter_abstain_boundary.png'), zeros.reshape((w, h)))
        zeros = np.zeros((w*h))
        zeros[ada_abstain_idx & non_boundary_idx] = 255
        cv2.imwrite(os.path.join(outdir, 'ada_inter_abstain_non_boundary.png'), zeros.reshape((w, h)))
        
        zeros = np.zeros((w*h))
        zeros[seg_abstain_idx & boundary_idx] = 255
        cv2.imwrite(os.path.join(outdir, 'seg_inter_abstain_boundary.png'), zeros.reshape((w, h)))
        zeros = np.zeros((w*h))
        abstain_idx = seg_classes_certify == 254
        zeros[seg_abstain_idx & non_boundary_idx] = 255
        cv2.imwrite(os.path.join(outdir, 'seg_inter_abstain_non_boundary.png'), zeros.reshape((w, h)))   
        
        p = dataset_palette[ds]
        label = p[label].reshape((w, h, c))
        
        cv2.imwrite(os.path.join(outdir, 'label.png'), label[...,::-1])
        
        p = dataset_palette[ds]
        seg_classes_certify = p[seg_classes_certify].reshape((w, h, c))
        cv2.imwrite(os.path.join(outdir, 'seg_classes_certify.png'), seg_classes_certify[...,::-1])
        
        p = dataset_palette[ds]
        ada_classes_certify = p[ada_classes_certify].reshape((w, h, c))
        cv2.imwrite(os.path.join(outdir, 'ada_classes_certify.png'), ada_classes_certify[...,::-1])
            

        superimposed_gt = cv2.addWeighted(label.astype(np.uint), 0.6, (image_np01*255).astype(np.uint), 0.4, 0)
        superimposed_ada = cv2.addWeighted(ada_classes_certify.astype(np.uint), 0.6, (image_np01*255).astype(np.uint), 0.4, 0)
        superimposed_ada[ada_classes_certify == (255, 255, 255)] = 255
        superimposed_seg = cv2.addWeighted(seg_classes_certify.astype(np.uint), 0.6, (image_np01*255).astype(np.uint), 0.4, 0)
        superimposed_seg[seg_classes_certify == (255, 255, 255)] = 255
        
        cv2.imwrite(os.path.join(outdir, 'label_im.png'), superimposed_gt[...,::-1])
        cv2.imwrite(os.path.join(outdir, 'ada_im_classes_certify.png'), superimposed_ada[...,::-1])
        cv2.imwrite(os.path.join(outdir, 'seg_im_classes_certify.png'), superimposed_seg[...,::-1])

        print(image_np01.shape, boundary_map.shape, label.shape, seg_classes_certify.shape, ada_classes_certify.shape)

(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,) (1024, 2048, 3) (1024, 2048, 3) (1024, 2048, 3)
(1024, 2048, 3) (2097152,