In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from glob import glob
import cv2
from scipy.sparse import csr_matrix, save_npz
from scipy.ndimage import label as measure_label
import matplotlib.pyplot as plt
plt.rcParams['image.origin'] = 'lower'

from collections import Counter
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed

from ink_helpers import (load_image,seed_everything,
                         load_fragment, DiceLoss, FocalLoss)

In [2]:
def process_label_idx(split_idx):
    mask3D_labels, num_labels = measure_label(mask3D_binary_list[split_idx])
    component_sizes = np.bincount(mask3D_labels.flat)
    gc.collect()

    # remove small components
    def remove_label(label_idx):
        if component_sizes[label_idx] <= 10 or (
            component_sizes[label_idx] <= 100 
            and np.count_nonzero(mask3D_labels[:2]==label_idx) == 0
        ) or (
            component_sizes[label_idx] <= 1000
            and np.count_nonzero(mask3D_labels[:40]==label_idx) == 0
        ):
            mask3D_binary_list[split_idx][mask3D_labels==label_idx] = 0

    with tqdm(total=len(component_sizes)-1, mininterval=10, 
              desc=f'Remove small clusters in split {split_idx}') as pbar:
        with ThreadPoolExecutor(max_workers=20) as executor:
            futures = [executor.submit(remove_label, label_idx) 
                       for label_idx in range(1, len(component_sizes))]
            for future in as_completed(futures):
                pbar.update(1)

In [3]:
all_frag_ids = ['1', '2a', '2b', '2c', '3']
id2dir = {id:f'./frags/train_{id}' for id in all_frag_ids}

id2images,id2frag_mask,id2ink_mask = {},{},{}
for frag_id in tqdm(all_frag_ids):
    images,frag_mask,ink_mask = load_fragment(frag_id)
    id2images[frag_id] = images
    id2frag_mask[frag_id] = frag_mask
    id2ink_mask[frag_id] = ink_mask

100%|█████████████████████████████████████████████| 5/5 [01:13<00:00, 14.68s/it]


In [4]:
for frag_id in all_frag_ids:
    folder_prefix = f'train_{frag_id}'
    print(f'Working on Fragment {folder_prefix}')
    
    # Load the images 
    images = id2images[frag_id]
    frag_mask = id2frag_mask[frag_id]
    ink_mask = id2ink_mask[frag_id]
    
    mask3D = np.zeros_like(images, dtype=np.uint8)

    for slice_axis in range(2): # x and y
        for slice_idx in tqdm(
            range(images.shape[slice_axis+1]), mininterval=5, 
            desc=f"Scanning the {'h' if not slice_axis else 'w'} side"
        ):

            ### load vertical slice
            if slice_axis == 0: # h cut
                frag_cols = np.where(images[:, slice_idx, :].min(axis=0)>0)[0]
                if len(frag_cols) < 3: continue
                side_cut_img = images[:, slice_idx, frag_cols]
            else: # w cut
                frag_cols = np.where(images[:, :, slice_idx].min(axis=0)>0)[0]
                if len(frag_cols) < 3: continue
                side_cut_img = images[:, frag_cols, slice_idx]

            # normalization to increase contrast
            side_cut_img = ((side_cut_img-side_cut_img.min())/(side_cut_img.max()-side_cut_img.min())*255).astype(np.uint8)

            ### Type One: plain Ostu thresholded
            otsu_value, plain_otsu_thresh = cv2.threshold(side_cut_img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
            otsu_value *= 0.98

            # Perform connected component analysis
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(plain_otsu_thresh)
            # Iterate over components and remove isolated bright points
            for i in range(1, num_labels): # skip background component (i.e. label 0)
                area = stats[i, cv2.CC_STAT_AREA]
                if area <= 3 or (centroids[i][1] > 30 and area < centroids[i][1] - 30):
                    labels[labels == i] = 0 # set component pixels to background
            remove_isolated_otsu_thresh = (labels>0).astype(bool).astype(np.uint8)

            ### Type Two: height_adjust_otsu_thresh
            top_surface_height = np.where(
                remove_isolated_otsu_thresh.mean(axis=1)>min(0.05, 0.5*remove_isolated_otsu_thresh[0].mean())
            )[0][-1]
            left_margin = max(min(20, len(frag_cols)//10), np.where(remove_isolated_otsu_thresh.sum(axis=0)>5)[0][0])
            right_margin = max(min(20, len(frag_cols)//10), len(frag_cols)-np.where(remove_isolated_otsu_thresh.sum(axis=0)>5)[0][-1])
            if len(frag_cols)-right_margin-left_margin <= 2:
                continue
            linspace_len = min(80, (len(frag_cols)-right_margin-left_margin)//5+2)
            linspace_arr = np.linspace(0, 1, linspace_len)

            height_adjust_otsu_thresh = np.zeros_like(side_cut_img, dtype=np.uint8)
            base_otsu_ratio = 0.9
            saturate_distance = 10
            for height in range(images.shape[0]):
                adjust_otsu_val = otsu_value * max(base_otsu_ratio, 
                                                   1-(1-base_otsu_ratio)/saturate_distance*(top_surface_height-height))
                adjust_otsu_val_rows = adjust_otsu_val * np.ones_like(height_adjust_otsu_thresh[height, :])
                adjust_otsu_val_rows[:left_margin] = max(otsu_value, adjust_otsu_val)
                adjust_otsu_val_rows[left_margin:left_margin+linspace_len] = (
                    max(otsu_value, adjust_otsu_val)*(1-linspace_arr) 
                    + adjust_otsu_val_rows[left_margin:left_margin+linspace_len]*linspace_arr
                )
                adjust_otsu_val_rows[-right_margin:] = max(otsu_value, adjust_otsu_val)
                adjust_otsu_val_rows[-right_margin-linspace_len:-right_margin] = (
                    adjust_otsu_val_rows[-right_margin-linspace_len:-right_margin]*(1-linspace_arr)
                    + max(otsu_value, adjust_otsu_val)*linspace_arr
                )

                height_adjust_otsu_thresh[height, :] = side_cut_img[height, :]>adjust_otsu_val_rows

            # Perform connected component analysis
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(height_adjust_otsu_thresh)

            # Iterate over components and remove isolated bright points
            for i in range(1, num_labels): # skip background component (i.e. label 0)
                area = stats[i, cv2.CC_STAT_AREA]
                if (area <= 3 
                    or area < 0.6*(centroids[i][1]-top_surface_height)
                    or area < 0.6*(left_margin-centroids[i][0])
                    or area < 0.6*(centroids[i][0]+right_margin-len(frag_cols))
                   ):
                    labels[labels == i] = 0 # set component pixels to background
            height_adjust_otsu_thresh = (labels>0).astype(bool).astype(np.uint8)

            ### Summary: add to 3D mask
            if slice_axis == 0: # h cut
                mask3D[:, slice_idx, frag_cols] += 2*remove_isolated_otsu_thresh + height_adjust_otsu_thresh
            else: # w cut
                mask3D[:, frag_cols, slice_idx] += 2*remove_isolated_otsu_thresh + height_adjust_otsu_thresh
                
                
    split_w_pos = [(mask3D.shape[2] // 2) * a for a in range(1,2)]

    mask3D_binary_list = np.split((mask3D >= 5).astype(np.uint8), split_w_pos, axis=2)
    print('split shapes: ', [a.shape for a in mask3D_binary_list])

    del mask3D, side_cut_img
    del plain_otsu_thresh, remove_isolated_otsu_thresh, height_adjust_otsu_thresh, labels
    gc.collect()
    
        
    # process each split_idx sequentially
    for split_idx in range(len(mask3D_binary_list)):
        process_label_idx(split_idx)

    mask3D_binary = np.concatenate(mask3D_binary_list, axis=2).astype(bool)
    del mask3D_binary_list
    gc.collect()

    
    np.save(f'./frags/{folder_prefix}/mask3D_binary.npy', mask3D_binary)
    del mask3D_binary
    gc.collect()

Working on Fragment train_1


Scanning the h side: 100%|█████████████████| 8181/8181 [00:37<00:00, 216.77it/s]
Scanning the w side: 100%|██████████████████| 6330/6330 [01:39<00:00, 63.51it/s]


split shapes:  [(65, 8181, 3165), (65, 8181, 3165)]


Remove small clusters in split 0: 100%|█| 30337/30337 [1:36:26<00:00,  5.24it/s]
Remove small clusters in split 1: 100%|█| 19762/19762 [1:02:48<00:00,  5.24it/s]


Working on Fragment train_2a


Scanning the h side: 100%|█████████████████| 6099/6099 [00:36<00:00, 167.50it/s]
Scanning the w side: 100%|██████████████████| 6903/6903 [01:38<00:00, 70.15it/s]


split shapes:  [(65, 6099, 3451), (65, 6099, 3452)]


Remove small clusters in split 0: 100%|█| 24810/24810 [1:03:07<00:00,  6.55it/s]
Remove small clusters in split 1: 100%|███| 17879/17879 [45:00<00:00,  6.62it/s]


Working on Fragment train_2b


Scanning the h side: 100%|██████████████████| 4500/4500 [01:22<00:00, 54.66it/s]
Scanning the w side: 100%|██████████████████| 9278/9278 [02:04<00:00, 74.51it/s]


split shapes:  [(65, 4500, 4639), (65, 4500, 4639)]


Remove small clusters in split 0: 100%|█| 45602/45602 [1:51:48<00:00,  6.80it/s]
Remove small clusters in split 1: 100%|█| 50318/50318 [2:06:24<00:00,  6.63it/s]


Working on Fragment train_2c


Scanning the h side: 100%|██████████████████| 4229/4229 [01:05<00:00, 64.09it/s]
Scanning the w side: 100%|██████████████████| 9504/9504 [01:39<00:00, 95.93it/s]


split shapes:  [(65, 4229, 4752), (65, 4229, 4752)]


Remove small clusters in split 0: 100%|█| 32429/32429 [1:17:01<00:00,  7.02it/s]
Remove small clusters in split 1: 100%|█| 41306/41306 [1:39:15<00:00,  6.94it/s]


Working on Fragment train_3


Scanning the h side: 100%|█████████████████| 7606/7606 [00:39<00:00, 194.36it/s]
Scanning the w side: 100%|██████████████████| 5249/5249 [01:30<00:00, 58.15it/s]


split shapes:  [(65, 7606, 2624), (65, 7606, 2625)]


Remove small clusters in split 0: 100%|███| 15351/15351 [36:36<00:00,  6.99it/s]
Remove small clusters in split 1: 100%|███| 20902/20902 [51:12<00:00,  6.80it/s]


## quick 3D mask 

In [6]:
for frag_id in all_frag_ids:
    folder_prefix = f'train_{frag_id}'
    print(f'Working on Fragment {folder_prefix}')
    
    # Load the images 
    images = id2images[frag_id]
    frag_mask = id2frag_mask[frag_id]
    ink_mask = id2ink_mask[frag_id]
    
    mask3D = np.zeros_like(images, dtype=np.uint8)

    for slice_axis in range(2): # x and y
        for slice_idx in tqdm(
            range(images.shape[slice_axis+1]), mininterval=5, 
            desc=f"Scanning the {'h' if not slice_axis else 'w'} side"
        ):

            ### load vertical slice
            if slice_axis == 0: # h cut
                frag_cols = np.where(images[:, slice_idx, :].min(axis=0)>0)[0]
                if len(frag_cols) < 3: continue
                side_cut_img = images[:, slice_idx, frag_cols]
            else: # w cut
                frag_cols = np.where(images[:, :, slice_idx].min(axis=0)>0)[0]
                if len(frag_cols) < 3: continue
                side_cut_img = images[:, frag_cols, slice_idx]

            ### Type One: plain Ostu thresholded
            otsu_value, plain_otsu_thresh = cv2.threshold(side_cut_img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)

            # Perform connected component analysis
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(plain_otsu_thresh)
            # Iterate over components and remove isolated bright points
            for i in range(1, num_labels): # skip background component (i.e. label 0)
                area = stats[i, cv2.CC_STAT_AREA]
                if area <= 2 or (centroids[i][1] > 50 and area < centroids[i][1] - 50):
                    labels[labels == i] = 0 # set component pixels to background
            remove_isolated_otsu_thresh = (labels>0).astype(bool).astype(np.uint8)

            ### Type Two: height_adjust_otsu_thresh
            top_surface_height = np.where(
                remove_isolated_otsu_thresh.mean(axis=1)>min(0.05, 0.5*remove_isolated_otsu_thresh[0].mean())
            )[0][-1]
            left_margin = max(min(20, len(frag_cols)//10), np.where(remove_isolated_otsu_thresh.sum(axis=0)>5)[0][0])
            right_margin = max(min(20, len(frag_cols)//10), len(frag_cols)-np.where(remove_isolated_otsu_thresh.sum(axis=0)>5)[0][-1])
            if len(frag_cols)-right_margin-left_margin <= 2:
                continue
            linspace_len = min(80, (len(frag_cols)-right_margin-left_margin)//5+2)
            linspace_arr = np.linspace(0, 1, linspace_len)

            height_adjust_otsu_thresh = np.zeros_like(side_cut_img, dtype=np.uint8)
            base_otsu_ratio = 0.95
            saturate_distance = 10
            for height in range(images.shape[0]):
                adjust_otsu_val = otsu_value * max(base_otsu_ratio, 
                                                   1-(1-base_otsu_ratio)/saturate_distance*(top_surface_height-height))
                adjust_otsu_val_rows = adjust_otsu_val * np.ones_like(height_adjust_otsu_thresh[height, :])
                adjust_otsu_val_rows[:left_margin] = max(otsu_value, adjust_otsu_val)
                adjust_otsu_val_rows[left_margin:left_margin+linspace_len] = (
                    max(otsu_value, adjust_otsu_val)*(1-linspace_arr) 
                    + adjust_otsu_val_rows[left_margin:left_margin+linspace_len]*linspace_arr
                )
                adjust_otsu_val_rows[-right_margin:] = max(otsu_value, adjust_otsu_val)
                adjust_otsu_val_rows[-right_margin-linspace_len:-right_margin] = (
                    adjust_otsu_val_rows[-right_margin-linspace_len:-right_margin]*(1-linspace_arr)
                    + max(otsu_value, adjust_otsu_val)*linspace_arr
                )

                height_adjust_otsu_thresh[height, :] = side_cut_img[height, :]>adjust_otsu_val_rows

            # Perform connected component analysis
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(height_adjust_otsu_thresh)

            # Iterate over components and remove isolated bright points
            for i in range(1, num_labels): # skip background component (i.e. label 0)
                area = stats[i, cv2.CC_STAT_AREA]
                if (area <= 2 
                    or area < 0.7*(centroids[i][1]-top_surface_height)
                    or area < 0.7*(left_margin-centroids[i][0])
                    or area < 0.7*(centroids[i][0]+right_margin-len(frag_cols))
                   ):
                    labels[labels == i] = 0 # set component pixels to background
            height_adjust_otsu_thresh = (labels>0).astype(bool).astype(np.uint8)

            ### Summary: add to 3D mask
            if slice_axis == 0: # h cut
                mask3D[:, slice_idx, frag_cols] += 2*remove_isolated_otsu_thresh + height_adjust_otsu_thresh
            else: # w cut
                mask3D[:, frag_cols, slice_idx] += 2*remove_isolated_otsu_thresh + height_adjust_otsu_thresh
                

    # require two Otsu and at least one height adjested Otsu
    quick_mask3D_binary = (mask3D >= 5).astype(bool)
    del mask3D
    gc.collect()
    
    np.save(f'./frags/{folder_prefix}/quick_mask3D_binary.npy', quick_mask3D_binary)
    del quick_mask3D_binary
    gc.collect()

Working on Fragment train_1


Scanning the h side: 100%|█████████████████| 8181/8181 [00:33<00:00, 241.79it/s]
Scanning the w side: 100%|██████████████████| 6330/6330 [01:28<00:00, 71.13it/s]


Working on Fragment train_2a


Scanning the h side: 100%|█████████████████| 6099/6099 [00:30<00:00, 203.21it/s]
Scanning the w side: 100%|██████████████████| 6903/6903 [01:17<00:00, 89.09it/s]


Working on Fragment train_2b


Scanning the h side: 100%|██████████████████| 4500/4500 [01:06<00:00, 67.65it/s]
Scanning the w side: 100%|██████████████████| 9278/9278 [01:49<00:00, 84.47it/s]


Working on Fragment train_2c


Scanning the h side: 100%|██████████████████| 4229/4229 [00:53<00:00, 78.39it/s]
Scanning the w side: 100%|█████████████████| 9504/9504 [01:28<00:00, 107.87it/s]


Working on Fragment train_3


Scanning the h side: 100%|█████████████████| 7606/7606 [00:34<00:00, 221.30it/s]
Scanning the w side: 100%|██████████████████| 5249/5249 [01:21<00:00, 64.43it/s]
