In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="7"

import pandas as pd
import numpy as np
import seaborn as sns
import os
import torch
import cv2
import gc
import torchvision
import ttach as tta
import albumentations as A

from glob import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
from itertools import chain
from os.path import join as pjoin
from copy import deepcopy
from albumentations.pytorch import ToTensorV2
from PIL import Image

from code_base.models import SMPWrapper, TransformerWrapper
from code_base.datasets import HubMapDataset
from code_base.utils.mask import rle_decode, overlay_mask
from code_base.utils.other import imread_rgb
from code_base.inference import apply_avarage_weights_on_swa_path, HubMapInference
from code_base.constants import CLASSES, PIXEL_SCALE
from code_base.utils.metrics import dice_coeff
from code_base.models import SegmentationTTAWrapperKwargs

In [None]:
!nvidia-smi

# Config

In [None]:
print("Possible exps:\n\n{}".format("\n".join(os.listdir(f"../logdirs/"))))

In [None]:
EXP_NAME = "???"
print("Possible checkpoints:\n\n{}".format("\n".join(set([os.path.basename(el) for el in glob(
    f"../logdirs/{EXP_NAME}/*/checkpoints/*.pt*"
) if "train" not in os.path.basename(el)]))))

In [None]:
CONFIG = {
    # Main
    "sliding_window_config": None,
#     {
#         "roi_size": (512, 512),
#         "sw_batch_size": 16,
#         "overlap": 0.75,
#         "padding_mode": "reflect"
#     },
    "fill_binary_holes": True,
    "test_tresh": 0.4,
#     {
#         'Hubmap': {
#             'kidney'        : 0.35,
#             'prostate'      : 0.35,
#             'largeintestine': 0.35,
#             'spleen'        : 0.35,
#             'lung'          : 0.075,
#         },
#         'HPA': {
#             'kidney'        : 0.45,
#             'prostate'      : 0.45,
#             'largeintestine': 0.45,
#             'spleen'        : 0.45,
#             'lung'          : 0.075,
#         },
#     },
    "use_amp": True,
    # Class names refers to CLASSES.index
    "min_area": [
        0.001, 0.0005, 0.0001, 
        0.001, 1e-06
    ],
    "is_relative_min_area": True,
    # Data config
    "n_folds":5,
    "train_data_root":"data/hpa/train_images/",
    "batch_size": 1,
    "num_workers": 0,
    "pad_config": dict(
        min_height=None, 
        min_width=None, 
        pad_height_divisor=32, 
        pad_width_divisor=32, 
    ),
    "use_one_channel_mask": True,
    "to_rgb": True,
    "additional_scalers": {
        'prostate': 0.15 * 2,
        'spleen': 1 * 2,
        'lung': 0.5 * 2,
        'kidney': 1 * 2,
        'largeintestine': 1 * 2
    },
    # Model config
    "exp_name":EXP_NAME,
    "model_class": SMPWrapper,
    "model_config": { 
#         "backbone_name": "timm-efficientnet-b5",
        "backbone_name": "timm-efficientnet-b7",
        # "num_classes": len(CLASSES),
        "num_classes": 1,
#         "arch_name": "Unet",
        "arch_name": "UnetPP",
#         "arch_name": "UnetGC",
        # "arch_name": "FPN",
        "pretrained":False,
#         "use_slice_idx": True,
#         "aux_params": {"classes": len(CLASSES)},
#         "return_only_mask": True
        
    },
    "tta_transforms": tta.aliases.d4_transform(),
    "batched_tta": True,
    "chkp_name":"swa_models_valid_dice_score.pt",
    "swa_checkpoint": None,
    "distributed_chkp": True,
    "use_sigmoid": True,
}


In [None]:
# solo tresh 
if isinstance(CONFIG["test_tresh"], float):
    humbap_tresh = CONFIG["test_tresh"]
    hpa_tresh = CONFIG["test_tresh"]
# multi tresh
elif isinstance(CONFIG["test_tresh"], dict):
    humbap_tresh = [CONFIG["test_tresh"]["Hubmap"][cls_name] for cls_name in CLASSES]
    hpa_tresh = [CONFIG["test_tresh"]["HPA"][cls_name] for cls_name in CLASSES]

In [None]:
ADITIONAL_MODELS = [
    {
        "exp_name":"???",
        "model_config": {
            "backbone_name": "timm-efficientnet-b7",
            "num_classes": 1, 
            "arch_name": "Unet",
            "pretrained":False,
        },
        "distributed_chkp": True,
        "chkp_name":"swa_models_valid_dice_score.pt",
        "tta_transforms": tta.aliases.d4_transform(),
        "batched_tta": True,
    },
]

# ADITIONAL_MODELS = None

# Data

In [None]:
hpa_df = glob("data/hpa/hpa_add_prostate/prostate_hpa/prostate_images/*.jpg")
hpa_df = [os.path.splitext(os.path.basename(el))[0] for el in hpa_df]
hpa_df = pd.DataFrame({
    "id": hpa_df,
    "organ": ["prostate"] * len(hpa_df),
    "tissue_thickness": [4] * len(hpa_df),
    "pixel_size": [0.4] * len(hpa_df),
    "data_source": ["HPA"] * len(hpa_df),
})
img_size = [
    Image.open(pjoin("data/hpa/hpa_add_prostate/prostate_hpa/prostate_images/", hpa_df.id.iloc[i]) + ".jpg").size for i in tqdm(range(len(hpa_df)))
]
hpa_df["img_height"] = [el[1] for el in img_size]
hpa_df["img_width"] = [el[0] for el in img_size]
hpa_df.head()

In [None]:
loader_config = {
    "batch_size": CONFIG["batch_size"],
    "drop_last": False,
    "shuffle": False,
    "num_workers": CONFIG["num_workers"],
}
hpa_dataset_config = {
    "root": (
        "data/hpa/hpa_add_prostate/prostate_hpa/prostate_images"
    ),
    "img_size": None,
    "test_mode": True,
    "precompute": False,
    "dynamic_resize_mode": "scale_or",
    "use_one_channel_mask": CONFIG["use_one_channel_mask"],
    "additional_scalers": CONFIG.get("additional_scalers", None),
    "to_rgb": CONFIG["to_rgb"],
    "ext": ".jpg",
    "transform": [
        A.PadIfNeeded(
            border_mode=4, 
            value=None, 
            mask_value=None, 
            always_apply=True,
            **CONFIG["pad_config"]
        ),
        A.Normalize(), ToTensorV2(transpose_mask=True)
    ]
}
# hubmap_dataset_config = {
#     "root": "../input/hubmap-organ-segmentation/test_images/",
#     "img_size": None,
#     "test_mode": True,
#     "precompute": False,
#     "dynamic_resize_mode": None,
#     "use_one_channel_mask": CONFIG["use_one_channel_mask"],
#     "additional_scalers": CONFIG.get("additional_scalers", None),
#     "to_rgb": CONFIG["to_rgb"],
#     "transform": [
#         A.PadIfNeeded(
#             border_mode=4, 
#             value=None, 
#             mask_value=None, 
#             always_apply=True,
#             **CONFIG["pad_config"]
#         ),
#         A.Normalize(), ToTensorV2(transpose_mask=True)
#     ]
# }
hpa_loader = torch.utils.data.DataLoader(
    HubMapDataset(df=hpa_df, **hpa_dataset_config),
    **loader_config
)
# hubmap_loader = torch.utils.data.DataLoader(
#     HubMapDataset(df=hubmap_df, **hubmap_dataset_config),
#     **loader_config
# )

# Model

In [None]:
def set_dropout_zero(model):
    for name, child in model.named_children():
        if isinstance(child, (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)):
            child.p = 0.0
        set_dropout_zero(child)

def create_model_and_upload_chkp(
    model_class,
    model_config,
    model_device,
    model_chkp,
    use_distributed=False,
    swa_checkpoint=None,
    tta_transform=None,
    tta_merge_mode="mean",
    batched_tta=False,
    instance_norm_mode=False
):
    if "swa" in model_chkp:
        print("swa by {}".format(os.path.splitext(os.path.basename(model_chkp))[0]))
        t_chkp = apply_avarage_weights_on_swa_path(model_chkp, use_distributed=use_distributed, take_best=swa_checkpoint)
    else:
        print("vanilla model")
        t_chkp = torch.load(model_chkp, map_location="cpu")
        
    t_model = model_class(**model_config, device=model_device)
    t_model.load_state_dict(t_chkp)
    if instance_norm_mode:
        t_model.train()
        set_dropout_zero(t_model)
    else:
        t_model.eval()
    if tta_transform is not None:
        print("Wrapping model in TTA")
        t_model = SegmentationTTAWrapperKwargs(t_model, tta_transform, merge_mode=tta_merge_mode)
    return t_model

In [None]:
model = [create_model_and_upload_chkp(
        model_class=CONFIG["model_class"],
        model_config=CONFIG['model_config'],
        model_device="cuda",
        model_chkp=f"../logdirs/{EXP_NAME}/fold_{m_i}/checkpoints/{CONFIG['chkp_name']}",
        swa_checkpoint=CONFIG['swa_checkpoint'],
        use_distributed=CONFIG['distributed_chkp'],
        tta_transform=CONFIG.get("tta_transforms", None),
        batched_tta=CONFIG.get("batched_tta", False)
) for m_i in range(CONFIG["n_folds"])]

In [None]:
if ADITIONAL_MODELS is not None:
    for add_conf in ADITIONAL_MODELS:
        model += [create_model_and_upload_chkp(
                model_class=CONFIG["model_class"],
                model_config=add_conf['model_config'],
                model_device="cuda",
                model_chkp=f"../logdirs/{add_conf['exp_name']}/fold_{m_i}/checkpoints/{add_conf['chkp_name']}",
                swa_checkpoint=CONFIG['swa_checkpoint'],
                use_distributed=add_conf['distributed_chkp'],
                tta_transform=add_conf.get("tta_transforms", None),
                batched_tta=add_conf.get("batched_tta", False)
        ) for m_i in range(CONFIG["n_folds"])]


In [None]:
print(f"Total models in Blend = {len(model)}")
gc.collect()

# Inference

In [None]:
inference_class = HubMapInference(
    device="cuda",
    verbose=True,
    verbose_tqdm=True,
    use_sigmoid=CONFIG["use_sigmoid"],
    use_amp=CONFIG["use_amp"]
)

In [None]:
# hubmap_test_pred = inference_class.predict_test_loader(
#     nn_models=model,
#     test_loader=hubmap_loader,
#     tresh=humbap_tresh,
#     pad_config=CONFIG["pad_config"],
#     min_area=CONFIG.get("min_area", None),
#     is_relative_min_area=CONFIG.get("is_relative_min_area", False),
#     # For HubMap we do not need to resize something. Just crop padding (to X32)
#     # In case of additional_scalers we have to rescale
#     use_rescaled=CONFIG.get("additional_scalers", None) is not None,
#     scale_back=CONFIG.get("additional_scalers", None) is not None,
#     fill_binary_holes=CONFIG["fill_binary_holes"],
#     sliding_window_config=CONFIG["sliding_window_config"],
#     print_input_shape=True
# )

In [None]:
hpa_test_pred = inference_class.predict_test_loader(
    nn_models=model,
    test_loader=hpa_loader,
    tresh=hpa_tresh,
    pad_config=CONFIG["pad_config"],
    min_area=CONFIG.get("min_area", None),
    is_relative_min_area=CONFIG.get("is_relative_min_area", False),
    # For HPA we have to crop padding (to X32) in rescaled domain and then scale back to original sizes
    use_rescaled=True,
    scale_back=True,
    fill_binary_holes=CONFIG["fill_binary_holes"],
    sliding_window_config=CONFIG["sliding_window_config"],
    print_input_shape=False,
    mean_type=CONFIG.get("mean_type", "mean"),
    save_mask_path="data/hpa/hpa_add_prostate/ensem_083"
)

In [None]:
hpa_test_pred = hpa_df.merge(hpa_test_pred[["id", "rle"]], on="id")
hpa_test_pred

In [None]:
hpa_test_pred.to_csv("data/hpa/hpa_add_prostate/ensem_083.csv", index=False)

# Visualise

In [None]:
idx = 6000

img = imread_rgb(
    f"data/hpa/hpa_add_prostate/prostate_hpa/prostate_images/{hpa_test_pred.id.iloc[idx]}.jpg"
)
mask = rle_decode(hpa_test_pred.rle.iloc[idx], shape=(hpa_test_pred.iloc[idx].img_width, hpa_test_pred.iloc[idx].img_height))
plt.figure(figsize=(10,10))
plt.title("Mask")
plt.imshow(mask)
plt.show()
plt.figure(figsize=(10,10))
plt.title("Mask Overlay")
plt.imshow(overlay_mask(img, mask, color_id=0))
plt.show()