**About** : This notebook is used to perform inference on validation data

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import gc
import os
import ast
import sys
import cv2
import glob
import json
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from collections import Counter
warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.metrics import *
from inference.post_process import *
from utils.logger import Config
from inference.validation import *
from inference.tweaking import *

from data.preparation import prepare_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines
from inference.validation import inference_val

## Exps

In [None]:
EXP_FOLDERS = [  # single models - LB 0.321
    LOG_PATH + "2021-11-10/10/",  # 0.3119 / 0.3081  \
    LOG_PATH + "2021-11-10/11/",  # 0.3109 / 0.3079  |-> 0.3153 / 0.3112
    LOG_PATH + "2021-11-10/12/",  # 0.3122 / 0.3091 /
]

# EXP_FOLDERS = [  # single models
#     LOG_PATH + "2021-11-10/16/",  # 0.3108 / 0.3075   \
#     LOG_PATH + "2021-11-10/15/",  # 0.3107 / 0.3077   |-> 0.3165 / 0.3133
#     LOG_PATH + "2021-11-10/19/",  # 0.3101 / 0.3074   |
#     LOG_PATH + "2021-11-10/20/",  # 0.3151 / 0.3116  /
# ]

In [None]:
EXP_FOLDERS = [  # single models - livecell (r50)
#     LOG_PATH + "2021-11-12/2/",  # 0.3151 / 0.3118   - pretrain
#     LOG_PATH + "2021-11-13/0/",  # 0.3130 / 0.3093   - 700 ext
    LOG_PATH + "2021-11-13/1/",  # 0.3141 / 0.3112   - schedule
    LOG_PATH + "2021-11-13/3/",  # 0.3149 / 0.3119   - schedule + pretrain 
#     LOG_PATH + "2021-11-11/7/",  # 0.3111 / 0.3084
#     LOG_PATH + "2021-11-10/21/",  # 0.3118 / 0.3102

#     LOG_PATH + "2021-11-13/5/",  # 0.3130 / 0.3100   - schedule + single
#     LOG_PATH + "2021-11-15/1/",   # 0.3139 / 0.3097  - schedule + pretrain r101
]

In [None]:
EXP_FOLDERS = [
    LOG_PATH + "2021-11-13/1/",  # 0.3142 / 0.3129   - schedule
#     LOG_PATH + "2021-11-13/3/",  # 0.3150 / 0.3135   - schedule + pretrain 
]

In [None]:
# EXP_FOLDERS = [  # 1st batch
#     LOG_PATH + "2021-11-10/21/",  # 0.3078 / 0.3042  \
#     LOG_PATH + "2021-11-11/0/",   # 0.3068 / 0.3040  |-> 0.3121x / 0.309x
#     LOG_PATH + "2021-11-11/1/",   # 0.3088 / 0.3045 /
# #     LOG_PATH + "2021-11-11/3/",  # 0.3084 / 0.3046
# #     LOG_PATH + "2021-11-11/7/",  # 0.3045 / 0.3012
# #     LOG_PATH + "2021-11-12/0/",  # 0.3077 / 0.3044
#     LOG_PATH + "2021-11-15/3/",  # 0.3077 / 0.3044
# ]

EXP_FOLDERS = [  # 2nd batch 
#     LOG_PATH + "2021-11-15/3/",  # rx101 pretrain + extra - 0.3153
    LOG_PATH + "2021-11-16/0/",  # r50 pretrain + extra - 0.3105
#     LOG_PATH + "2021-11-16/3/",  # rx101 pretrain  - 0.3155
#     LOG_PATH + "2021-11-17/0/",  # r50 pretrain - 0.3102
]

## Inference

In [None]:
USE_TTA = True

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))

    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    configs.append(config)

    weights.append(sorted(glob.glob(exp_folder + "*.pt")))
#     weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

In [None]:
%%time
df = prepare_data(fix=False)
all_results, dfs_val = inference_val(df, configs, weights, use_tta=USE_TTA)

## Evaluation

In [None]:
df_oof = pd.concat(dfs_val).reset_index(drop=True)
pipelines = define_pipelines(config.data_config)

In [None]:
datasets = [SartoriusDataset(df_val, transforms=pipelines['val_viz'], precompute_masks=False) for df_val in dfs_val]
# dataset = SartoriusDataset(df_oof, transforms=pipelines['val_viz'], precompute_masks=False)

In [None]:
best_thresholds_mask = [0.45, 0.45, 0.45]
best_thresholds_nms = [0.45, 0.1, 0.05]
best_thresholds_conf = [0.35, 0.45, 0.7]

### Tweak thresholds

In [None]:
thresholds_mask = [0.45]

thresholds_nms = [np.round(0.05 * i, 2) for i in range(1, 10)]

thresholds_conf = [np.round(0.05 * i, 2) for i in range(1, 18)]

In [None]:
all_scores = []
for dataset, results in zip(datasets, all_results):
    scores = tweak_thresholds(
        results,
        dataset,
        thresholds_mask,
        thresholds_nms,
        thresholds_conf,
        remove_overlap=True
    )
    all_scores.append(scores)
#     break

scores_tweak = [
    np.concatenate([scores_fold[c] for scores_fold in all_scores], 2)
    for c in range(len(CELL_TYPES))
]

In [None]:
best_scores = []

for c in range(len(CELL_TYPES)):
    print(f' -> Cell type {CELL_TYPES[c]} : ')

    scores_class = scores_tweak[c].mean(2) 
    idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
    best_score = scores_class[idx]
    best_scores.append(best_score)

    best_thresholds_c = (thresholds_mask[idx[0]], thresholds_nms[idx[1]], thresholds_conf[idx[2]])
    best_thresholds_mask[c] = best_thresholds_c[0]
    best_thresholds_nms[c] = best_thresholds_c[1]
    best_thresholds_conf[c] = best_thresholds_c[2]

    print(f"Best score {best_score:.4f} for thresholds (mask, nms, conf): {best_thresholds_c}\n")

best_score = np.average(best_scores, weights=list(Counter(df_oof['cell_type']).values()))
print(f'CV score : {best_score:.4f}')

In [None]:
# for c in range(len(CELL_TYPES)):
#     print(f"\nClass {CELL_TYPES[c]}")
#     for idx_mask, threshold_mask in enumerate(thresholds_mask):
#         for idx_nms, threshold_nms in enumerate(thresholds_nms):
#             print(f"\n-> Threshold mask = {threshold_mask} - Threshold nms = {threshold_nms}")
#             try:
#                 for s, conf in zip(np.mean(scores_tweak[c][idx_mask, idx_nms], 0) , thresholds_conf):
#                     print(f"Threshold conf = {conf} - score = {s:.4f}")
#             except:
#                 break

In [None]:
print(f'THRESHOLDS_MASK = {best_thresholds_mask}')
print(f'THRESHOLDS_NMS = {best_thresholds_nms}')
print(f'THRESHOLDS_CONF = {best_thresholds_conf}')

In [None]:
gc.collect()

### Score

In [None]:
all_scores = [[], [], []]

for results, dataset in zip(all_results, datasets):
    masks_pred, boxes_pred, cell_types = process_results(
        results, best_thresholds_mask, best_thresholds_nms, best_thresholds_conf, remove_overlap=True
    )
    
    scores = evaluate(
        masks_pred,
        dataset,
        cell_types
    )
    
    for i, s in enumerate(scores):
        all_scores[i] += s
        
    del masks_pred, boxes_pred, cell_types
    gc.collect()

In [None]:
score = np.mean(np.concatenate(all_scores))
scores_class = [np.mean(s) for s in all_scores if len(s)]

print(f' -> IoU mAP : {score:.4f}\n')

for s, c in zip(scores_class, CELL_TYPES):
    print(f'{c} : {s:.4f}')

In [None]:
gc.collect()

### Dice

In [None]:
masks_preds = []
for result in results:
    masks, _ , _ = post_process_preds(
        result,
        thresholds_conf=best_thresholds_conf,
        thresholds_mask=best_thresholds_mask,
        remove_overlap=False
    )
    masks_preds.append(masks.max(0))
    
masks_truth = [masks.masks.max(0) for masks in dataset.masks]

In [None]:
dice_score(np.array(masks_preds), np.array(masks_truth))

## Viz

In [None]:
max_size = 1500

In [None]:
# TODO

for idx in range(10):
    data = dataset[idx]

    img = data['img']
    truth = data['gt_masks'].masks.copy().astype(int)
    boxes_truth = data['gt_bboxes']
    
    # preds
    masks, boxes, c = post_process_preds(
        results[idx], best_thresholds_conf, best_thresholds_mask, remove_overlap=False
    )
    
#     sizes = np.max([boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]], 0)
#     masks = masks[sizes < max_size]
#     boxes = boxes[sizes < max_size]
    
    # Score
    for i in range(len(truth)):
        truth[i] *= (i + 1)
    truth = truth.max(0)

    pred = masks.copy().astype(int)
    for i in range(len(pred)):
        pred[i] *= (i + 1)
    pred = pred.max(0)

    score = iou_map([truth], [pred])

    plt.figure(figsize=(15, 15))
    plot_sample(img, pred, boxes, plotly=False)
    plt.axis(False)
    plt.title(f'{CELL_TYPES[c]} - iou_map={score:.3f}')
    plt.show()
    
    plt.figure(figsize=(15, 15))
    plot_sample(img, truth, boxes_truth, plotly=False)
    plt.axis(False)
    plt.title(f'{CELL_TYPES[c]} - iou_map={score:.3f}')
    plt.show()

    break

In [None]:
fig = plot_preds_iou(img, pred, truth, plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()

## Single image explo

In [None]:
EXP_FOLDERS = [  # 2nd batch 
    LOG_PATH + "2021-11-15/3/",  # rx101 pretrain + extra - 0.3117
    LOG_PATH + "2021-11-16/0/",  # r50 pretrain + extra - 0.3105
    LOG_PATH + "2021-11-16/3/",  # rx101 pretrain  - 0.3129
    LOG_PATH + "2021-11-17/0/",  # r50 pretrain - 0.3102
]

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))
    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    configs.append(config)

#     weights.append(sorted(glob.glob(exp_folder + "*.pt")))
    weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

## Inference

In [None]:
USE_TTA = True

In [None]:
df = prepare_data(fix=False)
results_s, all_stuff, df_oof_s = inference_single(df, configs, weights, idx=0, use_tta=USE_TTA)

In [None]:
pipelines = define_pipelines(config.data_config)
dataset_s = SartoriusDataset(df_oof_s, transforms=pipelines['val_viz'])

In [None]:
thresholds_mask = [0.45]

thresholds_nms = [np.round(0.05 * i, 2) for i in range(1, 10)]

thresholds_conf = [np.round(0.05 * i, 2) for i in range(1, 18)]

In [None]:
%%time
scores = tweak_thresholds(
    results_s,
    dataset_s,
    thresholds_mask,
    thresholds_nms,
    thresholds_conf,
    remove_overlap=True
)


In [None]:
for c in range(len(CELL_TYPES)):
    scores_class = scores[c]

    if scores_class.shape[2]:
        scores_class = scores[c].mean(2) 
        
        idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
        best_score = scores_class[idx]

        print(f"Best score {best_score:.4f} for thresholds : ")
        print(f'Threshold mask : {thresholds_mask[idx[0]]}')
        print(f'Threshold nms  : {thresholds_nms[idx[1]]}')
        print(f'Threshold conf : {thresholds_conf[idx[2]]}')

In [None]:
# for c in range(len(CELL_TYPES)):
#     print(f"\nClass {CELL_TYPES[c]}")
#     for idx_mask, threshold_mask in enumerate(thresholds_mask):
#         for idx_nms, threshold_nms in enumerate(thresholds_nms):
#             print(f"\n-> Threshold mask = {threshold_mask} - Threshold nms = {threshold_nms}")
#             try:
#                 for s, conf in zip(np.mean(scores[c][idx_mask][idx_nms], 0) , thresholds_conf):
#                     print(f"Threshold conf = {conf} - score = {s:.4f}")
#             except:
#                 break
#     break

In [None]:
thresholds_mask = thresholds_mask[idx[0]]
thresholds_nms = thresholds_nms[idx[1]]
thresholds_conf = thresholds_conf[idx[2]]

masks_pred, boxes_pred, cell_types = process_results(
    results_s, thresholds_mask, thresholds_nms, thresholds_conf, remove_overlap=True
)

In [None]:
scores_single = evaluate(masks_pred, dataset_s, cell_types)

print(f' -> IoU mAP : {np.mean(np.concatenate(scores_single)):.4f}\n')

In [None]:
idx = 0
data = dataset_s[idx]

img = data['img']
truth = data['gt_masks'].masks.copy().astype(int)
boxes_truth = data['gt_bboxes']

for i in range(len(truth)):
    truth[i] *= (i + 1)
truth = truth.max(0)

pred = masks_pred[idx].copy()
pred = remove_overlap_naive(pred)
pred = pred.astype(int)

for i in range(len(pred)):
    pred[i] *= (i + 1)
pred = pred.max(0)

s = iou_map([truth], [pred])

plt.figure(figsize=(15, 10))
plot_sample(img, mask=pred, boxes=boxes_pred[idx])
# plot_sample(img, mask=truth)
plt.title(f'{CELL_TYPES[cell_types[idx]]} - iou_map={s:.3f}')
plt.axis(False)
plt.show()

In [None]:
idx = 0
data = dataset_s[idx]

img = data['img']
truth = data['gt_masks'].masks.copy().astype(int)
boxes_truth = data['gt_bboxes']

for i in range(len(truth)):
    truth[i] *= (i + 1)
truth = truth.max(0)

pred = masks_pred[idx].copy()
pred = remove_overlap_naive(pred)
pred = pred.astype(int)

for i in range(len(pred)):
    pred[i] *= (i + 1)
pred = pred.max(0)

s = iou_map([truth], [pred])

plt.figure(figsize=(15, 10))
plot_sample(img, mask=pred, boxes=boxes_pred[idx])
# plot_sample(img, mask=truth)
plt.title(f'{CELL_TYPES[cell_types[idx]]} - iou_map={s:.3f}')
plt.axis(False)
plt.show()

## Viz stuff

In [None]:
(
    proposal_list, aug_proposals,
    bboxes, merged_bboxes, aug_bboxes,
    masks, merged_masks, aug_masks,
) = all_stuff

bboxes = bboxes.cpu().numpy()
proposals = proposal_list[0].cpu().numpy()

In [None]:
print(f'Number of proposals : {[len(prop) for prop in aug_proposals[0]]}')
print(f'Number of merged proposals : {len(proposals)}')
print(f'Number of merged boxes : {len(merged_bboxes)}')
print(f'Number of detected boxes (th=0.): {(bboxes[:, 4] > 0.).sum()}')
print(f'Number of detected boxes (th=0.1): {(bboxes[:, 4] > 0.1).sum()}')
print(f'Number of detected boxes (th=0.2): {(bboxes[:, 4] > 0.2).sum()}')
print(f'Number of detected boxes (th=0.3): {(bboxes[:, 4] > 0.3).sum()}')
print(f'Number of detected masks: {len(masks)}')

In [None]:
plt.figure(figsize=(15, 10))
plot_sample(img, mask=None, boxes=proposals)
plt.axis(False)
plt.show()

In [None]:
threshold_hit = 0.4

plt.figure(figsize=(15, 5))

missed = []
for i, preds in enumerate((proposals, bboxes)):
    max_ious = []
    for b in boxes_truth:
        ious = []
        for prop in preds[preds[:, 4] > 0.]:
            ious.append(bbox_iou(b, prop))

        max_ious.append(np.max(ious))

    max_ious = np.array(max_ious)
    missed.append(boxes_truth[max_ious < threshold_hit])

    plt.subplot(1, 2, i + 1)
    sns.histplot(max_ious, bins=20)
    plt.axvline(threshold_hit, c="salmon")
    t = 'proposals' if i == 0 else "bboxes"
    plt.title(t + f' - missed {len(missed[-1])}')
    
plt.show()

In [None]:
fig = plot_preds_iou(
    img,
    pred,
    truth,
#     boxes=missed[1],
#     boxes_2=missed[0],
    plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()

In [None]:
fig = plot_preds_iou(
    img,
    pred,
    truth,
#     boxes=missed[1],
#     boxes_2=missed[0],
    plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()