In [1]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

## Initialization

### Imports

In [2]:
import os
import sys
import cv2
import json
import glob
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

sys.path.append("../code/")

In [3]:
from params import *
from utils.rle import *

from data.dataset import load_image

from utils.metrics import dice_scores_img
from utils.plots import plot_heatmap_preds, plot_contours_preds

### Load

In [4]:
df_info = pd.read_csv(DATA_PATH + f"HuBMAP-20-dataset_information.csv")
df_mask = pd.read_csv(DATA_PATH + "train_4.csv")
# df = pd.read_csv(OUT_PATH + "df_images.csv")

### Data

In [5]:
root = TIFF_PATH_4
rle_path = DATA_PATH + "train_4.csv"
reduce_factor = 1
rles = pd.read_csv(rle_path)

### Experiment

In [6]:
log_folder = "../logs/2021-04-05/4/"  # b1

In [7]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

config = json.load(open(log_folder + 'config.json', 'r'))
config = Config(**config)

In [8]:
global_threshold = 0.4

In [9]:
preds = glob.glob(log_folder + "pred_*.npy")

In [10]:
preds

['../logs/2021-04-05/4/pred_b9a3865fc.npy',
 '../logs/2021-04-05/4/pred_aaa6a05cc.npy',
 '../logs/2021-04-05/4/pred_e79de561c.npy',
 '../logs/2021-04-05/4/pred_8242609fa.npy',
 '../logs/2021-04-05/4/pred_2f6ecfcdf.npy',
 '../logs/2021-04-05/4/pred_0486052bb.npy',
 '../logs/2021-04-05/4/pred_26dc41664.npy',
 '../logs/2021-04-05/4/pred_afa5e8098.npy',
 '../logs/2021-04-05/4/pred_54f2eec69.npy',
 '../logs/2021-04-05/4/pred_cb2d976f4.npy',
 '../logs/2021-04-05/4/pred_4ef6695ce.npy',
 '../logs/2021-04-05/4/pred_095bf7a1f.npy',
 '../logs/2021-04-05/4/pred_1e2425f28.npy',
 '../logs/2021-04-05/4/pred_c68fe75ea.npy',
 '../logs/2021-04-05/4/pred_b2dc8411c.npy']

### Image, truth & pred

In [11]:
THRESHOLD = 0.4

In [12]:
NAMES = [
    "b9a3865fc",
    "aaa6a05cc",
    "e79de561c",
    "8242609fa",
    "2f6ecfcdf",
    "0486052bb",
    "26dc41664",
    "afa5e8098",
    "54f2eec69",
    "cb2d976f4",
    "4ef6695ce",
    "095bf7a1f",
    "1e2425f28",
    "c68fe75ea",
    "b2dc8411c",
]

In [13]:
mask_name = "b9a3865fc"

idx = [i for i, path in enumerate(preds) if mask_name in path][0]

In [14]:
probs = np.load(preds[idx]).astype(np.float32)
pred = (probs > THRESHOLD).astype(np.uint8)

In [15]:
img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)

In [16]:
rle = df_mask[df_mask['id'] == mask_name]['encoding']
mask = enc2mask(rle, (img.shape[1], img.shape[0]))

In [17]:
mask.shape, img.shape, pred.shape

((7823, 10107), (7823, 10107, 3), (7823, 10107))

In [18]:
score = dice_scores_img(pred , mask)
print(f'Score for downscaled image is {score:.4f}')

Score for downscaled image is 0.9424


### Post-processing

In [49]:
def extract_components(probs, threshold=0.5, plot=True):
    x_coords = np.arange(probs.shape[0])
    y_coords = np.arange(probs.shape[1])
    
    mask = (probs > threshold).astype(np.uint8)
    num_component, components = cv2.connectedComponents(mask, connectivity=8)
    
    boxes = []
    
    for c in tqdm(range(1, num_component)):
        component = (components == c)
        
        x_comp = x_coords[component.sum(1) > 0][[0, -1]]
        y_comp = y_coords[component.sum(0) > 0][[0, -1]]
        
        boxes.append([x_comp[0], x_comp[1] + 1, y_comp[0], y_comp[1] + 1])
        
        if plot:
            plt.figure(figsize=(5, 5))
            plt.imshow(components[boxes[-1][0]: boxes[-1][1], boxes[-1][2]: boxes[-1][3]])
            plt.show()
    
    print(f'Found {len(boxes)} candidates')
    return boxes

In [50]:
THRESHOLD = 0.1
PLOT = False

In [51]:
boxes_dic = {}

for idx, pred in enumerate(preds):
    mask_name = pred.split('/')[-1].split('_')[1][:-4]
    print(f'\n  -> Mask {mask_name}')
    
#     rle = df_mask[df_mask['id'] == mask_name]['encoding']
#     img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)
#     mask = enc2mask(rle, (img.shape[1], img.shape[0]))
    
    probs = np.load(pred)

    boxes = extract_components(
        probs, 
        threshold=THRESHOLD,
        plot=PLOT,
    )
    
    boxes_dic[mask_name] = boxes


  -> Mask b9a3865fc


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=490.0), HTML(value='')))


Found 490 candidates

  -> Mask aaa6a05cc


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=117.0), HTML(value='')))


Found 117 candidates

  -> Mask e79de561c


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=191.0), HTML(value='')))


Found 191 candidates

  -> Mask 8242609fa


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=624.0), HTML(value='')))


Found 624 candidates

  -> Mask 2f6ecfcdf


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=174.0), HTML(value='')))


Found 174 candidates

  -> Mask 0486052bb


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=135.0), HTML(value='')))


Found 135 candidates

  -> Mask 26dc41664


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))


Found 255 candidates

  -> Mask afa5e8098


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=259.0), HTML(value='')))


Found 259 candidates

  -> Mask 54f2eec69


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=143.0), HTML(value='')))


Found 143 candidates

  -> Mask cb2d976f4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=354.0), HTML(value='')))


Found 354 candidates

  -> Mask 4ef6695ce


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=462.0), HTML(value='')))


Found 462 candidates

  -> Mask 095bf7a1f


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=352.0), HTML(value='')))


Found 352 candidates

  -> Mask 1e2425f28


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=207.0), HTML(value='')))


Found 207 candidates

  -> Mask c68fe75ea


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=126.0), HTML(value='')))


Found 126 candidates

  -> Mask b2dc8411c


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=132.0), HTML(value='')))


Found 132 candidates


In [66]:
for k in boxes_dic.keys():
    boxes_dic[k] = np.array(boxes_dic[k]).astype(int).tolist()

In [71]:
with open(log_folder + f"boxes_{str(THRESHOLD)[-1]}.json", "w") as f:
    json.dump(boxes_dic, f)
    
print(f'Saved boxes to {log_folder + f"boxes_{str(THRESHOLD)[-1]}.json"}')

Saved boxes to ../logs/2021-04-05/4/boxes_1.json
