In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

## Initialization

### Imports

In [None]:
import os
import sys
import cv2
import json
import glob
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

sys.path.append("../code/")

In [None]:
from params import *
from utils.rle import *

from data.dataset import load_image

from utils.metrics import dice_scores_img
from utils.plots import plot_heatmap_preds, plot_contours_preds

### Load

In [None]:
df_info = pd.read_csv(DATA_PATH + f"HuBMAP-20-dataset_information.csv")
df_mask = pd.read_csv(DATA_PATH + "train_4.csv")
# df = pd.read_csv(OUT_PATH + "df_images.csv")

### Data

In [None]:
root = TIFF_PATH_4
rle_path = DATA_PATH + "train_4.csv"
reduce_factor = 1
rles = pd.read_csv(rle_path)

### Experiment

In [None]:
log_folder = "../logs/2021-04-05/4/"  # b1

In [None]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

config = json.load(open(log_folder + 'config.json', 'r'))
config = Config(**config)

In [None]:
global_threshold = 0.4

In [None]:
preds = glob.glob(log_folder + "pred_*.npy")

In [None]:
preds

### Image, truth & pred

In [None]:
THRESHOLD = 0.4

In [None]:
NAMES = [
    "b9a3865fc",
    "aaa6a05cc",
    "e79de561c",
    "8242609fa",
    "2f6ecfcdf",
    "0486052bb",
    "26dc41664",
    "afa5e8098",
    "54f2eec69",
    "cb2d976f4",
    "4ef6695ce",
    "095bf7a1f",
    "1e2425f28",
    "c68fe75ea",
    "b2dc8411c",
]

In [None]:
mask_name = "b9a3865fc"

idx = [i for i, path in enumerate(preds) if mask_name in path][0]

In [None]:
probs = np.load(preds[idx]).astype(np.float32)
pred = (probs > THRESHOLD).astype(np.uint8)

In [None]:
img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)

In [None]:
rle = df_mask[df_mask['id'] == mask_name]['encoding']
mask = enc2mask(rle, (img.shape[1], img.shape[0]))

In [None]:
mask.shape, img.shape, pred.shape

In [None]:
score = dice_scores_img(pred , mask)
print(f'Score for downscaled image is {score:.4f}')

## Plot

In [None]:
w = 1000
h = int(w *  mask.shape[0] / mask.shape[1])

In [None]:
fig = plot_contours_preds(img, mask, pred, w=2, downsize=4)

fig.update_layout(
    autosize=False,
    width=w,
    height=h,
)

fig.show()

In [None]:
fig = plot_heatmap_preds(img, mask, probs, w=1, downsize=2)

fig.update_layout(
    autosize=False,
    width=w,
    height=h,
)

fig.show()

### Post-processing

In [None]:
def post_process_mask(probs, threshold_max=0.5, threshold_prob=0.4, threshold_comp=0.3, plot=True):
    
    mask = (probs > threshold_comp).astype(np.uint8)
    num_component, components = cv2.connectedComponents(mask, connectivity=8)
    
    processed_mask = np.zeros(mask.shape, np.uint8)

    maxs = []
    removed = 0
    for c in tqdm(range(1, num_component)):
        component = (components == c)
        
        component_prob = probs[component]
        max_prob = component_prob.max()
        
        maxs.append(max_prob)

        if max_prob > threshold_max:  # accept component
            processed_mask[component] = component_prob > threshold_prob
        else:
            removed += 1
    
    print(f'Removed {removed} components.')

    if plot:
        plt.figure(figsize=(15, 5))
        sns.histplot(maxs, bins=50)
        plt.axvline(threshold_max, color="salmon")
        plt.show()

    return processed_mask

In [None]:
THRESHOLD_MAX = 0.9
THRESHOLD_COMP = 0.4
THRESHOLD_PROB = 0.2

PLOT = False

In [None]:
scores_before = []
scores_after = []

for idx, pred in enumerate(preds):
    mask_name = pred.split('/')[-1].split('_')[1][:-4]
    print(f'\n  -> Mask {mask_name}')
    
    rle = df_mask[df_mask['id'] == mask_name]['encoding']
    img = load_image(os.path.join(TIFF_PATH_4, mask_name + ".tiff"), full_size=False)
    mask = enc2mask(rle, (img.shape[1], img.shape[0]))
    
    probs = np.load(pred)
    pred_mask = (probs > 0.4).astype(np.uint8)

    pred_pp = post_process_mask(
        probs, 
        threshold_comp=THRESHOLD_COMP, 
        threshold_max=THRESHOLD_MAX, 
        threshold_prob=THRESHOLD_PROB,
        plot=PLOT,
    )
    
    scores_before.append(dice_scores_img(pred_mask, mask))
    scores_after.append(dice_scores_img(pred_pp, mask))
    
    print(f'Score before PP : {scores_before[-1] :.4f}')
    print(f'Score after PP :  {scores_after[-1] :.4f}')

In [None]:
print(f'CV before PP : {np.mean(scores_before) :.4f}')
print(f'CV after PP :  {np.mean(scores_after) :.4f}')