In [None]:
import rp
from easydict import EasyDict
from source.peekaboo import make_image_square
from rp import ic
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
outputs_folder='untracked/dep_peekaboo_RefMatte'

In [None]:
result_folders=outputs_folder
result_folders=rp.get_all_folders(result_folders)
result_folders=[rp.get_all_folders(x) for x in result_folders]
result_folders=rp.list_flatten(result_folders)

In [None]:
result_folders[0]

In [None]:
rp.ic(len(result_folders))

In [None]:
class Result:
    def __init__(self, path):
        self.path=path
        
        self.trial_num=int(rp.get_folder_name(self.path)) # Like 000, 001 etc
        self.params=EasyDict(rp.load_json(rp.path_join(self.path,'params.json')))
        self.prompt=self.params.extra_data.prompt
        self.name=self.params.extra_data.entry.image_name
        self.preset_name=rp.get_folder_name(rp.get_parent_folder(self.path)).split('.')[-1]
        self.square_image_method=self.params.extra_data.square_image_method
        
        self.alpha_path         =rp.path_join(self.path,'alphas','0.png')
        self.image_path         =rp.path_join(self.path,'image.png')
        self.preview_image_path =rp.path_join(self.path,'preview_image.png')
        self.mask_path          =rp.path_join('datasets/RefMatte_RW_100',self.params.extra_data.entry.mask_path)
        self.original_image_path=rp.path_join('datasets/RefMatte_RW_100',self.params.extra_data.entry.image_path)
        self.mask_name=rp.get_file_name(self.mask_path, include_file_extension=False)
        self.other_mask_path    =rp.path_join('datasets/RefMatte_RW_100',
                                      self.params.extra_data.entry.mask_path
                                      .replace('_1.png','_A.png')
                                      .replace('_2.png','_B.png')
                                      .replace('_A.png','_2.png')
                                      .replace('_B.png','_1.png')
                                     )

        self.original_mask       =rp.load_image(self.mask_path          ,use_cache=True)
        self.original_other_mask =rp.load_image(self.other_mask_path    ,use_cache=True)
        
        # self.original_mask, self.original_other_mask = self.original_other_mask, self.original_mask #Do the Ol' switcheroo! This is for testing the reverse...
        
        self.original_image      =rp.load_image(self.original_image_path,use_cache=True)
        self.alpha               =rp.load_image(self.alpha_path         ,use_cache=True)
        self.image               =rp.load_image(self.image_path         ,use_cache=True)
        self.preview_image       =rp.load_image(self.preview_image_path ,use_cache=True)
        self.mask      =rp.cv_resize_image(make_image_square(self.original_mask      , self.square_image_method),rp.get_image_dimensions(self.alpha))
        self.other_mask=rp.cv_resize_image(make_image_square(self.original_other_mask, self.square_image_method),rp.get_image_dimensions(self.alpha))
        
        print(self)
        
    def __repr__(self):
        return 'Result(%s, %s, %s)'%(self.mask_name, self.trial_num, self.preset_name)
    
    @property 
    def best_iou(self):
        return get_best_iou(self)
    
    @property 
    def best_other_iou(self):
        return get_best_iou(self,self.other_mask)
    
def try_load_result(path):
    try:
        return Result(path)
    except Exception as e:
        #AttributeError: 'EasyDict' object has no attribute 'extra_data'
        rp.fansi_print(e,'red')

In [None]:
from rp import *
def get_mask_iou(*masks):
    """Calculates the IOU (intersection over union) of multiple binary masks"""
    masks=detuple(masks)
    assert all(is_image(mask) for mask in masks), 'All masks must be images as defined by rp.is_image'
    assert len(set(get_image_dimensions(mask) for mask in masks))==1, 'All masks must have the same dimensions, but got shapes '+repr(set(get_image_dimensions(mask) for mask in masks))
    masks = as_numpy_array([as_binary_image(as_grayscale_image(mask)) for mask in masks])
    intersection = np.min(masks, axis=0)
    union = np.max(masks, axis=0)
    return np.sum(intersection) / np.sum(union)

def get_iou(result, threshold=.1, mask=None):
    return get_mask_iou(mask if mask is not None else result.mask, rp.as_float_image(result.alpha)>threshold)

@memoized
def get_best_iou(result, mask=None):
    return max(get_iou(result, threshold, mask) for threshold in [.1,.2,.3,.4,.5,.6,.7,.8,.9])

In [None]:
from tqdm import tqdm
import random
from collections import defaultdict
from typing import List

@memoized
def alpha_filter_1(alpha):
    pred_img=alpha
    alpha=pred_img
    std=alpha.std()
    alpha=alpha-alpha.mean()*.56
    alpha=alpha/std/(2+.1)
    pred_img=alpha
    pred_img=rp.cv_gauss_blur(pred_img,10)
    R=45
    pred_img=rp.cv_dilate(pred_img,R,circular=True)
    # pred_img=rp.cv_erode(pred_img,R-6,circular=True)
    pred_img=rp.cv_erode(pred_img,R,circular=True)
    R=9
    pred_img=rp.cv_gauss_blur(pred_img,80)
    pred_img=rp.cv_erode(pred_img,R,circular=True)
    pred_img=rp.cv_dilate(pred_img,R,circular=True)
    pred_img=rp.as_float_image(pred_img)
    return pred_img

@memoized
def alpha_filter_2(alpha):
    pred_img=alpha
    alpha=pred_img
    std=alpha.std()
    #alpha=alpha-alpha.mean()*.56
    #alpha=alpha/std/(2+.1)
    #pred_img=alpha
    #pred_img=rp.cv_gauss_blur(pred_img,10)
    
    R=40
    pred_img=rp.cv_dilate(pred_img,R,circular=True)
    pred_img=rp.cv_erode(pred_img,R,circular=True)
    # R=10
    # pred_img=rp.cv_gauss_blur(pred_img,10)
    # pred_img=rp.cv_erode(pred_img,R,circular=True)
    # pred_img=rp.cv_dilate(pred_img,R,circular=True)
    pred_img=rp.as_float_image(pred_img)
    return pred_img


def chunkoozle(alpha,r=10,iter=1):
    #Bad name maybe
    alpha=as_grayscale_image(alpha)
    alpha=as_float_image(alpha)
    a=alpha
    for _ in range(iter):
        oa=a
        a=cv_dilate(a,r,circular=True)
        a=cv_erode(a,r,circular=True)
        a=as_float_image(a)
        a=np.maximum(a,oa)
    return a

@memoized
def alpha_filter_3(alpha):
    alpha=chunkoozle(alpha,60,2)
    alpha=rp.as_float_image(alpha)
    return alpha

@memoized
def alpha_filter_4(alpha):
    alpha=chunkoozle(alpha,90,3)
    alpha=rp.as_float_image(alpha)
    return alpha

@memoized
def alpha_filter_5(alpha):
    alpha=chunkoozle(alpha,70,1)
    alpha=rp.as_float_image(alpha)
    return alpha


In [None]:
results = [x for x in rp.par_map(try_load_result, result_folders) if x is not None]

In [None]:
best_results = sorted(results, key=get_best_iou, reverse=True)
best_results = [x for x in best_results if x.square_image_method=='scale']

In [None]:
for i in range(10):
    bri=best_results[i]
    print(get_best_iou(bri))
    rp.display_image(rp.horizontally_concatenated_images(bri.mask,bri.preview_image))
    print()
    print()

In [None]:
#REVERSED
for i in range(20):
    bri=best_results[-i-1]
    print(get_best_iou(bri))
    rp.display_image(rp.horizontally_concatenated_images(bri.mask,bri.preview_image))
    print()
    print()

In [None]:
plot_points = [
    (x.best_iou, x.best_other_iou) for x in results 
    if (
        # True
        # x.preset_name=='midas_raster_bilateral_low_grav_bilat0rgb_sd15__200iter' 
        # and 
        x.square_image_method=='scale'
    )
]
rp.scatter_plot(plot_points, xlabel='IOU', ylabel='Other IOU')

In [None]:
ious=[get_best_iou(result) for result in results]
sorted(ious)
import matplotlib.pyplot as plt

def display_histogram(scores: list):
    # Create a histogram with 10 bins
    plt.hist(scores, bins=10, range=(0, 1), edgecolor='black')

    # Add labels and title
    plt.xlabel('IOU')
    plt.ylabel('Frequency')
    plt.title('Histogram of IOUs')

    # Display the histogram
    plt.show()
    
display_histogram(ious)
ic(rp.median(ious),rp.mean(ious))

In [None]:
clusters=cluster_by_key(results, key=lambda x:(x.preset_name, x.square_image_method),as_dict=True)
rp.pretty_print({cat:mean(x.best_iou for x in clusters[cat]) for cat in clusters})

In [None]:
cluster_ious=[x.best_iou for x in clusters[('midas_raster_bilateral_low_grav_bilat0rgb_sd15__200iter', 'scale')]]
display_histogram(cluster_ious)
ic(rp.median(cluster_ious),rp.mean(cluster_ious))