**About** : This notebook is used to tweak ensemble params

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import gc
import os
import ast
import sys
import cv2
import glob
import json
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt

from itertools import product
from tqdm.notebook import tqdm
from collections import Counter
warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.metrics import *
from utils.logger import Config, prepare_log_folder, create_logger
from utils.rle import rle_encode, rle_decode

from inference.tweaking import *
from inference.validation import *
from inference.post_process import *

from data.preparation import prepare_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines
from inference.validation import inference_val

## Exps

In [None]:
EXP_FOLDERS = [  # ENS_6
    LOG_PATH + "2021-12-02/7/",  # 7. Cascade b5 - 0.3179
    LOG_PATH + "2021-12-03/0/",  # 8. Cascade rx101 - 0.3189
#     LOG_PATH + "2021-11-30/2/",  # 2. Cascade r50   - 0.3168
    LOG_PATH + "seb/mrcnn_resnext101_aug_2021-12-06/",  # 18.  maskrcnn rx101 - 0.3197
    LOG_PATH + "seb/maskrcnn_resnet50_2021-12-01/",  # 6. maskrcnn r50 - 0.3173
]

In [None]:
EXP_FOLDERS = [  # new folds
    LOG_PATH + "2021-12-11/2/",  # 1. Cascade b5 - 0.3127
    LOG_PATH + "2021-12-11/4/",  # 2. Cascade rx101 - 0.3141
    LOG_PATH + "2021-12-12/0/",  # 3. Cascade r50 - 0.3125
#     LOG_PATH + "2021-12-12/1/",  # 4. Cascade rx101 - 0.3139
#     LOG_PATH + "2021-12-12/4/",  # 5. Cascade rx50 - 0.3124
#     LOG_PATH + "2021-12-13/1/",  # 6. Cascade rx101 - 0.3127
    LOG_PATH + "seb/mrcnn_resnext101_new_splits/", # 7. maskrcnn rx101 - 0.3120
    LOG_PATH + "seb/mrcnn_resnet50_new_splits/", # 8. maskrcnn r50 - 0.3118
]

## Inference

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))

    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]

    try:
        _ = config.split
        remove_anomalies = config.remove_anomalies
    except:
        config.split = "skf"
        remove_anomalies = False

    configs.append(config)
    
#     weights.append(sorted(glob.glob(exp_folder + "*.pt")))
    weights.append(sorted(glob.glob(exp_folder + "*.pt"))[:1])

In [None]:
df = prepare_data(fix=False, remove_anomalies=remove_anomalies)

In [None]:
ENSEMBLE_CONFIG = {
    "use_tta": True,
    "num_classes": 3,

    "rpn_nms_pre": [5000, 2000, 1000],
    "rpn_iou_threshold": [0.7, 0.75, 0.7],
    "rpn_score_threshold": [0.9, 0.8, 0.9],
    "rpn_max_per_img": [None, None, None],  # [1500, 1000, 500],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.7, 0.9, 0.7],
    "rcnn_score_threshold": [0.2, 0.35, 0.55],
}

In [None]:
ENSEMBLE_CONFIG = {
    "use_tta": True,
    "num_classes": 3,

    "rpn_nms_pre": [3000, 2000, 1000],
    "rpn_iou_threshold": [0.75, 0.75, 0.6],
    "rpn_score_threshold": [0.95, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],  # [1500, 1000, 500],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.75, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.25, 0.5],
}

In [None]:
RPN_NMS_PRES = [
    [5000],
    [2000],
    [1000],
]
RPN_IOU_THS = [
    [0.65, 0.7],
    [0.75],
    [0.5, 0.6, 0.7],
]
RPN_SCORE_THS = [
    [0.8, 0.9],
    [0.8],
    [0.8, 0.9, 0.95],
]

RCNN_IOU_THS = [
    [0],
    [0.9, 0.95],
    [0]
]

In [None]:
count = 0

for class_idx, cell_type in enumerate(CELL_TYPES):
    for rpn_nms_pre, rpn_iou_threshold, rpn_score_threshold, rcnn_iou_threshold in product(
        RPN_NMS_PRES[class_idx], RPN_IOU_THS[class_idx], RPN_SCORE_THS[class_idx], RCNN_IOU_THS[class_idx]
    ):
        if rcnn_iou_threshold == 0:
            rcnn_iou_threshold = rpn_iou_threshold
        if rcnn_iou_threshold < rpn_iou_threshold:
            continue
            
        count += 1
        
print(f'Duration : {count * 8 / 60 / 3 :.1f}h')

In [None]:
thresholds_mask = [0.45]  # [0.4, 0.45, 0.5]
thresholds_nms = [0.05, 0.1, 0.15]
thresholds_conf = [np.round(0.05 * i, 2) for i in range(4, 17)]

best_thresholds_mask = [0.45, 0.45, 0.45]
best_thresholds_nms = [0.1, 0.1, 0.05]
best_thresholds_conf = [0.3, 0.4, 0.7]

CLASSES = CELL_TYPES
pipelines = define_pipelines(config.data_config)

In [None]:
%%time

log_folder = prepare_log_folder(LOG_PATH + "tweak/")
print(f"Logging results to {log_folder}\n")
create_logger(directory=log_folder, name="logs.txt")

for class_idx, cell_type in enumerate(CELL_TYPES):
    print('#' * (len(cell_type) + 4))
    print(f'# {cell_type} #')
    print('#' * (len(cell_type) + 4) + "\n")
    
    for rpn_nms_pre, rpn_iou_threshold, rpn_score_threshold, rcnn_iou_threshold in product(
        RPN_NMS_PRES[class_idx], RPN_IOU_THS[class_idx], RPN_SCORE_THS[class_idx], RCNN_IOU_THS[class_idx]
    ):
        if rcnn_iou_threshold == 0:
            rcnn_iou_threshold = rpn_iou_threshold

        if rcnn_iou_threshold < rpn_iou_threshold:
            continue

        print("rpn_nms_pre : ", rpn_nms_pre)
        print("rpn_iou_threshold : ", rpn_iou_threshold)
        print("rpn_score_threshold : ", rpn_score_threshold)
        print("rcnn_iou_threshold : ", rcnn_iou_threshold)
        
        ENSEMBLE_CONFIG['rpn_nms_pre'] = [rpn_nms_pre] * 3
        ENSEMBLE_CONFIG['rpn_iou_threshold'] = [rpn_iou_threshold] * 3
        ENSEMBLE_CONFIG['rpn_score_threshold'] = [rpn_score_threshold] * 3
        ENSEMBLE_CONFIG['rcnn_iou_threshold'] = [rcnn_iou_threshold] * 3

        all_results, dfs_val = inference_val(
            df.copy(), configs, weights, ENSEMBLE_CONFIG, verbose=0, cell_type=cell_type
        )
        
        dataset = SartoriusDataset(dfs_val[0], transforms=pipelines['val_viz'], precompute_masks=False)

        scores_tweak = tweak_thresholds(
            all_results[0],
            dataset,
            thresholds_mask,
            thresholds_nms,
            thresholds_conf,
            num_classes=len(CLASSES),
            remove_overlap=True,
            corrupt=True,
            cell_types=None
        )

        best_thresholds_mask, best_thresholds_nms, best_thresholds_conf = [], [], []
        best_scores = []

        scores_class = scores_tweak[class_idx].mean(2) 
        idx = np.unravel_index(np.argmax(scores_class, axis=None), scores_class.shape)
        best_score = scores_class[idx]
        best_thresholds = (thresholds_mask[idx[0]], thresholds_nms[idx[1]], thresholds_conf[idx[2]])

        print(f"Best score {best_score:.4f} for thresholds (mask, nms, conf): {best_thresholds}\n")

        del all_results, dfs_val
        gc.collect()
        
#         break