In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.cuda import amp
import timm
from pytorch_toolbelt.inference import tta
from pytorch_toolbelt import losses as L

import random
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import cv2
import glob
from matplotlib import pyplot as plt
import os
import json
import pandas as pd
import segmentation_models_pytorch as smp
from monai.inferers import sliding_window_inference
import shutil
from tifffile import imread
from IPython.display import display
import warnings
warnings.filterwarnings('ignore')
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 933120000

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape).T

def read_rgb_img(img_path):
    if CFG['domain_shift']:
        img_path= img_path.replace('train_images', 'train_images_domain_shift')
    img= imread(img_path)
    orign_shape= img.shape[:2][::-1]
    return img, orign_shape

def get_test_transform():
    return A.Compose([
        A.PadIfNeeded(min_height=CFG['window_size'], min_width=CFG['window_size'], border_mode=0, p=1),
        ToTensorV2(p=1.0),
    ])

def get_crop_transform(img_size):
    return A.Compose([
        A.CenterCrop(img_size, img_size, p=1),
    ])

class Customize_Dataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.image_path = df['image_path'].values
        self.organs= df['organ'].values
        self.transforms = transforms
    
    def __getitem__(self, index):
        img_path = self.image_path[index]
        organ= self.organs[index]
        img, ori_shape= read_rgb_img(img_path)
        
        ## scale adjust
        img_size= min(img.shape[0], img.shape[1])
        img_size= int(img_size/CFG['img_scale'])
        img= cv2.resize(img, (img_size, img_size))
        
        pad= img.shape[0] if img.shape[0]<CFG['window_size'] else 0
        img = self.transforms(image=img)["image"]
        return {
            'img_path': img_path,
            'image': torch.tensor(img/255, dtype=torch.float32),
            'ori_shape': torch.tensor(ori_shape),
            'organ': organ,
            'pad': pad,
        }
    
    def __len__(self):
        return len(self.df)

In [None]:
class customize_model(nn.Module):
    def __init__(self, model_name):
        super(customize_model, self).__init__()
        
    def forward(self, images):
        out= self.model(images)['logits']
        return out

# CFG

In [None]:
shutil.rmtree('valid_temp_1')
os.mkdir('valid_temp_1')
os.mkdir('valid_temp_1/gt')
os.mkdir('valid_temp_1/pt')

In [None]:
CFG= {
    'fold': 0,
    'img_scale': 4,
    'window_size': 768,
    'TTA': True,
    'model': None,
    'domain_shift': True,
    'show_result': False,
}
CFG['model']= f"./train_model/model_cv{CFG['fold']}_best.pth"
# CFG['model']= f"./train_model/model_cv{CFG['fold']}_ep100.pth"
# CFG['model']= f"./test_model/effb7_w768_cv0_best/model_cv{CFG['fold']}_best.pth"

CFG['model']= [torch.load(CFG['model'], map_location= 'cuda:0')]

# Prepare Dataset

In [None]:
df= pd.read_csv('./Data/train.csv')

valid_df= df[df['fold']==CFG['fold']].reset_index(drop=True)
print(f'valid dataset: {len(valid_df)}')

valid_dataset= Customize_Dataset(valid_df, get_test_transform())
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
valid_df.head()

# Inference

In [None]:
def inference(model, img, ori_shape):
    img= img.cuda()
    for i, m in enumerate(model):
        with torch.no_grad():
            m.eval()
            if CFG['TTA']:
                imgs= torch.cat([img, 
                                 img.rot90(1, [2,3]),
                                 img.rot90(2, [2,3]),
                                 img.rot90(3, [2,3]),
                                 img.flip(-1), 
                                 img.flip(-1).rot90(1, [2,3]),
                                 img.flip(-1).rot90(2, [2,3]),
                                 img.flip(-1).rot90(3, [2,3])], dim=0)
                pred= sliding_window_inference(imgs, 
                                               (CFG['window_size'], CFG['window_size']), 
                                               sw_batch_size= 2, 
                                               predictor= m,
                                               mode= 'gaussian',
                                               overlap= 0.25)
                pred= (pred[0] + 
                       pred[1].rot90(-1, [1,2]) +
                       pred[2].rot90(-2, [1,2]) +
                       pred[3].rot90(-3, [1,2]) +
                       pred[4].flip(-1) + 
                       pred[5].rot90(-1, [1,2]).flip(-1) + 
                       pred[6].rot90(-2, [1,2]).flip(-1) + 
                       pred[7].rot90(-3, [1,2]).flip(-1)) / 8
            else:
                pred= sliding_window_inference(img, 
                                               (CFG['window_size'], CFG['window_size']), 
                                               sw_batch_size= 2, 
                                               predictor= m,
                                               mode= 'gaussian',
                                               overlap= 0.25)[0]
                
        if pred.shape[0]!=1:
            if i==0: preds= pred.softmax(dim=0)
            else: preds+= pred.softmax(dim=0)
        else:
            if i==0: preds= pred.sigmoid()
            else: preds+= pred.sigmoid()
                
    pred= preds/len(model)
    pred= pred.cpu().permute(1,2,0).numpy()
    return pred

In [None]:
organs= [
    'lung',
    'spleen',
    'prostate',
    'kidney',
    'largeintestine',
]

for i, data in enumerate(tqdm(valid_loader)):
    for j in range(len(data['image'])):
        img_path= data['img_path'][j]
        img= data['image'][j]
        ori_shape= data['ori_shape'][j]
        organ= data['organ'][j]
        pad= data['pad'][j]
        
        img_size= Image.open(img_path).size[::-1]
        id_= img_path.split('/')[-1].split('.')[0]
        rle= valid_df.loc[valid_df['id']==int(id_), 'rle'].values[0]
        gt_mask= rle_decode(rle, img_size)
        
        ## inference
        img= torch.unsqueeze(img, dim= 0)
        pred_mask= inference(CFG['model'], img, ori_shape.numpy())
        
        if pred_mask.shape[2]!=1:
            pred_mask= pred_mask[..., organs.index(organ)+1]
        
        ## if padding
        if pad:
            aug= get_crop_transform(pad)
            pred_mask= aug(image= pred_mask)['image']
        pred_mask= cv2.resize(pred_mask, tuple(ori_shape))
    
        pred_mask= pred_mask*255
        im= Image.fromarray(pred_mask.astype(np.uint8))
        im.save(f'valid_temp_1/pt/{id_}.png')
        gt_mask= gt_mask*255
        im= Image.fromarray(gt_mask.astype(np.uint8))
        im.save(f'valid_temp_1/gt/{id_}.png')
        
        if CFG['show_result'] and organ=='spleen':
            print(f'organ: {organ}')
            ## show result
            plt.figure(figsize=(15,15))
            plt.subplot(1,4,1)
            img= np.array(Image.open(img_path))
            plt.title('image', color= 'b')
            plt.imshow(img)

            plt.subplot(1,4,2)
            plt.title('predict_mask', color= 'b')
            plt.imshow(pred_mask)

            pred_mask= cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2RGB)
            pred_mask[:,:,1:]= 0
            mix_img= (img*0.5).astype(np.uint8) + (pred_mask*0.5).astype(np.uint8)
            plt.subplot(1,4,3)
            plt.title('mix_iamge', color= 'b')
            plt.imshow(mix_img)

            gt_mask= cv2.cvtColor(gt_mask, cv2.COLOR_GRAY2RGB)
            gt_mask[:,:,1:]= 0
            mix_img= (img*0.5).astype(np.uint8) + (gt_mask*0.5).astype(np.uint8)
            plt.subplot(1,4,4)
            plt.title('gt_mask', color= 'b')
            plt.imshow(mix_img)
            plt.show()

# Metric

In [None]:
gt= glob.glob('valid_temp_1/gt/**.png')
pt= glob.glob('valid_temp_1/pt/**.png')

from pytorch_toolbelt import losses as L
dice_loss= L.DiceLoss(mode= 'binary', from_logits=False)

def evaluate(thr= 0.5):
    dice_organ= {
        'prostate': [],
        'spleen': [],
        'lung': [],
        'kidney': [],
        'largeintestine': [],
    }
    dices= []
    for i in tqdm(range(len(gt))):
        id_= pt[i].split('\\')[-1].split('.')[0]
        organ= df[df['id']==int(id_)]['organ'].values[0]

        pt_mask= np.array(Image.open(pt[i]).convert('L'))/255
        pt_mask= np.expand_dims(pt_mask, axis=0)
        pt_mask[pt_mask>=thr]= 1
        pt_mask[pt_mask<thr]= 0
        gt_mask= np.array(Image.open(gt[i]).convert('L'))/255
        gt_mask= np.expand_dims(gt_mask, axis=0)

        loss= dice_loss( torch.tensor(pt_mask), torch.tensor(gt_mask) )
        score= 1-loss
        dices.append(score)

        dice_organ[organ].append(score)

    for key in dice_organ.keys():
        dice_organ[key]= round( np.mean(dice_organ[key]), 3)
        
    return np.mean(dices), dice_organ

scores= []
thr_score_organ= {
    'prostate': [],
    'spleen': [],
    'lung': [],
    'kidney': [],
    'largeintestine': [],
}
for thr in range(1, 10):
    thr/=10
    print(f'thr= {thr}')
    dice, organ_dice= evaluate(thr)
        
    avg_dice= []
    for (key, value) in organ_dice.items():
        avg_dice.append(value)
    scores.append( np.mean(avg_dice) )
    
    for key in organ_dice.keys():
        thr_score_organ[key].append(organ_dice[key])
print(thr_score_organ)

In [None]:
organ_thr={
    'lung': 0,
    'spleen': 0,
    'prostate': 0,
    'kidney': 0,
    'largeintestine':0,
}

plt.figure(figsize=(6,3))
plt.plot(scores, label='mean score')
plt.legend()
plt.show()

plt.figure(figsize=(10,10))
for i, key in enumerate(thr_score_organ.keys()):
    plt.subplot(3,2,i+1)
    best_thr= ( np.argmax(thr_score_organ[key])+1 )/10
    organ_thr[key]= best_thr
    print(f'{key}: best thr= {best_thr}')
    plt.plot(thr_score_organ[key], color='r', label=f'{key}')
    plt.legend()
plt.show()