In [1]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [2]:
DEBUG = True

## Initialization

### Imports

In [3]:
import os
import gc
import sys
import cv2
import json
import glob
import torch
import random
import zipfile
import numpy as np
import pandas as pd
import plotly.express as px

from tqdm.notebook import tqdm
from collections import Counter
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset

sys.path.append("../code/")

In [4]:
from utils.torch import seed_everything
from utils.torch import load_model_weights
from model_zoo.models import define_model
from data.transforms import HE_preprocess_test
from utils.rle import rle_encode_less_memory, enc2mask
from utils.plots import plot_thresh_scores

### Seeding

In [5]:
SEED = 2021
seed_everything(SEED)

### Params

In [6]:
from params import *

In [7]:
IMG_PATH = DATA_PATH + 'test/'

## Data

In [8]:
class InferenceDataset(Dataset):
    def __init__(
        self,
        original_img_path,
        df_info,
        rle=None,
        overlap_factor=1,
        tile_size=256,
        reduce_factor=4,
        transforms=None,
    ):
        self.original_img = load_image(original_img_path, df_info, reduce_factor=reduce_factor)
        self.orig_size = self.original_img.shape

        self.raw_tile_size = tile_size
        self.reduce_factor = reduce_factor
        self.tile_size = tile_size

        self.overlap_factor = overlap_factor

        self.positions = self.get_positions()
        self.transforms = transforms

        if rle is not None:
            self.mask = enc2mask(rle, (self.orig_size[1], self.orig_size[0])) > 0
        else:
            self.mask = None

    def __len__(self):
        return len(self.positions)

    def get_positions(self):
        top_x = np.arange(
            0,
            self.orig_size[0],  # +self.tile_size,
            int(self.tile_size / self.overlap_factor),
        )
        top_y = np.arange(
            0,
            self.orig_size[1],  # +self.tile_size,
            int(self.tile_size / self.overlap_factor),
        )
        starting_positions = []
        for x in top_x:
            right_space = self.orig_size[0] - (x + self.tile_size)
            if right_space > 0:
                boundaries_x = (x, x + self.tile_size)
            else:
                boundaries_x = (x + right_space, x + right_space + self.tile_size)

            for y in top_y:
                down_space = self.orig_size[1] - (y + self.tile_size)
                if down_space > 0:
                    boundaries_y = (y, y + self.tile_size)
                else:
                    boundaries_y = (y + down_space, y + down_space + self.tile_size)
                starting_positions.append((boundaries_x, boundaries_y))

        return starting_positions

    def __getitem__(self, idx):
        pos_x, pos_y = self.positions[idx]
        img = self.original_img[pos_x[0]: pos_x[1], pos_y[0]: pos_y[1], :]
        
        img = self.transforms(image=img)["image"]
        
        return img

In [9]:
import tifffile as tiff


def load_image(img_path, df_info, reduce_factor=1):
    """
    Load image and make sure sizes matches df_info
    """
    image_fname = img_path.rsplit("/", -1)[-1]
    
    W = int(df_info[df_info.image_file == image_fname]["width_pixels"])
    H = int(df_info[df_info.image_file == image_fname]["height_pixels"])

    img = tiff.imread(img_path).squeeze()

    channel_pos = np.argwhere(np.array(img.shape) == 3)[0][0]
    W_pos = np.argwhere(np.array(img.shape) == W)[0][0]
    H_pos = np.argwhere(np.array(img.shape) == H)[0][0]

    img = np.moveaxis(img, (H_pos, W_pos, channel_pos), (0, 1, 2))
    
    if reduce_factor > 1:
        img = cv2.resize(
            img,
            (img.shape[1] // reduce_factor, img.shape[0] // reduce_factor),
            interpolation=cv2.INTER_AREA,
        )
        
    return img


## Model

### Define & load

In [10]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [11]:
def load_models(cp_folder):
    config = json.load(open(cp_folder + 'config.json', 'r'))
    config = Config(**config)
    
    weights = sorted(glob.glob(cp_folder + "*.pt"))
    models = []
    
    for weight in weights:
        model = define_model(
            config.decoder,
            config.encoder,
            num_classes=config.num_classes,
            encoder_weights=None,
        )
        
        model = load_model_weights(model, weight)
        models.append(model)
        
    return models

## Inference

### Tile weighting

In [12]:
def get_tile_weighting(size, sigma=1):
    half = size // 2
    w = np.ones((size, size), np.float32)

    x = np.concatenate([np.mgrid[-half:0], np.mgrid[1: half + 1]])[:, None]
    x = np.tile(x, (1, size))
    x = half + 1 - np.abs(x)
    y = x.T

    w = np.minimum(x, y)
    w = (w / w.max()) ** sigma
    w = np.minimum(w, 1)

    return w.astype(np.float16)

### Predict

In [13]:
def predict_entire_mask_no_thresholding(dataset, models, batch_size=8, upscale=True):
    models = [model.to(DEVICE).eval() for model in models]
    
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
    w = get_tile_weighting(dataset.tile_size, sigma=1)

    preds = []
    with torch.no_grad():
        for batch in tqdm(loader):
            model_preds = []
            for model in models:
                pred = model(batch.to(DEVICE))
                _, _, h, w = pred.shape
                
                if upscale:
                    pred = torch.nn.functional.interpolate(
                        pred, (h * dataset.reduce_factor, w * dataset.reduce_factor)
                    )

                pred = pred.sigmoid().detach().cpu().view(-1, h, w).numpy()
                model_preds.append(pred)
                
            model_preds = np.mean(model_preds, 0).astype(np.float16)
            preds.append(model_preds)

    preds = np.concatenate(preds)

    global_pred = np.zeros(
        (dataset.orig_size[0], dataset.orig_size[1]), dtype=np.float16
    )
    global_counter = np.zeros(
        (dataset.orig_size[0], dataset.orig_size[1]), dtype=np.float16
    )

    for tile_idx, (pos_x, pos_y) in enumerate(dataset.positions):
        global_pred[pos_x[0]: pos_x[1], pos_y[0]: pos_y[1]] += preds[tile_idx, :, :] * w
        global_counter[pos_x[0]: pos_x[1], pos_y[0]: pos_y[1]] += w

    global_pred = np.divide(global_pred, global_counter).astype(np.float16)  # divide by overlapping tiles

    return global_pred

### Predictions to RLE

In [14]:
def threshold_resize_encode(preds, shape, threshold=0.5):
    preds = (preds > threshold).astype(np.uint8)
    
    preds = cv2.resize(
        preds,
        (shape[0], shape[1]),
        interpolation=cv2.INTER_AREA,
    )
    
    preds = rle_encode_less_memory(preds)
    
    return preds

## Main

In [15]:
REDUCE_FACTOR = 4
THRESHOLD = 0.4

CP_FOLDER = "../dataset/"

In [16]:
df = pd.read_csv(DATA_PATH + 'sample_submission.csv')
df_info = pd.read_csv(DATA_PATH + "HuBMAP-20-dataset_information.csv")
rles = df_mask = pd.read_csv(DATA_PATH + "train_4.csv")

config = json.load(open(CP_FOLDER + 'config.json', 'r'))
config = Config(**config)

In [17]:
models = load_models(CP_FOLDER)


 -> Loading weights from ../dataset/Unet_efficientnet-b5_0.pt


 -> Loading weights from ../dataset/Unet_efficientnet-b5_1.pt


 -> Loading weights from ../dataset/Unet_efficientnet-b5_2.pt


 -> Loading weights from ../dataset/Unet_efficientnet-b5_3.pt



In [18]:
for img in df['id'].unique():
    
    if DEBUG:
        # Check performances on a validation image
        img = "2f6ecfcdf"
        IMG_PATH = DATA_PATH + "train"
        models = models[:1]
    
    print(f'\n\t Image {img}')
    
    print(f'\n - Building dataset')
    
    rle = rles[rles['id'] == img]["encoding"] if DEBUG else None
    
    predict_dataset = InferenceDataset(
        f"{IMG_PATH}/{img}.tiff",
        df_info,
        rle=rle,
        overlap_factor=config.overlap_factor,
        reduce_factor=REDUCE_FACTOR,
        transforms=HE_preprocess_test(augment=False, visualize=False),
    )
    
    print(f'\n - Predicting masks')

    global_pred = predict_entire_mask_no_thresholding(
        predict_dataset, models, batch_size=config.val_bs // 4, upscale=False
    )
    
    if DEBUG:
        threshold, score = plot_thresh_scores(
            mask=predict_dataset.mask, pred=global_pred, plot=False
        ) 
        print(f" -> Scored {score :.4f} with threshold {threshold:.2f}")
    
    print('\n - Resizing & encoding')
    
    shape = df_info[df_info.image_file == img + ".tiff"][['width_pixels', 'height_pixels']].values.astype(int)[0]
    rle = threshold_resize_encode(global_pred, shape, threshold=THRESHOLD)
    df.loc[df.id == img, 'predicted'] = rle
    
    del global_pred, predict_dataset
    gc.collect()
    
    if DEBUG:
        break
    
# df.to_csv('submission.csv', index=False)


	 Image 2f6ecfcdf

 - Building dataset

 - Predicting masks


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=396.0), HTML(value='')))


 -> Scored 0.9493 with threshold 0.40

 - Resizing & encoding


In [19]:
df

Unnamed: 0,id,predicted
0,2ec3f1bb9,60762295 20 60786285 20 60810275 20 60834265 2...
1,3589adb90,68658950 68 68688383 68 68717816 68 68747249 6...
2,d488c759a,548575513 60 548622173 60 548668833 60 5487154...
3,aa05346ff,52856681 48 52887401 48 52918121 48 52948841 4...
4,57512b7f1,328952557 28 328985797 28 329019037 28 3290522...
