In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import cv2
import matplotlib.pyplot as plt
import os 
import random
from tqdm import tqdm
import gc

import segmentation_models_pytorch as smp
import albumentations as A
import joblib
from empatches import EMPatches
import nibabel as nib
import cc3d
from scipy.ndimage import median_filter

from helper_scripts.surface_dice_score import compute_surface_dice_score

In [None]:
class CFG:
    data_root_path = '/path/to/data/folder'

    seed = 42

    epochs = 10
    valid_batch_size = 32
    workers = 8
    accelerator = "gpu"

    valid_overlap = 0.1
    patch_size = 256

    seg_model = "Unet" 
    encoder_name = 'tu-maxvit_tiny_tf_512' 

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

df = pd.read_csv(f'{CFG.data_root_path}/train_rles.csv')
df['image'] = df['id'].apply(lambda x: x.split('_')[-1])
valid_df = df[df['id'].str.contains('kidney_2')].iloc[900:].reset_index(drop=True)[['id']]

In [None]:
def set_seed(seed = 42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(CFG.seed)

In [None]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

In [None]:
valid_transform = A.Compose([
])

In [None]:
df = df[df['id'].str.contains('kidney_1_dense|kidney_2|kidney_3_dense')].reset_index(drop=True)
df['kidney'] = df.id.apply(lambda x: x.rsplit('_',1)[0])

def create_image_path(row):
    if row.kidney == 'kidney_3_dense':
        image_path = f'{CFG.data_root_path}/train/kidney_3_sparse/images/{row.image}.tif'
    else:
        image_path = f'{CFG.data_root_path}/train/{row.kidney}/images/{row.image}.tif'
    return image_path
def create_mask_path(row):
    if row.kidney == 'kidney_3_dense':
        mask_path = f'{CFG.data_root_path}/train/kidney_3_dense/labels/{row.image}.tif'
    else:
        mask_path = f'{CFG.data_root_path}/train/{row.kidney}/labels/{row.image}.tif'
    return mask_path

df['image_path'] =  df.apply(create_image_path, axis=1)
df['mask_path'] =  df.apply(create_mask_path, axis=1)

In [None]:
def create_kidney_volume(kidney, df):
    df = df[df['kidney'].str.contains(kidney)].sort_values('image', ascending=True).reset_index(drop=True)
    all_images = []
    all_masks = []
    for i in tqdm(range(len(df))):
        row = df.iloc[i]

        image = cv2.imread(row.image_path, cv2.IMREAD_GRAYSCALE)
        image = torch.from_numpy(image.copy())
        image = image.to(torch.uint8)
        all_images.append(image)

        mask = cv2.imread(row.mask_path, cv2.IMREAD_GRAYSCALE)
        mask = torch.from_numpy(mask.copy())
        mask = mask.to(torch.uint8)
        all_masks.append(mask)
    all_images = torch.stack(all_images)
    all_masks = torch.stack(all_masks)
    return all_images, all_masks

In [None]:
valid_images, valid_masks = create_kidney_volume('kidney_2', df)
valid_data = {'kidney_2': [valid_images, valid_masks]}

In [None]:
def create_grid(images):
    row1 = np.concatenate([images[0, :, :],images[1, :, :]], axis=1)
    row2 = np.concatenate([images[2, :, :],images[3, :, :]], axis=1)
    image = row2 = np.concatenate([row1, row2], axis=0)
    return image

def preprocess_image(image, lo, hi):
    image = image.to(torch.float32)
    image = (image - lo) / (hi - lo)
    image = torch.clamp(image, min=0.5)
    return image

def preprocess_mask(mask):
    mask = mask.to(torch.float32)
    mask /= 255.0
    return mask

def get_patch_id_list(data, truncate=0, return_indices=False):
    emp = EMPatches()
    img = data[0]
    img_patches, image_indices = emp.extract_patches(img, patchsize=CFG.patch_size, overlap=CFG.valid_overlap)
    patch_ids = []
    for image_id in range(data.shape[0]-truncate):
        for patch in range(len(img_patches)):
            patch_ids.append(f'{image_id}_{patch}')
    if return_indices:
        return patch_ids, len(img_patches), image_indices
    return patch_ids

def get_percentile_dict():
    percentile_dict = {}
    for kidney in ['kidney_2']:
        if kidney == 'kidney_2':
            lo, hi = np.percentile(valid_data[kidney][0].numpy(), (2, 98))
        percentile_dict[kidney] = [lo, hi]
    return percentile_dict

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, patch_ids, albumentation=False):
        self.data = data
        self.patch_ids = patch_ids
        self.emp = EMPatches()
        self.lo, self.hi = np.percentile(data.numpy(), (2, 98))
        self.albumentation = albumentation

    def __getitem__(self, index):
        orig_image_id, patch_id = self.patch_ids[index].split('_')
        images = []
        for i in range(4):
            image_id = int(orig_image_id)+i
            img = preprocess_image(self.data[image_id], self.lo, self.hi)
            img_patches, image_indices = self.emp.extract_patches(img, patchsize=CFG.patch_size, overlap=CFG.valid_overlap)
            img = img_patches[int(patch_id)]
            images.append(img)
        images = np.stack(images)
        image = create_grid(images)
        if self.albumentation == 'Brightness':
            image = A.RandomBrightness(limit=[-0.05,-0.05],p=1)(image=image)['image']
        
        image = torch.tensor(image)
        orig_image_id = torch.tensor(int(orig_image_id), dtype=torch.int16)
        patch_id = torch.tensor(int(patch_id), dtype=torch.int8)
        return image.unsqueeze(0), orig_image_id.unsqueeze(0),  patch_id.unsqueeze(0)

    def __len__(self):
        return len(self.patch_ids)

In [None]:
seg_models = {
    "Unet": smp.Unet,
    "Unet++": smp.UnetPlusPlus,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "FPN": smp.FPN,
    "PSPNet": smp.PSPNet,
    "PAN": smp.PAN,
    "DeepLabV3": smp.DeepLabV3,
    "DeepLabV3+": smp.DeepLabV3Plus,
}

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = seg_models[CFG.seg_model](
            encoder_name=CFG.encoder_name,
            encoder_weights=None, 
            in_channels=1,
            classes=1,
            activation=None,
        )

    def forward(self, images):
        preds = self.model(images)
        return preds

In [None]:
model = Model()
model.to(CFG.device)
weights = torch.load(f'./results/model/tu-maxvit_tiny_tf_512-Unet_loss_V17_last_3epochs.pth', map_location=CFG.device)
model.load_state_dict(weights)

In [None]:
def rotate_grid(images, k=0):
    row1 = torch.concat([torch.rot90(images[:,:,:CFG.patch_size, :CFG.patch_size], k=k, dims=[2, 3]), torch.rot90(images[:,:,:CFG.patch_size, CFG.patch_size:], k=k, dims=[2, 3])], axis=3)
    row2 = torch.concat([torch.rot90(images[:,:,CFG.patch_size:, :CFG.patch_size], k=k, dims=[2, 3]), torch.rot90(images[:,:,CFG.patch_size:, CFG.patch_size:], k=k, dims=[2, 3])], axis=3)
    image = torch.concat([row1, row2], axis=2)
    return image

def rotate_grid_tta(model, images, rot90):
    with torch.cuda.amp.autocast(enabled=True):
        with torch.no_grad():
            images = rotate_grid(images, k=rot90)
            preds_tta = model(images)
            preds_tta = nn.Sigmoid()(preds_tta)
            preds_tta = rotate_grid(preds_tta, k=-rot90)
            return preds_tta
        
def predict_axis(kidney_volume, all_preds, axis=0, albumentation=False):
    emp = EMPatches()
    print('Predicting axis:', axis)
    if axis == 1:
        kidney_volume = kidney_volume.permute(1,2,0)
        all_preds = all_preds.transpose(1,2,0)
    elif axis == 2:
        kidney_volume = kidney_volume.permute(2,0,1)
        all_preds = all_preds.transpose(2,0,1)
    
    kidney_patched_ids, num_patches_kidney, indices_kidney = get_patch_id_list(kidney_volume, truncate=3, return_indices=True)

    test_dataset = Dataset(kidney_volume, kidney_patched_ids, albumentation=albumentation)
    test_dataloader = DataLoader(test_dataset, batch_size=CFG.valid_batch_size, shuffle=False, num_workers=CFG.workers,pin_memory=True)
    
    
    kidney_patched_ids = get_patch_id_list(kidney_volume, truncate=0)
    test_dict = {}
    test_num = {}
    for id in kidney_patched_ids:
        test_dict[id] = torch.zeros((CFG.patch_size,CFG.patch_size), device='cpu', dtype=torch.float16)
        test_num[id] = 0
        
    pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Test ')
    for step, (images, orig_image_id, patch_id) in pbar:        
        images  = images.to(CFG.device, dtype=torch.float)
        preds = torch.zeros(images.shape, device='cpu')

        for i in range(4): 
            preds_tta = rotate_grid_tta(model, images, i)
            preds += preds_tta.detach().cpu()
        preds /= 4 
        preds = preds.to(torch.float16)
        orig_image_id = orig_image_id.cpu().numpy()
        patch_id = patch_id.cpu().numpy()
        for i, pred in enumerate(preds):
            pred = pred.squeeze(0)
            patches = torch.stack([pred[:CFG.patch_size, :CFG.patch_size], pred[:CFG.patch_size, CFG.patch_size:], pred[CFG.patch_size:, :CFG.patch_size], pred[CFG.patch_size:, CFG.patch_size:]])
            for x, patch in enumerate(patches):
                image_id = orig_image_id[i].item()+x
                image_patch_id = f'{str(image_id)}_{patch_id[i].item()}'
                test_dict[image_patch_id] += patch
                test_num[image_patch_id] += 1
    torch.cuda.empty_cache()
    gc.collect()     
    
    for id in kidney_patched_ids:
        test_dict[id] /= test_num[id]
    
    print('Adding to predictions:')
    for x, image_id in enumerate(tqdm(range(kidney_volume.shape[0]))):
        preds_image = []
        for i in range(num_patches_kidney):
            patch_id = f'{image_id}_{i}'
            preds_image.append(test_dict[patch_id].cpu().numpy())
            test_dict[patch_id] = 0
        merged_preds = emp.merge_patches(preds_image, indices_kidney, mode='avg')
        merged_preds = merged_preds.astype(np.float16)
        all_preds[x] += merged_preds
    
        torch.cuda.empty_cache()
        gc.collect()
    print(all_preds.shape)   
    if axis == 1:
        kidney_volume = kidney_volume.permute(2,0,1)
        all_preds = all_preds.transpose(2,0,1)
    elif axis == 2:
        kidney_volume = kidney_volume.permute(1,2,0)
        all_preds = all_preds.transpose(1,2,0)
    print(all_preds.shape)
        
    return all_preds

def predict_kidney(kidney='kidney_2', folder='train'):
    kidney_volume = valid_data['kidney_2'][0]
    all_preds = np.zeros(kidney_volume.shape, dtype=np.float16)
    # all_preds = predict_axis(kidney_volume, all_preds, axis=0, albumentation='Brightness')
    # all_preds = predict_axis(kidney_volume, all_preds, axis=1, albumentation='Brightness')
    # all_preds = predict_axis(kidney_volume, all_preds, axis=2, albumentation='Brightness')

    all_preds = predict_axis(kidney_volume, all_preds, axis=0, albumentation=False)
    all_preds = predict_axis(kidney_volume, all_preds, axis=1, albumentation=False)
    all_preds = predict_axis(kidney_volume, all_preds, axis=2, albumentation=False)
    counter = 3
    all_preds /= counter
    
    return all_preds

In [None]:
all_preds = predict_kidney()
torch.cuda.empty_cache()
gc.collect()   

In [None]:
def calculate_surface_dice_score(all_preds, all_orig_masks, threshold):
    height = all_orig_masks.shape[1]
    width = all_orig_masks.shape[2]
    all_preds = (all_preds>threshold).astype(np.int8) 

    # all_preds = cc3d.dust(
    #             all_preds, threshold=16, 
    #             connectivity=26, in_place=False
    #             )
    # all_preds = median_filter(all_preds, size=2)
    # all_preds, N = cc3d.largest_k(
    #                         all_preds, k=10, 
    #                         connectivity=26, delta=0,
    #                         return_N=True,
    #                         )
    # all_preds = (all_preds >= 2).astype(np.int8)

    preds_rle = []
    for pred in all_preds:
        rle = rle_encode(pred)
        preds_rle.append(rle)

    labels_rle = []
    for pred in all_orig_masks:
        rle = rle_encode(pred)
        labels_rle.append(rle)
    
    preds_df = valid_df.copy()
    preds_df['rle'] = preds_rle
    valid_df['rle'] = labels_rle
    valid_df['width'] = width
    valid_df['height'] = height
    del all_preds, all_orig_masks
    torch.cuda.empty_cache()
    gc.collect()

    return compute_surface_dice_score(preds_df, valid_df)

In [None]:
calculate_surface_dice_score(all_preds, valid_data['kidney_2'][1].numpy(), 0.2)

In [None]:
# 0.7402400374412537 0.1
# 0.8727735280990601 0.2 --> 0.647912 private score, 13th place solution
# 0.8868693709373474 0.3
# 0.8942667841911316 0.4 --> 0.591234 private score
# 0.8997269868850708 0.5 
# 0.9003251194953918 0.6
# 0.892861008644104 0.7 

In [None]:
all_preds = (all_preds>0.5).astype(np.int8) 
all_preds = nib.Nifti1Image(all_preds, np.eye(4))  
nib.save(all_preds, './results/segmentations/all_preds_nifti.nii')
all_labels = nib.Nifti1Image(valid_data['kidney_2'][1].numpy(), np.eye(4))  
nib.save(all_labels, './results/segmentations/all_labels_nifti.nii')