In [1]:
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

## Initalization

### Imports

In [2]:
import os
import sys
import torch
import zipfile
import numpy as np
import pandas as pd
import plotly.express as px

from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader

sys.path.append("../code/")

<IPython.core.display.Javascript object>

In [5]:
from params import *

from data.dataset import InferenceDataset
from data.transforms import HE_preprocess
from model_zoo.models import define_model
from utils.metrics import dice_scores_img
from utils.plots import plot_global_pred, plot_thresh_scores
from training.predict import predict_entire_mask

<IPython.core.display.Javascript object>

### Load

In [17]:
df_info = pd.read_csv(DATA_PATH + f"HuBMAP-20-dataset_information.csv")
df_mask = pd.read_csv(DATA_PATH + "train.csv")
df_images = pd.read_csv(OUT_PATH + "df_images.csv")

<IPython.core.display.Javascript object>

## Viz

In [None]:
# Load pretrained models
cv_column = "4fold" #fold
encoder = "se_resnext50_32x4d" #"resnet101"
decoder = "Unet"
tile_size = 256   #the size of tiles
reduce_fac = 4 #reduce the original images by 4 times

In [None]:
all_scores = []
for fold_nb in df_images[cv_column].unique():
    model_name = f"{encoder}_{decoder}_{tile_size}_{reduce_fac}_fold{fold_nb}"
    model = torch.jit.load(f'../models/{model_name}.jit').to("cuda")
    # Here use validation images to visualize fair predictions
    mask_names = df_images[df_images[cv_column]==fold_nb].tile_name.apply(lambda x: x.split("_")[0]).unique()
    for maskname in mask_names:
        print(maskname)
        predict_dataset = PredictFromImgDataset(f'../input/hubmap-kidney-segmentation/train/{maskname}.tiff',
                                                mask_name=maskname,
                                                overlap_factor=4,
                                                reduce_factor=reduce_fac,
                                                transforms=HE_preprocess(augment=False, visualize=False))

        threshold = None
        global_pred = predict_entire_mask(predict_dataset,
                                          model,
                                          batch_size=16,
                                          threshold=threshold)

        plot_global_pred(mask=predict_dataset.mask, pred=global_pred>0)
        if threshold is None:
            thresholds, scores = plot_thresh_scores(mask=predict_dataset.mask, pred=global_pred)
            all_scores.append(scores)

In [None]:
encoded_mask = rle_encode_less_memory(global_pred)

<IPython.core.display.Javascript object>

# Visualize some results (raw results - not averaged)

In [None]:
model.eval()

BS = 16
val_df = df_images[df_images.fold==0].reset_index()
val_dataset = TileDataset(val_df,
                          IMG_FOLDER, MSK_FOLDER,
                          transforms=HE_preprocess(augment=False, visualize=False))

preds_dl = DataLoader(val_dataset, batch_size=BS,shuffle=True,num_workers=1)

imgs,masks = next(iter(preds_dl))

fig = plt.figure(figsize=(8,4*BS))
for i,(img,mask) in enumerate(zip(imgs,masks)):
    raw_pred = model(img.unsqueeze(0).to('cuda'))
    pred = torch.sigmoid(raw_pred) > 0.5
    pred = pred.detach().cpu().squeeze().numpy()
    plt.subplot(BS, 2, 2*i+1)
    img = img.permute(1,2,0).numpy()#.astype(np.uint8)
    plt.imshow(img,vmin=0,vmax=255
              )
    plt.imshow(pred, alpha=0.5)
    plt.axis('off')
    plt.subplot(BS, 2 ,2*i+2)
    plt.imshow(img,vmin=0,vmax=255
              )
    plt.imshow(mask.squeeze().numpy(), alpha=0.5)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)