In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

## Initialization

### Imports

In [None]:
import os
import sys
import cv2
import json
import glob
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

sys.path.append("../code/")

In [None]:
from params import *
from utils.rle import *

from data.dataset import load_image

from utils.metrics import dice_scores_img
from utils.plots import plot_heatmap_preds, plot_contours_preds

### Load

In [None]:
df_info = pd.read_csv(DATA_PATH + f"HuBMAP-20-dataset_information.csv")
df_mask = pd.read_csv(DATA_PATH + "train_4.csv")
# df = pd.read_csv(OUT_PATH + "df_images.csv")

### Data

In [None]:
root = TIFF_PATH_4
rle_path = DATA_PATH + "train_4.csv"
reduce_factor = 1
rles = pd.read_csv(rle_path)

### Experiment

In [None]:
log_folder = "../logs/2021-04-05/4/"  # b1

In [None]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

config = json.load(open(log_folder + 'config.json', 'r'))
config = Config(**config)

In [None]:
global_threshold = 0.4

In [None]:
preds = glob.glob(log_folder + "pred_*.npy")

In [None]:
preds

### Image, truth & pred

In [None]:
THRESHOLD = 0.4

In [None]:
NAMES = [
    "b9a3865fc",
    "aaa6a05cc",
    "e79de561c",
    "8242609fa",
    "2f6ecfcdf",
    "0486052bb",
    "26dc41664",
    "afa5e8098",
    "54f2eec69",
    "cb2d976f4",
    "4ef6695ce",
    "095bf7a1f",
    "1e2425f28",
    "c68fe75ea",
    "b2dc8411c",
]

In [None]:
mask_name = "b9a3865fc"

idx = [i for i, path in enumerate(preds) if mask_name in path][0]

In [None]:
probs = np.load(preds[idx]).astype(np.float32)
pred = (probs > THRESHOLD).astype(np.uint8)

In [None]:
img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)

In [None]:
rle = df_mask[df_mask['id'] == mask_name]['encoding']
mask = enc2mask(rle, (img.shape[1], img.shape[0]))

In [None]:
mask.shape, img.shape, pred.shape

In [None]:
score = dice_scores_img(pred , mask)
print(f'Score for downscaled image is {score:.4f}')

### Get bounding boxes

In [None]:
def extract_components(probs, threshold=0.5, plot=True):
    x_coords = np.arange(probs.shape[0])
    y_coords = np.arange(probs.shape[1])
    
    mask = (probs > threshold).astype(np.uint8)
    num_component, components = cv2.connectedComponents(mask, connectivity=8)
    
    boxes = []
    
    for c in tqdm(range(1, num_component)):
        component = (components == c)
        
        x_comp = x_coords[component.sum(1) > 0][[0, -1]]
        y_comp = y_coords[component.sum(0) > 0][[0, -1]]
        
        boxes.append([x_comp[0], x_comp[1] + 1, y_comp[0], y_comp[1] + 1])
        
        if plot and not (c % 100):
            plt.figure(figsize=(5, 5))
            plt.imshow(components[boxes[-1][0]: boxes[-1][1], boxes[-1][2]: boxes[-1][3]])
            plt.show()
    
#         break
        
    print(f'Found {len(boxes)} candidates')
    return boxes

In [None]:
THRESHOLD = 0.1
PLOT = True

In [None]:
boxes_dic = {}

for idx, pred in enumerate(preds):
    mask_name = pred.split('/')[-1].split('_')[1][:-4]
    print(f'\n  -> Mask {mask_name}')
    
#     rle = df_mask[df_mask['id'] == mask_name]['encoding']
#     img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)
#     mask = enc2mask(rle, (img.shape[1], img.shape[0]))
    
    probs = np.load(pred)

    boxes = extract_components(
        probs, 
        threshold=THRESHOLD,
        plot=PLOT,
    )
    
    boxes_dic[mask_name] = boxes
    
    break
    

### Save as images

In [None]:
def extend_box(box, size=64):
    """
    Extends a bounding box to be of a chosen size.
    Args:
        box (numpy array ): Bounding box.
        size (int, optional): Target size. Defaults to 64.
    Returns:
        numpy array: Extended bounding box.
    """
    w = box[1] - box[0]
    h = box[3] - box[2]

    dw = (size - w) / 2
    dh = (size - h) / 2

    new_box = [
        box[0] - np.floor(dw),
        box[1] + np.ceil(dw),
        box[2] - np.floor(dh),
        box[3] + np.ceil(dh),
    ]
    return np.array(new_box).astype(int)


def adapt_to_shape(box, shape):
    """
    Modifies a bounding box to fit in a given shape.
    Args:
        box (numpy array): Bounding box.
        shape (numpy array): Shape (H, W).
    Returns:
        numpy array: Adapted bounding box.
    """
    if box[0] < 0:
        box[1] -= box[0]
        box[0] = 0
    elif box[1] >= shape[0]:
        diff = box[1] - shape[0]
        box[1] -= diff
        box[0] -= diff

    if box[2] < 0:
        box[3] -= box[2]
        box[2] = 0

    elif box[3] >= shape[1]:
        diff = box[3] - shape[1]
        box[3] -= diff
        box[2] -= diff

    return box

In [None]:
# boxes_dic = json.load(open(log_folder + f"boxes_{str(THRESHOLD)[-1]}.json", 'r'))

In [None]:
SIZE = 192
SAVE = False
PLOT = True

In [None]:
SAVE_DIR = log_folder + f"boxes_{str(THRESHOLD)[-1]}/"

try:
    os.mkdir(SAVE_DIR)
except FileExistsError:
    print('Folder already exists !')

In [None]:
for idx, pred in enumerate(tqdm(preds)):
    mask_name = pred.split('/')[-1].split('_')[1][:-4]
    print(f'\n  -> Mask {mask_name}')
    
    rle = df_mask[df_mask['id'] == mask_name]['encoding']
    img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)
    mask = enc2mask(rle, (img.shape[1], img.shape[0]))
    
    probs = np.load(pred)
    
    for i, box in enumerate(boxes_dic[mask_name]):
        
        box = extend_box(box, SIZE)
        box = adapt_to_shape(box, img.shape)

        crop_img = img[box[0] : box[1], box[2]:box[3]]
        crop_prob = probs[box[0] : box[1], box[2]:box[3]].astype(np.float32)
        crop_mask = mask[box[0] : box[1], box[2]:box[3]]
        
        crop = np.concatenate([
            crop_img.astype(np.float32) / 255,
            crop_prob[:, :, None],
            crop_mask[:, :, None],
        ], -1)
        
        if SAVE:
            np.save(SAVE_DIR + f'{mask_name}_{i}.npy', crop)
        
        if PLOT and not (i % 100):
            plt.figure(figsize=(15,5))
            plt.subplot(1, 3, 1)
            plt.imshow(crop_img)
            plt.subplot(1, 3, 2)
            plt.imshow(crop_mask)
            plt.subplot(1, 3, 3)
            plt.imshow(crop_prob)
            plt.show()
            break
    break

In [None]:
for k in boxes_dic.keys():
    boxes_dic[k] = np.array(boxes_dic[k]).astype(int).tolist()

with open(log_folder + f"boxes_{str(THRESHOLD)[-1]}.json", "w") as f:
    json.dump(boxes_dic, f)
    
print(f'Saved boxes to {log_folder + f"boxes_{str(THRESHOLD)[-1]}.json"}')