In [1]:
import numpy as np
import pathlib
import os
import random
import time
import numba
import gc 
import sys
from tqdm.notebook import tqdm 
import cv2
import warnings
warnings.filterwarnings('ignore')

#data structure
import pandas as pd

#tiff file
import rasterio 
from rasterio.windows import Window 

#models
import torch
import torch.nn as nn

#data augmentation
import albumentations as A 
import torchvision
from torchvision import transforms as T

In [2]:
import sys
sys.path.append('../input/segmentation-models-pytorch/EfficientNet-PyTorch')
sys.path.append('../input/segmentation-models-pytorch/pretrained-models.pytorch')
sys.path.append('../input/segmentation-models-pytorch/pytorch-image-models')
sys.path.append('../input/segmentation-models-pytorch/segmentation_models.pytorch')
import segmentation_models_pytorch as smp

In [3]:
def set_seeds(seed=21):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [4]:
BASE_DIR = '../input/hubmap-kidney-segmentation'
SAVE_DIR = "/kaggle/working/"
SEED = 21
WINDOW = 1024
NEW_SIZE = 512
OVERLAP = 32 
TH = 0.7
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

torch.cuda.empty_cache()
set_seeds();

In [5]:
@numba.njit()
def rle_numba(pixels):
    size = len(pixels)
    points = []
    if pixels[0] == 1: 
        points.append(0)
    flag = True
    for i in range(1, size):
        if pixels[i] != pixels[i-1]:
            if flag:
                points.append(i+1)
                flag = False
            else:
                points.append(i+1 - points[-1])
                flag = True
    if pixels[-1] == 1: 
        points.append(size-points[-1]+1)    
    return points

def rle_numba_encode(image):
    pixels = image.T.flatten()
    points = rle_numba(pixels)
    return ' '.join(str(x) for x in points)

def make_grid(shape, window=WINDOW, min_overlap=OVERLAP):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    
    nx = x // (window - min_overlap) + 1 
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x) 
    
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

In [6]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32') 

def get_preprocessing():
    _transform = [
        #A.Resize(380)
        #A.Normalize(mean=(0.65459856, 0.48386562, 0.69428385), 
        #            std=(0.15167958, 0.23584107, 0.13146145),
        A.Normalize(mean=(0.485, 0.456, 0.406), 
                    std=(0.229, 0.224, 0.225),
                    max_pixel_value=255.0, always_apply=True, p=1.0),
        A.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return A.Compose(_transform)

In [7]:
p_model = pathlib.Path('../input/overlap-and-inference')
ENCODER = 'efficientnet-b4' #resnext50_32x4d
ACTIVATION = 'sigmoid' 
fold_models = []
  
for model_file in p_model.glob('*.pth'):
    best_model = smp.FPN( #UnetPlusPlus
        encoder_name=ENCODER, 
        encoder_weights=None,
        activation=ACTIVATION,
        in_channels=3,
        classes=1)
    best_model.load_state_dict(torch.load(model_file))
    best_model.to(DEVICE)
    fold_models.append(best_model)

In [8]:
p_base = pathlib.Path(BASE_DIR)
p_model = pathlib.Path('../input/overlap-and-inference')
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
subm = {}

for i, filename in tqdm(enumerate(p_base.glob('test/*.tiff')), 
                        total = len(list(p_base.glob('test/*.tiff')))):
    print(filename) 
    test_image_ds = rasterio.open(filename, transform=identity)
    slices = make_grid(test_image_ds.shape, window=WINDOW, min_overlap=OVERLAP)
    preds = np.zeros(test_image_ds.shape, dtype=np.uint8)
       
    for (x1,x2,y1,y2) in tqdm(slices):
        #get slice image
        image = test_image_ds.read([1,2,3],
                    window=Window.from_slices((x1,x2),(y1,y2)))
        image = np.moveaxis(image, 0, -1) 
        image= cv2.resize(image, (NEW_SIZE, NEW_SIZE), interpolation=cv2.INTER_AREA)
        image = get_preprocessing()(image=image)['image']
        image = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
        
        pred = None
        for fold_model in fold_models:
            with torch.set_grad_enabled(False):
                pred_fold = fold_model.predict(image) 
                if pred is None:
                    pred = pred_fold.squeeze().cpu().numpy()       
                else:
                    pred += pred_fold.squeeze().cpu().numpy()  
        pred = pred / len(fold_models)       
        pred = cv2.resize(pred,(WINDOW, WINDOW), interpolation=cv2.INTER_AREA)
        
        #merge preds
        preds[x1:x2, y1:y2] += (pred > TH).astype(np.uint8)

    del slices, test_image_ds, fold_model, pred, image
    gc.collect()
           
    # fusion or
    preds = (preds > 0).astype(np.uint8) 
    subm[i] = {'id':filename.stem, 'predicted': rle_numba_encode(preds)}
    del preds
    gc.collect()

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

../input/hubmap-kidney-segmentation/test/afa5e8098.tiff


HBox(children=(FloatProgress(value=0.0, max=1710.0), HTML(value='')))


../input/hubmap-kidney-segmentation/test/b9a3865fc.tiff


HBox(children=(FloatProgress(value=0.0, max=1312.0), HTML(value='')))


../input/hubmap-kidney-segmentation/test/c68fe75ea.tiff


HBox(children=(FloatProgress(value=0.0, max=1428.0), HTML(value='')))


../input/hubmap-kidney-segmentation/test/b2dc8411c.tiff


HBox(children=(FloatProgress(value=0.0, max=480.0), HTML(value='')))


../input/hubmap-kidney-segmentation/test/26dc41664.tiff


HBox(children=(FloatProgress(value=0.0, max=1677.0), HTML(value='')))





In [9]:
submission = pd.DataFrame.from_dict(subm, orient='index')
submission.to_csv('submission.csv', index=False)
submission.head()

Unnamed: 0,id,predicted
0,afa5e8098,
1,b9a3865fc,
2,c68fe75ea,
3,b2dc8411c,
4,26dc41664,
