In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import pandas as pd
import numpy as np
import seaborn as sns
import os
import torch
import cv2
import torchvision
import ttach as tta
import gc
import albumentations as A

from glob import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
from itertools import chain
from os.path import join as pjoin
from copy import deepcopy
from pprint import pprint
from itertools import product
from matplotlib.patches import Rectangle
from albumentations.pytorch import ToTensorV2

from code_base.models import SMPWrapper, TransformerWrapper
from code_base.datasets import HubMapDataset
from code_base.utils.mask import rle_decode, overlay_mask
from code_base.inference import apply_avarage_weights_on_swa_path, HubMapInference
from code_base.constants import CLASSES, PIXEL_SCALE
from code_base.utils.metrics import dice_coeff
from code_base.models import SegmentationTTAWrapperKwargs
from code_base.models.denis_models.custom_unet import get_model
from code_base.models.denis_models.upsample import (
    UpsampleNearestAdditiveUpsample2x, 
    UpsampleBilinearAdditiveUpsample2x, 
    UpsampleResidualDeconvolutionUpsample,
    BilinearUpsample4x, 
    BilinearUpsample2x, 
    PixelShuffle4x
)
from code_base.models.denis_models.unet_blocks import NestedInception
%matplotlib inline


# Config

In [None]:
print("Possible exps:\n\n{}".format("\n".join([el for el in os.listdir("../logdirs/")])))

In [None]:
EXP_NAME = "???"
print("Possible checkpoints:\n\n{}".format("\n".join(set([os.path.basename(el) for el in glob(f"../logdirs/{EXP_NAME}/*/checkpoints/*.pt*") if "train" not in os.path.basename(el)]))))

In [None]:
conf_path = glob(f"../logdirs/{EXP_NAME}/code/train_configs___*.py")
assert len(conf_path) == 1
conf_path = conf_path[0]
!cat {conf_path}

In [None]:
SLIDING_CONFIG = {
    "roi_size": (1024, 1024),
    "sw_batch_size": 4,
    "overlap": 0.75,
    "padding_mode": "reflect"
}

CONFIG = {
    # Main
    "sliding_window_config":(
        # Unet++ models
        [None] * 5 +
        # mitb3 models
        [deepcopy(SLIDING_CONFIG) for _ in range(5)] + 
        # Unet models
        [None] * 5 + 
        # Unet models
        [None] * 5 
    ),
    "fill_binary_holes": True,
    "test_tresh": 0.5,
    "use_amp": True,
    "min_area": [
        0.001, 0.0005, 0.0001, 
        0.001, 1e-06
    ],
    # "is_relative_min_area": True,
    # Data config
    "train_df_path":"data/train.csv",
    "split_path":"data/cv_split5_v2.npy",
    "n_folds":5,
    "train_data_root":"data/train_images/",
    "batch_size": 1,
    "num_workers": 1,
    "pad_config": dict(
        min_height=None, 
        min_width=None, 
        pad_height_divisor=32, 
        pad_width_divisor=32, 
    ),
    "use_one_channel_mask": True,
    "to_rgb": True,
    "additional_scalers": {
        'prostate': 0.15 * 2,
        'spleen': 1 * 2,
        'lung': 0.5 * 2,
        'kidney': 1 * 2,
        'largeintestine': 1 * 4
    },
    # Model config
    "exp_name":EXP_NAME,
    "model_class": SMPWrapper,
    # "model_class": get_model,
    # "model_class": TransformerWrapper,
    "model_config": { 
        # "backbone_name": "mit_b5",
        # "backbone_name": "timm-efficientnet-b5",
        "backbone_name": "timm-efficientnet-b7",
        # "num_classes": len(CLASSES),
        "num_classes": 1,
        # "arch_name": "Unet",
        "arch_name": "UnetPP",
        # "arch_name": "FPN",
        # "arch_name": "UnetAsymmetric",
        # "arch_name": "UnetMultiHead",
        # "arch_name": "UnetGC",
        "pretrained":False,
        # "use_slice_idx": True,
        # "case_embedding_dim": 64
        # "aux_params": {"classes": len(CLASSES)},
        # "return_only_mask": True
    },
    # "model_config": dict(
    #     model_name="tf_efficientnetv2_l_in21k", 
    #     in_channels=3, 
    #     out_channels=1,
    #     channel_attention=False, 
    #     positional_attention=False,
    #     norm=torch.nn.BatchNorm2d,
    #     bias=False,
    #     se='TRIPLET',
    #     attn_unet=True, 
    #     layers=range(0,4), 
    #     DO=0.0,
    #     multistage_upsample=True,
    #     n_blocks=1, 
    #     block=NestedInception, 
    #     upsample=UpsampleBilinearAdditiveUpsample2x,
    #     pretrained=False,
    #     drop_rate=0.0
    # ),
    # "model_config": {
    #     "model_name": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024",
    #     "n_classes": 1,  # len(CLASSES),
    #     "pretrained": False
    # },
    "tta_transforms": tta.aliases.flip_transform(),
    "chkp_name":"swa_models_valid_dice_score.pt",
    "swa_checkpoint": None,
    "distributed_chkp": True,
    "use_sigmoid": True,
}

ORGANS_INCLUDE = None #['lung']

In [None]:
ADITIONAL_MODELS = [
    {
        "exp_name":"???",
        "model_config": {
            "backbone_name": "mit_b5",
            "num_classes": 1, 
            "arch_name": "Unet",
            "pretrained":False,
        },
        "distributed_chkp": True,
        "chkp_name":"swa_models_valid_dice_score.pt",
        "tta_transforms": tta.aliases.flip_transform()
    },
    {
        "exp_name":"???",
        "model_config": {
            "backbone_name": "timm-efficientnet-b7",
            "num_classes": 1, 
            "arch_name": "Unet",
            "pretrained":False,
            "point_rand_config": {"in_ch": 161, "num_classes": 1, "backbone_type": "effnet"},
        },
        "distributed_chkp": True,
        "chkp_name":"swa_models_valid_dice_score.pt",
        "tta_transforms": tta.aliases.flip_transform()
    },
    {
        "exp_name":"???",
        "model_config": {
            "backbone_name": "timm-efficientnet-b7",
            "num_classes": 1, 
            "arch_name": "Unet",
            "pretrained":False,
            "point_rand_config": {"in_ch": 161, "num_classes": 1, "backbone_type": "effnet"},
        },
        "distributed_chkp": True,
        "chkp_name":"swa_models_valid_dice_score.pt",
        "tta_transforms": tta.aliases.flip_transform()
    },
]

# ADITIONAL_MODELS = None

# Data

In [None]:
df = pd.read_csv(CONFIG["train_df_path"])
split = np.load(CONFIG["split_path"], allow_pickle=True)
val_df = [df.iloc[split[i][1]].reset_index(drop=True) for i in range(len(split[:CONFIG["n_folds"]]))]
if ORGANS_INCLUDE is not None:
    df = df[df["organ"].isin(ORGANS_INCLUDE)].reset_index(drop=True)
    print(f"Considering organs: {set(df['organ'])}")
    val_df = [el[el["organ"].isin(ORGANS_INCLUDE)].reset_index(drop=True) for el in val_df]

In [None]:
loader_config = {
    "batch_size": CONFIG["batch_size"],
    "drop_last": False,
    "shuffle": False,
    "num_workers": CONFIG["num_workers"],
}
ds_config = {
    "root": CONFIG["train_data_root"],
    "img_size": None,
    "test_mode": True,
    "precompute": False,
    "dynamic_resize_mode": "scale_or",
    "use_one_channel_mask": CONFIG["use_one_channel_mask"],
    "additional_scalers": CONFIG.get("additional_scalers", None),
    "to_rgb": CONFIG["to_rgb"],
    "transform": A.Compose([
        A.PadIfNeeded(
            border_mode=4, 
            value=None, 
            mask_value=None, 
            always_apply=True,
            **CONFIG["pad_config"]
        ),
        A.Normalize(), ToTensorV2(transpose_mask=True)
    ]),
}
ds_test = [HubMapDataset(df=df, **ds_config) for df in val_df]
loader_test = [torch.utils.data.DataLoader(
    ds,
    **loader_config,
)for ds in ds_test]

# Model

In [None]:
def create_model_and_upload_chkp(
    model_class,
    model_config,
    model_device,
    model_chkp,
    use_distributed=False,
    swa_checkpoint=None,
    tta_transform=None,
    tta_merge_mode="mean"
):
    if "swa" in model_chkp:
        print("swa by {}".format(os.path.splitext(os.path.basename(model_chkp))[0]))
        t_chkp = apply_avarage_weights_on_swa_path(model_chkp, use_distributed=use_distributed, take_best=swa_checkpoint)
    else:
        print("vanilla model")
        t_chkp = torch.load(model_chkp, map_location="cpu")
        
    t_model = model_class(**model_config, device=model_device)
    t_model.load_state_dict(t_chkp)
    t_model.eval()
    if tta_transform is not None:
        print("Wrapping model in TTA")
        t_model = SegmentationTTAWrapperKwargs(t_model, tta_transform, merge_mode=tta_merge_mode)
    return t_model

In [None]:
model = [[create_model_and_upload_chkp(
        model_class=CONFIG["model_class"],
        model_config=CONFIG['model_config'],
        model_device="cuda",
        model_chkp=f"../logdirs/{CONFIG['exp_name']}/fold_{m_i}/checkpoints/{CONFIG['chkp_name']}",
        swa_checkpoint=CONFIG['swa_checkpoint'],
        use_distributed=CONFIG['distributed_chkp'],
        tta_transform=CONFIG.get("tta_transforms", None),
) for m_i in range(CONFIG["n_folds"])]]

In [None]:
if ADITIONAL_MODELS is not None:
    for add_conf in ADITIONAL_MODELS:
        model.append([create_model_and_upload_chkp(
                model_class=CONFIG["model_class"],
                model_config=add_conf['model_config'],
                model_device="cuda",
                model_chkp=f"../logdirs/{add_conf['exp_name']}/fold_{m_i}/checkpoints/{add_conf['chkp_name']}",
                swa_checkpoint=CONFIG['swa_checkpoint'],
                use_distributed=add_conf['distributed_chkp'],
                tta_transform=add_conf.get("tta_transforms", None),
        ) for m_i in range(CONFIG["n_folds"])])

In [None]:
print(f"Total Exps in Blend = {len(model)}")
gc.collect()

# Inference

In [None]:
inference_class = HubMapInference(
    device="cuda",
    verbose=True,
    verbose_tqdm=True,
    use_sigmoid=True,
    use_amp=CONFIG["use_amp"]
)

In [None]:
CONFIG["exp_name"], CONFIG["chkp_name"]

In [None]:
test_pred = []
for i in range(CONFIG["n_folds"]):
    test_pred_temp = inference_class.predict_test_loader(
        # j - iterates over exps, i - over folds
        nn_models=[model[j][i] for j in range(len(model))],
        test_loader=loader_test[i],
        tresh=CONFIG["test_tresh"],
        pad_config=CONFIG["pad_config"],
        min_area=CONFIG.get("min_area", None),
        is_relative_min_area=CONFIG.get("is_relative_min_area", False),
        use_rescaled=True,
        scale_back=True,
        fill_binary_holes=CONFIG["fill_binary_holes"],
        sliding_window_config=CONFIG["sliding_window_config"],
        print_input_shape=False,
        save_mask_path=f"./temp_1/fold_{i}"
    )
    assert set(test_pred_temp["id"]) == set(val_df[i]["id"])
    test_pred.append(
        test_pred_temp
    )

In [None]:
test_pred = pd.concat(test_pred)

# Optimize tresholds

In [None]:
result_df = pd.concat(val_df)
result_df = result_df[["id", "rle", "organ", "img_width", "img_height"]].rename(columns={"rle":"real"}).merge(test_pred.rename(columns={"rle":"pred"}), on="id").reset_index(drop=True)
result_df

In [None]:
pred_masks_pathes = glob("temp_1/*/*.png")
pred_masks_pathes = pd.DataFrame({
    "path": pred_masks_pathes,
    "id": [int(os.path.basename(el)[:-4]) for el in pred_masks_pathes]
}).set_index("id")["path"]
pred_masks_pathes

In [None]:
real_masks = []
real_organs = []
pred_masks = []
for id, rle_real, rle_pred, organ, w, h in tqdm(zip(result_df.id, result_df.real, result_df.pred, result_df.organ, result_df.img_width, result_df.img_height)):
    real_masks.append(rle_decode(rle_real, shape=(w, h)))
    pred_masks.append(cv2.imread(pred_masks_pathes.loc[id], 0).astype(float) / 255.0)
    real_organs.append(organ)
real_organs = np.array(real_organs)

In [None]:
thresh_search_space = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]

organ_scores_treshes = {}
for organ in tqdm(set(real_organs)):
    organ_scores_treshes[organ] = {"thr": [], "score": []}
    organ_idxs = np.where(real_organs == organ)[0]
    for thresh in thresh_search_space:
        dice_coefs = []
        for idx in organ_idxs:
            dice_coefs.append(dice_coeff(
                real_masks[idx] > 0.5, 
                pred_masks[idx] > thresh
            ))
        organ_scores_treshes[organ]["thr"].append(thresh)
        organ_scores_treshes[organ]["score"].append(np.mean(dice_coefs))

In [None]:
best_score = 0
organ_vc = result_df["organ"].value_counts().to_dict()
for organ, stats in organ_scores_treshes.items():
    best_stat_id = np.argmax(stats["score"])
    print(
        f"For {organ} score reached {stats['score'][best_stat_id]} on tresh {stats['thr'][best_stat_id]}"
    )
    best_score += organ_vc[organ] * stats['score'][best_stat_id]
print(f"With opt tresh Mean score = {best_score / len(result_df)}")

In [None]:
!rm temp_1 -rf

# Compute Final Metric

In [None]:
# result_df = pd.concat(val_df)
# result_df = result_df[["id", "rle", "organ", "img_width", "img_height"]].rename(columns={"rle":"real"}).merge(test_pred.rename(columns={"rle":"pred"}), on="id").reset_index(drop=True)
# result_df

In [None]:
dice_coefs = []
for rle_real, rle_pred, organ, w, h in tqdm(zip(result_df.real, result_df.pred, result_df.organ, result_df.img_width, result_df.img_height)):
    mask_real = rle_decode(rle_real, shape=(w, h))
    mask_pred = rle_decode(rle_pred, shape=(w, h))
    # mask_real = cv2.resize(
    #     mask_real, 
    #     (int(w / PIXEL_SCALE[organ]), int(h / PIXEL_SCALE[organ])),
    #     interpolation=cv2.INTER_NEAREST
    # )
    # mask_pred = rle_decode(
    #     rle_pred, 
    #     shape=(int(w / PIXEL_SCALE[organ]), int(h / PIXEL_SCALE[organ]))
    # )
    dice_coefs.append(dice_coeff(mask_real > 0.5, mask_pred > 0.5))
result_df["dice"] = dice_coefs

In [None]:
print(f"{EXP_NAME}\nMean Dice = {result_df['dice'].mean()}")
result_df.groupby("organ")['dice'].mean()