**About** : This notebook is used to perform inference on validation data

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import gc
import os
import ast
import sys
import cv2
import glob
import json
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from collections import Counter
warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.metrics import *
from utils.logger import Config, save_to_folder
from utils.rle import rle_encode, rle_decode

from inference.tweaking import *
from inference.validation import *
from inference.post_process import *

from data.preparation import prepare_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines
from inference.validation import inference_val

## Single image explo

In [None]:
EXP_FOLDERS = [  # id #3 - Cort 0.4083 / 0.4091 pp
#     LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3121
#     LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3141
#     LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH + "seb/mrcnn_resnext101_lossdecay/", # 11. mrcnn r101 0.3131
#     LOG_PATH + "seb/mrcnn_r50_lossdecay/", # 12. mrcnn r50 0.3125
#     LOG_PATH + "2021-12-15/0/",  # 14. Cascade b6 - 0.3121
]

In [None]:
if len(EXP_FOLDERS) <= 6:  # small ensemble 
    EXP_FOLDERS_CORT = EXP_FOLDERS
    EXP_FOLDERS_ASTRO = EXP_FOLDERS
    EXP_FOLDERS_SHSY5Y = EXP_FOLDERS
    EXP_FOLDER_CLS = EXP_FOLDERS

assert [f for f in EXP_FOLDERS if f in EXP_FOLDERS_CORT] == EXP_FOLDERS_CORT
assert [f for f in EXP_FOLDERS if f in EXP_FOLDERS_ASTRO] == EXP_FOLDERS_ASTRO
assert [f for f in EXP_FOLDERS if f in EXP_FOLDERS_SHSY5Y] == EXP_FOLDERS_SHSY5Y

In [None]:
ENSEMBLE_CONFIG = {  # best sub
    "use_tta": True,
    "use_tta_masks": True,
    "num_classes": 3,

    "rpn_nms_pre": [5000, 2000, 1000],
    "rpn_iou_threshold": [0.7, 0.75, 0.6],
    "rpn_score_threshold": [0.9, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.7, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.25, 0.5],
    
    "use_for_cort": [f in EXP_FOLDERS_CORT for f in EXP_FOLDERS],
    "use_for_astro": [f in EXP_FOLDERS_ASTRO for f in EXP_FOLDERS],
    "use_for_shsy5y": [f in EXP_FOLDERS_SHSY5Y for f in EXP_FOLDERS],
    "use_for_cls": [f in EXP_FOLDER_CLS for f in EXP_FOLDERS]
}

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))
    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    config.split = "skf"
    configs.append(config)

#     weights.append(sorted(glob.glob(exp_folder + "*.pt")))
    weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

## Inference

In [None]:
for idx in range(1, 2):
    print(idx)
    df = prepare_data(fix=False, remove_anomalies=True)
    results_s, all_stuff, df_oof_s = inference_single(df, configs, weights, ENSEMBLE_CONFIG, idx=idx, cell_type="cort")
    
    pipelines = define_pipelines(config.data_config)
    dataset_s = SartoriusDataset(df_oof_s, transforms=pipelines['val_viz'])

    thresholds_mask = [0.45]
    thresholds_nms = [0.05, 0.1, 0.15]
    thresholds_conf = [0.65]
    min_sizes = [0]  # , 75]

    scores, cell_types = tweak_thresholds(
        results_s,
        dataset_s,
        thresholds_mask,
        thresholds_nms,
        thresholds_conf,
        min_sizes,
        remove_overlap=True,
        corrupt=False,
    )

    for c in range(len(CELL_TYPES)):
        scores_class = scores[c]

        if scores_class.shape[-1]:
            scores_class = scores[c].mean(-2) 
            idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
            best_score = scores_class[idx]

            threshold_mask = thresholds_mask[idx[0]]
            threshold_nms = thresholds_nms[idx[1]]
            threshold_conf = thresholds_conf[idx[3]]
            min_size = min_sizes[idx[2]]

            print(f"Best score {best_score:.4f}\n")
            print(f'- Threshold mask : {threshold_mask}')
            print(f'- Threshold nms  : {threshold_nms}')
            print(f'- Threshold conf : {threshold_conf}')
            print(f'- Min size : {min_size}\n')
            
    if min_size:
        break

In [None]:
masks_pred, boxes_pred, cell_types = process_results(
    results_s, threshold_mask, threshold_nms, threshold_conf, 0, remove_overlap=True, corrupt=False
)

In [None]:
scores_single, _ = evaluate(masks_pred, dataset_s, cell_types)

print(f' -> IoU mAP : {np.mean(scores_single):.4f}\n')

In [None]:
scores_single, _ = evaluate(masks_pred, dataset_s, cell_types)

print(f' -> IoU mAP : {np.mean(scores_single):.4f}\n')

In [None]:
masks_pred[0].shape

In [None]:
idx = 0
data = dataset_s[idx]

img = data['img']
truth = data['gt_masks'].masks.copy().astype(int)
boxes_truth = data['gt_bboxes']
pred = masks_pred[idx].copy().astype(int)

plt.figure(figsize=(15, 10))
plot_sample(img, mask=pred, boxes=boxes_pred[idx])
# plot_sample(img, mask=truth)
plt.title(f'{CELL_TYPES[cell_types[idx]]} - iou_map={np.mean(scores_single):.3f}')
plt.axis(False)
plt.show()

In [None]:
fig = plot_preds_iou(
    img,
    pred,
    truth,
    plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()

## Viz stuff

In [None]:
proposal_list, merged_bboxes, bboxes, aug_masks, masks = all_stuff

bboxes = bboxes.cpu().numpy()
merged_bboxes = merged_bboxes.cpu().numpy()
proposals = proposal_list[0].cpu().numpy()

In [None]:
proposal_list[0].max(0)[0]

In [None]:
# print(f'Number of proposals : {[len(prop) for prop in aug_proposals[0]]}')
print(f'Number of merged proposals : {len(proposals)}')
print(f'Number of merged boxes : {len(merged_bboxes)}')

for i in range(5):
    print(f'Number of detected boxes (th={0.05 * i:.2f}): {(bboxes[:, 4] > 0.05 * i).sum()}')
    
print()
print(f'Number of pred masks after pp : {len(pred)}')
print(f'Number of gt masks : {len(truth)}')

In [None]:
plt.figure(figsize=(15, 10))
plot_sample(img, mask=None, boxes=merged_bboxes)
plt.axis(False)
plt.show()

In [None]:
threshold_hit = 0.4

plt.figure(figsize=(15, 5))

missed = []
for i, preds in enumerate((proposals, bboxes)):
    max_ious = []
    for b in boxes_truth:
        ious = []
        for prop in preds[preds[:, 4] > 0.]:
            ious.append(bbox_iou(b, prop))

        max_ious.append(np.max(ious))

    max_ious = np.array(max_ious)
    missed.append(boxes_truth[(max_ious < threshold_hit)])

    plt.subplot(1, 2, i + 1)
    sns.histplot(max_ious, bins=20)
    plt.axvline(threshold_hit, c="salmon")
    t = 'proposals' if i == 0 else "bboxes"
    plt.title(t + f' - missed {len(missed[-1])}')
    
plt.show()

In [None]:
# m = (pred.astype(int).max(0) > 0)[..., None]
# img_m = img * (1 - m) + 129 * m

# plt.figure(figsize=(15, 15))
# plt.imshow(img_m)
# #         plot_sample(img * (1 - m) + 131 * m, masks.astype(int))
# plt.axis(False)
# plt.title(img_id)
# plt.show()


# cv2.imwrite('test.png', img_m)

In [None]:
fig = plot_preds_iou(
    img,
    pred,
    truth,
#     boxes=missed[1],
#     boxes_2=missed[0],
    plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()

In [None]:
# masks_comp = pred.copy()

In [None]:
# merged, _, picks = mask_nms(
#     np.concatenate([pred > 0, masks_comp > 0]),
#     np.concatenate([np.ones((len(pred), 5)), 0.1 * np.ones((len(masks_comp), 5))]),
#     0.0
# )

# # merged = pred > 0

In [None]:
# merged = remove_overlap_naive(merged)
# merged = merged.astype(int)
# for i in range(len(merged)):
#     merged[i] *= (i + 1)

# merged = merged.max(0)

In [None]:
# iou_map([truth.max(0)], [merged])

In [None]:
# iou_map([truth.max(0)], [merged])

In [None]:
# fig = plot_preds_iou(
#     img,
#     merged.astype(int),
#     truth,
# #     boxes=missed[1],
# #     boxes_2=missed[0],
#     plot_tp=True)

# fig.update_layout(
#     autosize=False,
#     width=900,
#     height=700,
# )

# fig.show()

In [None]:
# plt.imshow(masks_comp.max(0))