**About** : This notebook is used to perform inference on validation data

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import gc
import os
import ast
import sys
import cv2
import glob
import json
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from collections import Counter
warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.metrics import *
from utils.logger import Config, save_to_folder
from utils.rle import rle_encode, rle_decode

from inference.tweaking import *
from inference.validation import *
from inference.post_process import *

from data.preparation import prepare_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines
from inference.validation import inference_val

## Exps

### All Exps

In [None]:
EXP_FOLDERS = [  # new folds
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3134
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3154
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3133
    LOG_PATH + "seb/mrcnn_resnext101_new_splits/", # 7. maskrcnn rx101 - 0.3120
    LOG_PATH + "seb/mrcnn_resnet50_new_splits/", # 8. maskrcnn r50 - 0.3118
    LOG_PATH + "seb/mrcnn_resnext101_lossdecay/",  # 11. mrcnn r101 0.3131
    LOG_PATH + "seb/mrcnn_r50_lossdecay/",  # 12. mrcnn r50 0.3125
    LOG_PATH + "2021-12-15/0/",  # 14. Cascade b6 - 0.3121
    LOG_PATH + "2021-12-15/1/",  # 15. htc r50 - 0.3121
    LOG_PATH + "2021-12-20/1/",  #  16. Cascade rx101_64x4 - 0.3130
    LOG_PATH + "2021-12-21/0/",  #  17. htc rx101 - 0.3119
    LOG_PATH + "seb/cascade_b4/", # 18. cascade b4 c
    LOG_PATH + "seb/mrcnn_b5/", # 19. mrcnn b4 - 0.3086
    LOG_PATH + "2021-12-22/2/",  #  20. cascade b6 192 crops - 0.3118
    LOG_PATH + "2021-12-22/6/",  #  21. htc b4 - 0.3083
    LOG_PATH + "seb/mrcnn_r101_64x4",  # 22. mrcnn rx101_64x4 - 0.3127
    LOG_PATH + "seb/cascade_resnext101_32x8/",  # 23. cascade rx101_32x8 - 0.3121
    LOG_PATH + "seb/mrcnn_rx101_decay_bn_flip_aug/",  # 24. mrcnn rx101 - 0.3141
    LOG_PATH + "seb/mrcnn_r50_bn_flip_decay/",  # 25. mrcnn r50 0.3141
    LOG_PATH + "seb/mrcnn_rx101_64x4_flip_bn_decay_64x4",  # 26. mrcnn rx101_64x4 - 0.3121
    LOG_PATH + "2021-12-28/2/",  # 28. mrcnn rx50 gnws - 
]

### Ensembles

In [None]:
EXP_FOLDERS = [  # id #3 - Cort 0.4083 / 0.4091 pp
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3121
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3141
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH + "seb/mrcnn_resnext101_lossdecay/", # 11. mrcnn r101 0.3131
    LOG_PATH + "seb/mrcnn_r50_lossdecay/", # 12. mrcnn r50 0.3125
    LOG_PATH + "2021-12-15/0/",  # 14. Cascade b6 - 0.3121
]

In [None]:
EXP_FOLDERS = [  # Shsy5y - 0.2399 / 0.2403 no tta mask
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3133
    LOG_PATH + "seb/mrcnn_resnet50_new_splits/", # 8. maskrcnn r50 - 0.3118
    LOG_PATH + "2021-12-15/1/",  # 15. htc r50 - 0.3121
    LOG_PATH + "2021-12-20/1/",  #  16. Cascade rx101_64x4 - 0.3130
    LOG_PATH + "2021-12-22/2/",  #  20. cascade b6 192 crops - 0.3118
    LOG_PATH + "seb/mrcnn_r101_64x4/",  # 22. mrcnn rx101_64x4 - 0.3127
]

In [None]:
EXP_FOLDERS = [  # Astro - 0.2145 / 0.2151 no tta mask / 0.2165 pp no tta mask
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3134
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3154
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3133
    LOG_PATH + "seb/mrcnn_resnext101_lossdecay/",  # 11. mrcnn r101 0.3131
    LOG_PATH + "2021-12-15/1/",  # 15. htc r50 - 0.3121
    LOG_PATH + "seb/mrcnn_r101_64x4/",  # 22. mrcnn rx101_64x4 - 0.3127
]

In [None]:
EXP_FOLDERS = [ # Cort - 0.4095
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3134
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3154
    LOG_PATH + "2021-12-15/0/",  # 14. Cascade b6 - 0.3121
    LOG_PATH + "seb/mrcnn_b5/", # 19. mrcnn b5 - 0.3086
    LOG_PATH + "2021-12-22/6/",  #  21. htc b4 - 0.3083
    LOG_PATH + "seb/mrcnn_rx101_decay_bn_flip_aug/",  # 24. mrcnn rx101 - 0.3141
]

In [None]:
if len(EXP_FOLDERS) <= 6:  # small ensemble 
    EXP_FOLDERS_CORT = EXP_FOLDERS
    EXP_FOLDERS_ASTRO = EXP_FOLDERS
    EXP_FOLDERS_SHSY5Y = EXP_FOLDERS
    EXP_FOLDER_CLS = EXP_FOLDERS

assert [f for f in EXP_FOLDERS if f in EXP_FOLDERS_CORT] == EXP_FOLDERS_CORT
assert [f for f in EXP_FOLDERS if f in EXP_FOLDERS_ASTRO] == EXP_FOLDERS_ASTRO
assert [f for f in EXP_FOLDERS if f in EXP_FOLDERS_SHSY5Y] == EXP_FOLDERS_SHSY5Y

## Inference

In [None]:
ENSEMBLE_CONFIG_SINGLE = {  # single
    "use_tta": True,
    "use_tta_masks": True,
    "num_classes": 3,

    "rpn_nms_pre": [3000, 2000, 1000],
    "rpn_iou_threshold": [0.75, 0.75, 0.6],
    "rpn_score_threshold": [0.95, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],  # [1500, 1000, 500],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.75, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.3, 0.5],
    
    "use_for_cort": [f in EXP_FOLDERS_CORT for f in EXP_FOLDERS],
    "use_for_astro": [f in EXP_FOLDERS_ASTRO for f in EXP_FOLDERS],
    "use_for_shsy5y": [f in EXP_FOLDERS_SHSY5Y for f in EXP_FOLDERS],
    "use_for_cls": [f in EXP_FOLDER_CLS for f in EXP_FOLDERS]
}

In [None]:
ENSEMBLE_CONFIG_ENS = {  # ens
    "use_tta": True,
    "use_tta_masks": True,  # False
    "num_classes": 3,

    "rpn_nms_pre": [5000, 2000, 1000],
    "rpn_iou_threshold": [0.7, 0.75, 0.6],
    "rpn_score_threshold": [0.9, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.7, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.25, 0.5],
    
    "use_for_cort": [f in EXP_FOLDERS_CORT for f in EXP_FOLDERS],
    "use_for_astro": [f in EXP_FOLDERS_ASTRO for f in EXP_FOLDERS],
    "use_for_shsy5y": [f in EXP_FOLDERS_SHSY5Y for f in EXP_FOLDERS],
    "use_for_cls": [f in EXP_FOLDER_CLS for f in EXP_FOLDERS]
}

In [None]:
ENSEMBLE_CONFIG = ENSEMBLE_CONFIG_SINGLE if len(EXP_FOLDERS) == 1 else ENSEMBLE_CONFIG_ENS

In [None]:
CELL_TYPE = None  # "cort"  #  # "shsy5y"  # "astro"
# If you wish to run the inference only on one cell type.

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))

    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]

    try:
        _ = config.split
        remove_anomalies = config.remove_anomalies
    except:
        config.split = "skf"
        remove_anomalies = False

    configs.append(config)
    weights.append(sorted(glob.glob(exp_folder + "*.pt")))
#     weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

In [None]:
df = prepare_data(fix=False, remove_anomalies=remove_anomalies)

In [None]:
pipelines = define_pipelines(configs[0].data_config)

In [None]:
%%time
all_results, dfs_val = inference_val(df, configs, weights, ENSEMBLE_CONFIG, cell_type=CELL_TYPE)

## Evaluation

In [None]:
df_oof = pd.concat(dfs_val).reset_index(drop=True)

pipelines = define_pipelines(config.data_config)

datasets = [SartoriusDataset(df_val, transforms=pipelines['val_viz'], precompute_masks=False) for df_val in dfs_val]

In [None]:
best_thresholds_mask = [0.45, 0.45, 0.45]
best_thresholds_nms = [0.1, 0.1, 0.05]
best_thresholds_conf = [0.3, 0.4, 0.7]
best_min_sizes = [0, 0, 0]  # [50, 125, 75]

### Tweak thresholds

In [None]:
thresholds_mask = [0.45]
thresholds_nms = [0.05, 0.1, 0.15]
thresholds_conf = [np.round(0.05 * i, 2) for i in range(4, 17)]
min_sizes = [0, 50, 75, 100, 150]  # [0, 25, 50, 75, 100]

In [None]:
for fold, (df_, results) in enumerate(zip(dfs_val, all_results)):
    for i, (b, _) in enumerate(results):
        pred = np.argmax(np.bincount(b[:, 5].astype(int)))
        gt = CELL_TYPES.index(df_['cell_type'][i])

        if pred != gt:
            print(f'Fold {fold}, img {df_["id"][i]} (idx {i}), pred {pred}, gt {gt}')

In [None]:
all_scores = []
all_cell_types = []

for dataset, results in zip(datasets, all_results):
    scores, cell_types = tweak_thresholds(
        results,
        dataset,
        thresholds_mask,
        thresholds_nms,
        thresholds_conf,
        min_sizes=min_sizes,
        remove_overlap=True,
        corrupt=True,
    )
    all_scores.append(scores)
    all_cell_types += cell_types

In [None]:
if CELL_TYPE is not None:
    for i in range(len(weights[0])):
        if not all_scores[i][0].shape[-1] and not all_scores[i][1].shape[-1]:
            all_scores[i][0] = all_scores[i][2].copy()
            all_scores[i][1] = all_scores[i][2].copy()

        if not all_scores[i][2].shape[-1] and not all_scores[i][1].shape[-1]:
            all_scores[i][2] = all_scores[i][0].copy()
            all_scores[i][1] = all_scores[i][0].copy()

        if not all_scores[i][0].shape[-1] and not all_scores[i][2].shape[-1]:
            all_scores[i][0] = all_scores[i][1].copy()
            all_scores[i][2] = all_scores[i][1].copy()

scores_tweak = [
    np.concatenate([scores_fold[c] for scores_fold in all_scores if scores_fold[c].shape[-1]], -2)
    for c in range(len(CELL_TYPES))
]

In [None]:
best_thresholds_mask, best_thresholds_nms, best_thresholds_conf, best_min_sizes = [], [], [], []
best_scores = []

for c in range(len(CELL_TYPES)):  # 64
    print(f' -> Cell type {CELL_TYPES[c]} : ')

    scores_class = scores_tweak[c].mean(-2) 
    idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
    best_score = scores_class[idx]
    best_scores.append(best_score)

    best_thresholds_c = (
        thresholds_mask[idx[0]], thresholds_nms[idx[1]], thresholds_conf[idx[3]], min_sizes[idx[2]]
    )
    best_thresholds_mask.append(best_thresholds_c[0])
    best_thresholds_nms.append(best_thresholds_c[1])
    best_thresholds_conf.append(best_thresholds_c[2])
    best_min_sizes.append(best_thresholds_c[3])

    print(
        f"Best score {best_score:.4f} for thresholds (mask, nms, conf, min_size): {best_thresholds_c}\n"
    )

# ws = [Counter(df_oof['cell_type'])[c] for c in CELL_TYPES]
ws = [Counter(all_cell_types)[c] for c in range(len(CELL_TYPES))]

best_score = np.average(best_scores, weights=ws)

print(f'CV score : {best_score:.4f}')

In [None]:
# for c in range(len(CLASSES)):
#     print(f"\nClass {CLASSES[c]}")
#     for idx_mask, threshold_mask in enumerate(thresholds_mask):
#         for idx_nms, threshold_nms in enumerate(thresholds_nms):
#             print(f"\n-> Threshold mask = {threshold_mask} - Threshold nms = {threshold_nms}")
#             for s, conf in zip(np.mean(scores_tweak[c][idx_mask, idx_nms], 0) , thresholds_conf):
#                 print(f"Threshold conf = {conf} - score = {s:.4f}")

In [None]:
print(f'THRESHOLDS_MASK = {best_thresholds_mask}')
print(f'THRESHOLDS_NMS = {best_thresholds_nms}')
print(f'THRESHOLDS_CONF = {best_thresholds_conf}')
print(f'MIN_SIZES = {best_min_sizes}')

In [None]:
gc.collect()

### Score

In [None]:
all_scores = [[], [], []]
metadata = []

for results, dataset in zip(all_results, datasets):
    masks_pred, boxes_pred, cell_types = process_results(
        results,
        best_thresholds_mask,
        best_thresholds_nms,
        best_thresholds_conf,
        70,
        remove_overlap=True,
        corrupt=True
    )
    
    scores, scores_per_class = evaluate(
        masks_pred,
        dataset,
        cell_types
    )

    for masks, boxes, cell_type_pred, img_id, score, cell_type in zip(
        masks_pred, boxes_pred, cell_types, dataset.df['id'].values, scores, dataset.df['cell_type'].values
    ):
        metadata.append({
            'id': img_id,
            'cell_type': cell_type,
            'cell_type_pred': cell_type_pred,
            'rles': [rle_encode(mask) for mask in masks],
            'boxes': boxes.tolist(),
            'score': score
        })

    for i, s in enumerate(scores_per_class):
        all_scores[i] += s
        
#     del masks_pred, boxes_pred, cell_types
#     gc.collect()

#     break
    
df_preds_oof = pd.DataFrame.from_dict(metadata)

In [None]:
print(f' -> IoU mAP : {df_preds_oof.score.mean():.4f}\n')
df_preds_oof[['cell_type', 'score']].groupby('cell_type').mean()

### Save no thesholding

#### Raw results

In [None]:
name = "ens_12/"

SAVE_DIR = OUT_PATH + name

# assert not os.path.exists(SAVE_DIR)
# os.mkdir(SAVE_DIR)

In [None]:
for fold_idx, (results, dataset) in enumerate(zip(all_results, datasets)):
    for i in tqdm(range(len(dataset))):
        id_ = dataset.df['id'][i]
        masks = results[i][1]
        boxes = results[i][0]

        np.save(SAVE_DIR + f"masks_{id_}.npy", masks)
        np.save(SAVE_DIR + f"boxes_{id_}.npy", boxes)

## Viz

In [None]:
dataset = SartoriusDataset(df_oof, transforms=pipelines['val_viz'], precompute_masks=False)

In [None]:
for idx in range(len(dataset)):
    if df_oof['plate'][idx] != "cort[density]":
        continue
    
    score = df_preds_oof['score'][idx]
    c = df_preds_oof['cell_type'][idx]
    
    data = dataset[idx]
    img = data['img']
    
    # truth
    truth = data['gt_masks'].masks.copy().astype(int)
    boxes_truth = data['gt_bboxes']
    
    # preds
    rles = df_preds_oof['rles'][idx]
    pred = np.array([rle_decode(enc, ORIG_SIZE) for enc in rles]).astype(int)
    boxes = df_preds_oof['boxes'][idx]
    
    plt.figure(figsize=(15, 15))
    plot_sample(img, pred, boxes, plotly=False)
    plt.axis(False)
    plt.title(f'Pred - {CELL_TYPES[c]} - iou_map={score:.3f}')
    plt.show()
    
    plt.figure(figsize=(15, 15))
    plot_sample(img, truth, boxes_truth, plotly=False)
    plt.axis(False)
    plt.title(f'Truth - {df_oof["cell_type"][idx]}')
    plt.show()
    
    print('-' * 100)

    break

In [None]:
fig = plot_preds_iou(img, pred, truth, plot_tp=True)

fig.update_layout(
    autosize=False,
    width=900,
    height=700,
)

fig.show()