**About** : This notebook is used to perform inference on validation data

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import gc
import os
import ast
import sys
import cv2
import glob
import json
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from collections import Counter
warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.metrics import *
from utils.logger import Config
from utils.rle import rle_encode, rle_decode

from inference.tweaking import *
from inference.validation import *
from inference.post_process import *

from data.preparation import prepare_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines
from inference.validation import inference_val

## Exps

In [None]:
EXP_FOLDERS = [  # ENS_7
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3127
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3141
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH + "seb/mrcnn_resnext101_new_splits/", # 7. maskrcnn rx101 - 0.3120
    LOG_PATH + "seb/mrcnn_resnet50_new_splits/", # 8. maskrcnn r50 - 0.3118
]

In [None]:
EXP_FOLDERS = [  # ENS_8
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3127
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3141
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH + "seb/mrcnn_resnext101_lossdecay/", # 11. mrcnn r101 0.3131
    LOG_PATH + "seb/mrcnn_r50_lossdecay/", # 12. mrcnn r50 0.3125
    LOG_PATH + "2021-12-15/0/",  # 14. Cascade b6 - 0.3121
]

In [None]:
EXP_FOLDERS = [  # new folds - fix tta
#     LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3134
#     LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3154
#     LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3133
#     LOG_PATH + "seb/mrcnn_resnext101_lossdecay/",  # 11. mrcnn r101 0.3131
#     LOG_PATH + "seb/mrcnn_r50_lossdecay/",  # 12. mrcnn r50 0.3125
#     LOG_PATH + "2021-12-15/0/",  # 14. Cascade b6 - 0.3121
    LOG_PATH + "2021-12-15/1/",  # 14. htc r50 - 0.3121
]

## Inference

In [None]:
ENSEMBLE_CONFIG = {
    "use_tta": True,
    "num_classes": 3,

    "rpn_nms_pre": [5000, 2000, 1000],
    "rpn_iou_threshold": [0.7, 0.75, 0.6],
    "rpn_score_threshold": [0.9, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],  # [1500, 1000, 500],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.7, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.25, 0.5],
}

In [None]:
ENSEMBLE_CONFIG = {
    "use_tta": True,
    "num_classes": 3,

    "rpn_nms_pre": [3000, 2000, 1000],
    "rpn_iou_threshold": [0.75, 0.75, 0.6],
    "rpn_score_threshold": [0.95, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],  # [1500, 1000, 500],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.75, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.3, 0.5],
}

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))

    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]

    try:
        _ = config.split
        remove_anomalies = config.remove_anomalies
    except:
        config.split = "skf"
        remove_anomalies = False

    configs.append(config)
    
    weights.append(sorted(glob.glob(exp_folder + "*.pt")))
#     weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

In [None]:
df = prepare_data(fix=False, remove_anomalies=remove_anomalies)

In [None]:
%%time
all_results, dfs_val = inference_val(df, configs, weights, ENSEMBLE_CONFIG)

## Evaluation

In [None]:
df_oof = pd.concat(dfs_val).reset_index(drop=True)
df_oof['class'] = df_oof['cell_type']

pipelines = define_pipelines(config.data_config)

datasets = [SartoriusDataset(df_val, transforms=pipelines['val_viz'], precompute_masks=False) for df_val in dfs_val]

In [None]:
best_thresholds_mask = [0.45, 0.45, 0.45]
best_thresholds_nms = [0.1, 0.1, 0.05]
best_thresholds_conf = [0.3, 0.4, 0.7]

In [None]:
# CLASSES = [
#     "astro[hippo]",
#     "astros[cereb]",
#     "cort[6-OHDA_oka-low]",
#     "cort[density_pre-treat]",
#     "cort[oka-high_debris]",
#     "shsy5y",
# ]

# def plate_to_class(plate):
#     mapping = {
#         "astro[hippo]": 0,
#         "astros[cereb]": 1,
#         "cort[6-OHDA]": 2,
#         "cort[debris]": 4,
#         "cort[density]": 3,
#         "cort[oka-high]": 4,
#         "cort[oka-low]": 2,
#         "cort[pre-treat]": 3,
#         "shsy5y[diff]": 5,
#     }
#     return mapping[plate]

# df_oof['class'] = df_oof['plate'].apply(plate_to_class)
# pred_plate = np.load("../output/pred_plate.npy")
# df['pred_plate'] = pred_plate

### Tweak thresholds

In [None]:
thresholds_mask = [0.45]
thresholds_nms = [0.05, 0.1, 0.15]
thresholds_conf = [np.round(0.05 * i, 2) for i in range(4, 17)]

In [None]:
all_scores = []
for dataset, results in zip(datasets, all_results):
#     cell_types = dataset.df['plate'].apply(plate_to_class).values  # GT
#     cell_types = dataset.df.merge(df, on="id", how="left")['pred_plate'].values  # PRED
    cell_types = None

    scores = tweak_thresholds(
        results,
        dataset,
        thresholds_mask,
        thresholds_nms,
        thresholds_conf,
#         num_classes=len(CLASSES),
        remove_overlap=True,
        corrupt=True,
        cell_types=cell_types
    )
    all_scores.append(scores)

if cell_types is None:
    CLASSES = CELL_TYPES

scores_tweak = [
    np.concatenate([scores_fold[c] for scores_fold in all_scores if scores_fold[c].shape[-1]], 2)
    for c in range(len(CLASSES)) 
]

In [None]:
best_thresholds_mask, best_thresholds_nms, best_thresholds_conf = [], [], []
best_scores = []

for c in range(len(CLASSES)):
    print(f' -> Cell type {CLASSES[c]} : ')

    scores_class = scores_tweak[c].mean(2) 
    idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
    best_score = scores_class[idx]
    best_scores.append(best_score)

    best_thresholds_c = (thresholds_mask[idx[0]], thresholds_nms[idx[1]], thresholds_conf[idx[2]])
    best_thresholds_mask.append(best_thresholds_c[0])
    best_thresholds_nms.append(best_thresholds_c[1])
    best_thresholds_conf.append(best_thresholds_c[2])

    print(f"Best score {best_score:.4f} for thresholds (mask, nms, conf): {best_thresholds_c}\n")

if cell_types is None:
    weights = [Counter(df_oof['cell_type'])[c] for c in CELL_TYPES]
else:
#     weights = [Counter(df_oof['class'].values)[c] for c in range(len(CLASSES))]
    weights = [Counter(pred_plate)[c] for c in range(len(CLASSES))]

best_score = np.average(best_scores, weights=weights)

print(f'CV score : {best_score:.4f}')

In [None]:
# for c in range(len(CLASSES)):
#     print(f"\nClass {CLASSES[c]}")
#     for idx_mask, threshold_mask in enumerate(thresholds_mask):
#         for idx_nms, threshold_nms in enumerate(thresholds_nms):
#             print(f"\n-> Threshold mask = {threshold_mask} - Threshold nms = {threshold_nms}")
#             for s, conf in zip(np.mean(scores_tweak[c][idx_mask, idx_nms], 0) , thresholds_conf):
#                 print(f"Threshold conf = {conf} - score = {s:.4f}")

In [None]:
print(f'THRESHOLDS_MASK = {best_thresholds_mask}')
print(f'THRESHOLDS_NMS = {best_thresholds_nms}')
print(f'THRESHOLDS_CONF = {best_thresholds_conf}')

In [None]:
gc.collect()

### Score

In [None]:
all_scores = [[], [], []]
metadata = []

for results, dataset in zip(all_results, datasets):
    masks_pred, boxes_pred, cell_types = process_results(
        results,
        best_thresholds_mask,
        best_thresholds_nms,
        best_thresholds_conf,
        remove_overlap=True,
        corrupt=True
    )
    
    scores, scores_per_class = evaluate(
        masks_pred,
        dataset,
        cell_types
    )
    
    for masks, boxes, cell, img_id, score in zip(
        masks_pred, boxes_pred, cell_types, dataset.df['id'].values, scores
    ):
        metadata.append({
            'id': img_id,
            'rles': [rle_encode(mask) for mask in masks],
            'boxes': boxes.tolist(),
            'cell_type': cell,
            'score': score
        })
    

    for i, s in enumerate(scores_per_class):
        all_scores[i] += s
        
    del masks_pred, boxes_pred, cell_types
    gc.collect()

#     break
    
df_preds_oof = pd.DataFrame.from_dict(metadata)

In [None]:
# if len(EXP_FOLDERS) == 1:
#     save_folder = EXP_FOLDERS[0]
#     name = 'df_oof.csv'
# else:
#     save_folder = OUT_PATH
#     name = 'df_oof_blend'

#     for f in EXP_FOLDERS:
#         name += f"_{f.split('/')[-3][5:]}-{f.split('/')[-2]}"
#     name += ".csv"

# print(f'-> Saved results to "{save_folder + name}"')

# df_preds_oof.to_csv(save_folder + name, index=False)

In [None]:
plt.figure(figsize=(15, 5))
for i in range(len(all_scores)):
    plt.subplot(1, 3, i + 1)
    plt.title(CELL_TYPES[i], size=15)
    plt.grid(True)
    plt.scatter(range(len(all_scores[i])), sorted(all_scores[i]), s=20)
#     plt.ylim(0, 0.75)
    plt.xlabel('')
    plt.ylabel('IoU mAP')

In [None]:
score = np.mean(np.concatenate(all_scores))
scores_class = [np.mean(s) for s in all_scores if len(s)]

print(f' -> IoU mAP : {score:.4f}\n')

for s, c in zip(scores_class, CELL_TYPES):
    print(f'{c} : {s:.4f}')

In [None]:
df_oof = df_oof.merge(df_preds_oof, on="id", how="left", suffixes=('', '_pred'))
df_oof['n_cells'] = df_oof['annotation'].apply(len)

In [None]:
df_oof.groupby('plate').mean()[['score']]

In [None]:
plt.figure(figsize=(15, 10))
sns.violinplot(x='score', y='plate', data=df_oof.sort_values('plate'))
plt.show()

In [None]:
df_oof.groupby('cell_type').mean()[['score']]

In [None]:
gc.collect()

### Dice

In [None]:
# masks_preds, masks_truth = [], []

# for results, dataset in zip(all_results, datasets):
#     masks_pred, boxes_pred, cell_types = process_results(
#         results, best_thresholds_mask, best_thresholds_nms, best_thresholds_conf, remove_overlap=True
#     )
    
#     masks_preds.append(masks.max(0))
    
#     masks_truth += [masks.masks.max(0) for masks in dataset.masks]

# dice_score(np.array(masks_preds), np.array(masks_truth))

## Viz

In [None]:
dataset = SartoriusDataset(df_oof, transforms=pipelines['val_viz'], precompute_masks=False)

In [None]:
for idx in range(len(dataset)):
    if df_oof['plate'][idx] != "cort[density]":
        continue
    
    score = df_preds_oof['score'][idx]
    c = df_preds_oof['cell_type'][idx]
    
    data = dataset[idx]
    img = data['img']
    
    # truth
    truth = data['gt_masks'].masks.copy().astype(int)
    boxes_truth = data['gt_bboxes']
    
    # preds
    rles = df_preds_oof['rles'][idx]
    pred = np.array([rle_decode(enc, ORIG_SIZE) for enc in rles]).astype(int)
    boxes = df_preds_oof['boxes'][idx]
    
    plt.figure(figsize=(15, 15))
    plot_sample(img, pred, boxes, plotly=False)
    plt.axis(False)
    plt.title(f'Pred - {CELL_TYPES[c]} - iou_map={score:.3f}')
    plt.show()
    
    plt.figure(figsize=(15, 15))
    plot_sample(img, truth, boxes_truth, plotly=False)
    plt.axis(False)
    plt.title(f'Truth - {df_oof["cell_type"][idx]}')
    plt.show()
    
    print('-' * 100)

    break

In [None]:
fig = plot_preds_iou(img, pred, truth, plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()

## Single image explo

In [None]:
EXP_FOLDERS = [  # new folds - fix tta
#     LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3134
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3154
#     LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3133
#     LOG_PATH + "seb/mrcnn_resnext101_lossdecay/",  # 11.
#     LOG_PATH + "seb/mrcnn_r50_lossdecay/",  # 12.

]

In [None]:
ENSEMBLE_CONFIG = {
    "use_tta": True,
    "num_classes": 3,

    "rpn_nms_pre": [3000, 2000, 1000],
    "rpn_iou_threshold": [0.75, 0.75, 0.6],
    "rpn_score_threshold": [0.95, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],  # [1500, 1000, 500],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.75, 0.9, 0.6],
    "rcnn_score_threshold": [0., 0., 0.],
}

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))
    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    config.split = "skf"
    configs.append(config)

#     weights.append(sorted(glob.glob(exp_folder + "*.pt")))
    weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

## Inference

In [None]:
df = prepare_data(fix=False, remove_anomalies=True)
results_s, all_stuff, df_oof_s = inference_single(df, configs, weights, ENSEMBLE_CONFIG, idx=1)

In [None]:
df = prepare_data(fix=False, remove_anomalies=True)
results_s, all_stuff, df_oof_s = inference_single(df, configs, weights, ENSEMBLE_CONFIG, idx=1)

In [None]:
pipelines = define_pipelines(config.data_config)
dataset_s = SartoriusDataset(df_oof_s, transforms=pipelines['val_viz'])

In [None]:
thresholds_mask = [0.45]
thresholds_nms = [np.round(0.05 * i, 2) for i in range(1, 10)]
thresholds_conf = [np.round(0.05 * i, 2) for i in range(1, 18)]

scores = tweak_thresholds(
    results_s,
    dataset_s,
    thresholds_mask,
    thresholds_nms,
    thresholds_conf,
    remove_overlap=True
)

for c in range(len(CELL_TYPES)):
    scores_class = scores[c]

    if scores_class.shape[2]:
        scores_class = scores[c].mean(2) 
        
        idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
        best_score = scores_class[idx]
        
        threshold_mask = thresholds_mask[idx[0]]
        threshold_nms = thresholds_nms[idx[1]]
        threshold_conf = thresholds_conf[idx[2]]

        print(f"Best score {best_score:.4f} for thresholds : ")
        print(f'- Threshold mask : {threshold_mask}')
        print(f'- Threshold nms  : {threshold_nms}')
        print(f'- Threshold conf : {threshold_conf}')

In [None]:
masks_pred, boxes_pred, cell_types = process_results(
    results_s, threshold_mask, threshold_nms, threshold_conf, remove_overlap=True, corrupt=False
)

In [None]:
scores_single, _ = evaluate(masks_pred, dataset_s, cell_types)

print(f' -> IoU mAP : {np.mean(scores_single):.4f}\n')

In [None]:
scores_single, _ = evaluate(masks_pred, dataset_s, cell_types)

print(f' -> IoU mAP : {np.mean(scores_single):.4f}\n')

In [None]:
scores_single, _ = evaluate(masks_pred, dataset_s, cell_types)

print(f' -> IoU mAP : {np.mean(scores_single):.4f}\n')

In [None]:
idx = 0
data = dataset_s[idx]

img = data['img']
truth = data['gt_masks'].masks.copy().astype(int)
boxes_truth = data['gt_bboxes']
pred = masks_pred[idx].copy().astype(int)

plt.figure(figsize=(15, 10))
plot_sample(img, mask=pred, boxes=boxes_pred[idx])
# plot_sample(img, mask=truth)
plt.title(f'{CELL_TYPES[cell_types[idx]]} - iou_map={np.mean(scores_single):.3f}')
plt.axis(False)
plt.show()

In [None]:
idx = 0
data = dataset_s[idx]

img = data['img']
truth = data['gt_masks'].masks.copy().astype(int)
boxes_truth = data['gt_bboxes']
pred = masks_pred[idx].copy().astype(int)

plt.figure(figsize=(15, 10))
plot_sample(img, mask=pred, boxes=boxes_pred[idx])
# plot_sample(img, mask=truth)
plt.title(f'{CELL_TYPES[cell_types[idx]]} - iou_map={np.mean(scores_single):.3f}')
plt.axis(False)
plt.show()

## Viz stuff

In [None]:
proposal_list, merged_bboxes, bboxes, aug_masks, masks = all_stuff

bboxes = bboxes.cpu().numpy()
merged_bboxes = merged_bboxes.cpu().numpy()
proposals = proposal_list[0].cpu().numpy()

In [None]:
# print(f'Number of proposals : {[len(prop) for prop in aug_proposals[0]]}')
print(f'Number of merged proposals : {len(proposals)}')
print(f'Number of merged boxes : {len(merged_bboxes)}')

for i in range(5):
    print(f'Number of detected boxes (th={0.05 * i:.2f}): {(bboxes[:, 4] > 0.05 * i).sum()}')
    
print()
print(f'Number of pred masks after pp : {len(pred)}')
print(f'Number of gt masks : {len(truth)}')

In [None]:
plt.figure(figsize=(15, 10))
plot_sample(img, mask=None, boxes=merged_bboxes)
plt.axis(False)
plt.show()

In [None]:
threshold_hit = 0.4

plt.figure(figsize=(15, 5))

missed = []
for i, preds in enumerate((proposals, bboxes)):
    max_ious = []
    for b in boxes_truth:
        ious = []
        for prop in preds[preds[:, 4] > 0.]:
            ious.append(bbox_iou(b, prop))

        max_ious.append(np.max(ious))

    max_ious = np.array(max_ious)
    missed.append(boxes_truth[(max_ious < threshold_hit)])

    plt.subplot(1, 2, i + 1)
    sns.histplot(max_ious, bins=20)
    plt.axvline(threshold_hit, c="salmon")
    t = 'proposals' if i == 0 else "bboxes"
    plt.title(t + f' - missed {len(missed[-1])}')
    
plt.show()

In [None]:
# m = (pred.astype(int).max(0) > 0)[..., None]
# img_m = img * (1 - m) + 129 * m

# plt.figure(figsize=(15, 15))
# plt.imshow(img_m)
# #         plot_sample(img * (1 - m) + 131 * m, masks.astype(int))
# plt.axis(False)
# plt.title(img_id)
# plt.show()


# cv2.imwrite('test.png', img_m)

In [None]:
fig = plot_preds_iou(
    img,
    pred,
    truth,
#     boxes=missed[1],
#     boxes_2=missed[0],
    plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()

In [None]:
# masks_comp = pred.copy()

In [None]:
# merged, _, picks = mask_nms(
#     np.concatenate([pred > 0, masks_comp > 0]),
#     np.concatenate([np.ones((len(pred), 5)), 0.1 * np.ones((len(masks_comp), 5))]),
#     0.0
# )

# # merged = pred > 0

In [None]:
# merged = remove_overlap_naive(merged)
# merged = merged.astype(int)
# for i in range(len(merged)):
#     merged[i] *= (i + 1)

# merged = merged.max(0)

In [None]:
# iou_map([truth.max(0)], [merged])

In [None]:
# iou_map([truth.max(0)], [merged])

In [None]:
# fig = plot_preds_iou(
#     img,
#     merged.astype(int),
#     truth,
# #     boxes=missed[1],
# #     boxes_2=missed[0],
#     plot_tp=True)

# fig.update_layout(
#     autosize=False,
#     width=900,
#     height=700,
# )

# fig.show()

In [None]:
# plt.imshow(masks_comp.max(0))