In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.cuda import amp
import timm
from pytorch_toolbelt.inference import tta
from pytorch_toolbelt import losses as L

import random
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import cv2
import glob
from matplotlib import pyplot as plt
import os
import json
import pandas as pd
import segmentation_models_pytorch as smp
from monai.inferers import sliding_window_inference
import shutil
from tifffile import imread
from IPython.display import display
import warnings
warnings.filterwarnings('ignore')
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 933120000000

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape).T

def read_rgb_img(img_path):
    img= imread(img_path)
    img= np.squeeze(img)
    if img.shape[0]==3: img= np.transpose(img, (1,2,0))
    orign_shape= img.shape[:2][::-1]
    return img, orign_shape

def get_test_transform():
    return A.Compose([
        ToTensorV2(p=1.0),
    ])

class Customize_Dataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.image_path = df['image_path'].values
        self.transforms = transforms
    
    def __getitem__(self, index):
        img_path = self.image_path[index]
        img, ori_shape= read_rgb_img(img_path)
        
        ## scale adjust
        img_size= (img.shape[0], img.shape[1])
        scale= CFG['img_scale'][CFG['organ']]
        scale= int(scale)
        img= cv2.resize(img, (img_size[1]//scale, img_size[0]//scale))
        
        img = self.transforms(image=img)["image"]
        return {
            'img_path': img_path,
            'image': torch.tensor(img/255, dtype=torch.float32),
            'ori_shape': torch.tensor(ori_shape),
        }
    
    def __len__(self):
        return len(self.df)

# CFG

In [None]:
try: os.mkdir('valid_temp_2')
except: pass
shutil.rmtree('valid_temp_2')
os.mkdir('valid_temp_2')
os.mkdir('valid_temp_2/img')
os.mkdir('valid_temp_2/gt')
os.mkdir('valid_temp_2/pt')

In [None]:
CFG= {
    'fold': 0,
#     'organ': 'kidney',
    'organ': 'largeintestine',
    
    'window_size': 768,
    'img_scale': {
        'kidney': (4/1.25),
        'largeintestine': (2.5),
    },
    
    'TTA': True,
    'model': None,
    'show_result': False,
}
CFG['model']= f"./train_model/model_cv{CFG['fold']}_best.pth"
# CFG['model']= f"./train_model/model_cv{CFG['fold']}_ep25.pth"
# CFG['model']= f"./test_model/effb7_w768_cv0_best/model_cv{CFG['fold']}_best.pth"
# CFG['model']= f"./test_model/effb7_w768_0.58"

CFG['model']= [torch.load(CFG['model'], map_location= 'cuda:0')]
# CFG['model']= [torch.load(m, map_location= 'cuda:0') for m in glob.glob(f"{CFG['model']}/**/*pth", recursive=True)[-5:]]

# Prepare Dataset

In [None]:
if CFG['organ']=='kidney': df= pd.read_csv('./Data/ex_data_kidney.csv')
if CFG['organ']=='largeintestine': df= pd.read_csv('./Data/ex_data_largeintestine.csv')
df['fold']= 0
valid_df= df.fillna('')

valid_dataset= Customize_Dataset(valid_df.iloc[:5], get_test_transform())
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
print(f'valid dataset: {len(valid_dataset)}')
valid_df

# Inference

In [None]:
def inference(model, img, ori_shape):
    img= img.cuda()
    for i, m in enumerate(model):
        with torch.no_grad():
            m.eval()
            if CFG['TTA']:
                imgs= torch.cat([img, img.flip(-1), img.flip(-2), img.flip(-1).flip(-2)], dim=0)
                pred= sliding_window_inference(imgs, 
                                               (CFG['window_size'], CFG['window_size']), 
                                               sw_batch_size= 2, 
                                               predictor= m,
                                               mode= 'gaussian',
                                               overlap= 0.25)
                pred= (pred[0] + pred[1].flip(-1) + pred[2].flip(-2) + pred[3].flip(-1).flip(-2) ) / 4
            else:
                pred= sliding_window_inference(img, 
                                               (CFG['window_size'], CFG['window_size']), 
                                               sw_batch_size= 2, 
                                               predictor= m,
                                               mode= 'gaussian',
                                               overlap= 0.25)[0]
                
        if pred.shape[0]!=1:
            if i==0: preds= pred.softmax(dim=0)
            else: preds+= pred.softmax(dim=0)
        else:
            if i==0: preds= pred.sigmoid()
            else: preds+= pred.sigmoid()
                
    pred= preds/len(model)
    pred= pred.cpu().permute(1,2,0).numpy()
    if pred.shape[2]!=1:
        if 'kidney'==CFG['organ']:
            pred= pred[..., 4:5]
        else:
            pred= pred[..., 5:6]
    pred= cv2.resize(pred, tuple(ori_shape))
    return pred

In [None]:
indx= 0
for i, data in enumerate(tqdm(valid_loader)):
    for j in range(len(data['image'])):
        img_path= data['img_path'][j]
        shutil.copy(img_path, f'valid_temp_2/img/{indx}.tiff')
        img= data['image'][j]
        ori_shape= data['ori_shape'][j]
        
        mask_path= valid_df.loc[indx, 'mask_path']
        shutil.copy(mask_path, f'valid_temp_2/gt/{indx}.png')
        
        ## inference
        img= torch.unsqueeze(img, dim= 0)
        pred_mask= inference(CFG['model'], img, ori_shape.numpy())
        pred_mask*= 255
        
        im= Image.fromarray(pred_mask.astype(np.uint8))
        im.save(f'valid_temp_2/pt/{indx}.png')
        del pred_mask
        
        indx+= 1
        if CFG['show_result']:
            ## show result
            plt.figure(figsize=(15,15))
            plt.subplot(1,4,1)
            img= np.array(Image.open(img_path))
            plt.title('image', color= 'b')
            plt.imshow(img)

            plt.subplot(1,4,2)
            plt.title('predict_mask', color= 'b')
            plt.imshow(pred_mask)

            pred_mask= cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2RGB)
            pred_mask[:,:,1:]= 0
            mix_img= (img*0.5).astype(np.uint8) + (pred_mask*0.5).astype(np.uint8)
            plt.subplot(1,4,3)
            plt.title('mix_iamge', color= 'b')
            plt.imshow(mix_img)

            gt_mask= cv2.cvtColor(gt_mask, cv2.COLOR_GRAY2RGB)
            gt_mask[:,:,1:]= 0
            mix_img= (img*0.5).astype(np.uint8) + (gt_mask*0.5).astype(np.uint8)
            plt.subplot(1,4,4)
            plt.title('gt_mask', color= 'b')
            plt.imshow(mix_img)
            plt.show()

# Metric

In [None]:
gt= glob.glob('valid_temp_2/gt/**.png')
pt= glob.glob('valid_temp_2/pt/**.png')

from pytorch_toolbelt import losses as L
dice_loss= L.DiceLoss(mode= 'binary', from_logits=False)

def evaluate(thr= 0.5):
    dices= []
    for i in tqdm(range(len(gt))):
        pt_mask= np.array(Image.open(pt[i]).convert('L')).astype(np.uint8)/255
        pt_mask= cv2.resize(pt_mask, (2048,2048))
        pt_mask= np.expand_dims(pt_mask, axis=0)
        pt_mask[pt_mask>=thr]= 1
        pt_mask[pt_mask<thr]= 0
        gt_mask= np.array(Image.open(gt[i]).convert('L')).astype(np.uint8)/255
        gt_mask= cv2.resize(gt_mask, (2048,2048))
        gt_mask= np.expand_dims(gt_mask, axis=0)

        loss= dice_loss( torch.tensor(pt_mask), torch.tensor(gt_mask) )
        score= 1-loss
        dices.append(score)
        del pt_mask
        del gt_mask

    return np.mean(dices)

scores= []
for thr in range(1, 10):
    thr/=10
    print(f'thr= {thr}')
    dice= evaluate(thr)
    scores.append(dice)
    if thr==0.5:
        print(f'thr_0.5: {dice}')

In [None]:
plt.plot(scores)
plt.show()