In [None]:
!python -m pip install --no-index --find-links=/kaggle/input/pip-download-for-segmentation-models-pytorch segmentation-models-pytorch
!python -m pip install --no-index --find-links=/kaggle/input/connected-components-3d connected-components-3d
!python -m pip install --no-index --find-links=/kaggle/input/empatches empatches

In [None]:
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import segmentation_models_pytorch as smp
import albumentations as A
import glob
import os
from tqdm import tqdm
from empatches import EMPatches
import joblib
import gc
import sys

In [None]:
class CFG:
    data_root_path = '/kaggle/input/blood-vessel-segmentation'
    workers = 4
    seed = 42
    threshold = 0.2
    test_overlap = 0.1
    batch_size = 16
    patch_size = 256
    debug = False
    
    seg_model = "Unet" 
    encoder_name = "tu-maxvit_tiny_tf_512"

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
ls_images = glob.glob(os.path.join(CFG.data_root_path, "test", "*", "*", "*.tif"))
print(f"found images: {len(ls_images)}")

if CFG.debug:
    print('Debugging..')
    ls_images = glob.glob(os.path.join(CFG.data_root_path, "train", "kidney_2", "images", "*.tif"))
    print(f"found images: {len(ls_images)}")
    kidney_ids = []
    image_ids = []
    for p_img in tqdm(ls_images):
        path_ = p_img.split(os.path.sep)
        # parse the submission ID
        kidney_id = path_[-3]
        image_id, _ = os.path.splitext(path_[-1])
        kidney_ids.append(f"{kidney_id}")
        image_ids.append(f"{image_id}")
else: 
    kidney_ids = []
    image_ids = []
    for p_img in tqdm(ls_images):
        path_ = p_img.split(os.path.sep)
        # parse the submission ID
        kidney_id = path_[-3]
        image_id, _ = os.path.splitext(path_[-1])
        kidney_ids.append(f"{kidney_id}")
        image_ids.append(f"{image_id}")

In [None]:
df = pd.DataFrame({'kidney_ids':kidney_ids, 'image_ids':image_ids}).sort_values('image_ids', ascending=True).reset_index(drop=True)

In [None]:
if CFG.debug:
    print('shortening df for debugging')
    df = df.iloc[900:900+256].copy()
    print(df.shape)
    df['folder'] = 'train'
else:
    df['folder'] = 'test'

In [None]:
os.mkdir('/kaggle/working/images/')
os.mkdir('/kaggle/working/indices/')

In [None]:
def create_kidney_volume(kidney, folder):
    images = []
    kidney_ids = df[df['kidney_ids'] == kidney].sort_values('image_ids', ascending=True).image_ids
    for image_id in tqdm(kidney_ids):
        img = cv2.imread(f'/kaggle/input/blood-vessel-segmentation/{folder}/{kidney}/images/{str(image_id).zfill(4)}.tif',cv2.IMREAD_GRAYSCALE)
        img = torch.from_numpy(img.copy())
        img = img.to(torch.uint8)
        images.append(img)
    images = torch.stack(images)
    return images, kidney_ids

def preprocess_image(image, lo, hi):
    image = image.to(torch.float32)
    image = (image - lo) / (hi - lo)
    image = torch.clamp(image, min=0.5)
    return image

def get_patch_id_list(data, truncate=0, return_indices=False):
    emp = EMPatches()
    img = data[0]
    img_patches, image_indices = emp.extract_patches(img, patchsize=CFG.patch_size, overlap=CFG.test_overlap)
    patch_ids = []
    for image_id in range(data.shape[0]-truncate):
        for patch in range(len(img_patches)):
            patch_ids.append(f'{image_id}_{patch}')
    if return_indices:
        return patch_ids, len(img_patches), image_indices
    return patch_ids

def create_grid(images):
    row1 = np.concatenate([images[0, :, :],images[1, :, :]], axis=1)
    row2 = np.concatenate([images[2, :, :],images[3, :, :]], axis=1)
    image = row2 = np.concatenate([row1, row2], axis=0)
    return image

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, patch_ids):
        self.data = data
        self.patch_ids = patch_ids
        self.emp = EMPatches()
        self.lo, self.hi = np.percentile(data.numpy(), (2, 98))

    def __getitem__(self, index):
        orig_image_id, patch_id = self.patch_ids[index].split('_')
        images = []
        for i in range(4):
            image_id = int(orig_image_id)+i
            img = preprocess_image(self.data[image_id], self.lo, self.hi)
            img_patches, image_indices = self.emp.extract_patches(img, patchsize=CFG.patch_size, overlap=CFG.test_overlap)
            img = img_patches[int(patch_id)]
            images.append(img)
        images = np.stack(images)
        image = create_grid(images)
        
        image = torch.tensor(image) 
        orig_image_id = torch.tensor(int(orig_image_id), dtype=torch.int16)
        patch_id = torch.tensor(int(patch_id), dtype=torch.int8)
        return image.unsqueeze(0), orig_image_id.unsqueeze(0),  patch_id.unsqueeze(0)

    def __len__(self):
        return len(self.patch_ids)

In [None]:
torch.manual_seed(CFG.seed)

seg_models = {
    "Unet": smp.Unet,
    "Unet++": smp.UnetPlusPlus,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "FPN": smp.FPN,
    "PSPNet": smp.PSPNet,
    "PAN": smp.PAN,
    "DeepLabV3": smp.DeepLabV3,
    "DeepLabV3+": smp.DeepLabV3Plus,
}

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = seg_models[CFG.seg_model](
            encoder_name=CFG.encoder_name,
            encoder_weights=None, 
            in_channels=1,
            classes=1,
            activation=None,
        )

    def forward(self, images):
        preds = self.model(images)
        return preds

In [None]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

In [None]:
def rotate_grid(images, k=0):
    row1 = torch.concat([torch.rot90(images[:,:,:CFG.patch_size, :CFG.patch_size], k=k, dims=[2, 3]), torch.rot90(images[:,:,:CFG.patch_size, CFG.patch_size:], k=k, dims=[2, 3])], axis=3)
    row2 = torch.concat([torch.rot90(images[:,:,CFG.patch_size:, :CFG.patch_size], k=k, dims=[2, 3]), torch.rot90(images[:,:,CFG.patch_size:, CFG.patch_size:], k=k, dims=[2, 3])], axis=3)
    image = torch.concat([row1, row2], axis=2)
    return image

def rotate_grid_tta(model, images, rot90):
    with torch.cuda.amp.autocast(enabled=True):
        with torch.no_grad():
            images = rotate_grid(images, k=rot90)
            preds_tta = model(images)
            preds_tta = nn.Sigmoid()(preds_tta)
            preds_tta = rotate_grid(preds_tta, k=-rot90)
            return preds_tta
        
def predict_axis(kidney_volume, all_preds, axis=0):
    print('Predicting axis:', axis)
    if axis == 1:
        kidney_volume = kidney_volume.permute(1,2,0)
        all_preds = all_preds.transpose(1,2,0)
    elif axis == 2:
        kidney_volume = kidney_volume.permute(2,0,1)
        all_preds = all_preds.transpose(2,0,1)
    
    kidney_patched_ids, num_patches_kidney, indices_kidney = get_patch_id_list(kidney_volume, truncate=3, return_indices=True)

    test_dataset = Dataset(kidney_volume, kidney_patched_ids)
    test_dataloader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.workers,pin_memory=True)
    
    kidney_patched_ids = get_patch_id_list(kidney_volume, truncate=0)
    test_dict = {}
    test_num = {}
    for id in kidney_patched_ids:
        test_dict[id] = torch.zeros((CFG.patch_size,CFG.patch_size), device='cpu', dtype=torch.float16)
        test_num[id] = 0
    pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Test ')
    for step, (images, orig_image_id, patch_id) in pbar:        
        images  = images.to(CFG.device, dtype=torch.float)
        preds = torch.zeros(images.shape, device='cpu')

        for i in range(4): 
            preds_tta = rotate_grid_tta(model, images, i)
            preds += preds_tta.detach().cpu()
        preds /= 4 
        preds = preds.to(torch.float16)
        orig_image_id = orig_image_id.cpu().numpy()
        patch_id = patch_id.cpu().numpy()
        for i, pred in enumerate(preds):
            pred = pred.squeeze(0)
            patches = torch.stack([pred[:CFG.patch_size, :CFG.patch_size], pred[:CFG.patch_size, CFG.patch_size:], pred[CFG.patch_size:, :CFG.patch_size], pred[CFG.patch_size:, CFG.patch_size:]])
            for x, patch in enumerate(patches):
                image_id = orig_image_id[i].item()+x
                image_patch_id = f'{str(image_id)}_{patch_id[i].item()}'
                test_dict[image_patch_id] += patch
                test_num[image_patch_id] += 1
    torch.cuda.empty_cache()
    gc.collect()     
    
    for id in kidney_patched_ids:
        test_dict[id] /= test_num[id]
    
    print('Adding to predictions:')
    for x, image_id in enumerate(tqdm(range(kidney_volume.shape[0]))):
        preds_image = []
        for i in range(num_patches_kidney):
            patch_id = f'{image_id}_{i}'
            preds_image.append(test_dict[patch_id].cpu().numpy())
            test_dict[patch_id] = 0
        merged_preds = emp.merge_patches(preds_image, indices_kidney, mode='max')
        merged_preds = merged_preds.astype(np.float16)
        all_preds[x] += merged_preds        
        torch.cuda.empty_cache()
        gc.collect()
    print(all_preds.shape)   
    if axis == 1:
        kidney_volume = kidney_volume.permute(2,0,1)
        all_preds = all_preds.transpose(2,0,1)
    elif axis == 2:
        kidney_volume = kidney_volume.permute(1,2,0)
        all_preds = all_preds.transpose(1,2,0)
    print(all_preds.shape)
        
    return all_preds

def predict_kidney(kidney='kidney_2', folder='train'):
    
    kidney_volume, kidney_ids = create_kidney_volume(kidney, folder)
    all_preds = np.zeros(kidney_volume.shape, dtype=np.float16)
    all_preds = predict_axis(kidney_volume, all_preds, axis=0)
    all_preds = predict_axis(kidney_volume, all_preds, axis=1)
    all_preds = predict_axis(kidney_volume, all_preds, axis=2)
    counter = 3
    all_preds /= counter
    
    all_preds = (all_preds>CFG.threshold).astype(np.int8)
    all_rle = []
    for pred in all_preds:
        rle = rle_encode(pred)
        all_rle.append(rle)
    submission = pd.DataFrame.from_dict({
    "id": kidney_ids,
    "rle": all_rle
    })
    submission.id = submission.id.apply(lambda x: f'{kidney}_{x}')
    return submission

In [None]:
weights = torch.load(f'/kaggle/input/model2-5d-v17/tu-maxvit_tiny_tf_512-Unet_loss_V17_last_3epochs.pth', map_location=CFG.device) 
model = Model()
model.to(CFG.device)
model.load_state_dict(weights)
emp = EMPatches()

model.eval()

if CFG.debug:
    submission = predict_kidney(kidney='kidney_2', folder='train')
    
else:
    df_kidney_5 = predict_kidney(kidney='kidney_5', folder='test')
    torch.cuda.empty_cache()
    gc.collect()
    df_kidney_6 = predict_kidney(kidney='kidney_6', folder='test')
    torch.cuda.empty_cache()
    gc.collect()
    submission = pd.concat([df_kidney_5,df_kidney_6]).reset_index(drop=True)


In [None]:
submission.to_csv("submission.csv", index=False)

In [None]:
submission