In [2]:
!pip install -q ../input/pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q ../input/pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q ../input/pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl
!pip install -q ../input/pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl
!mkdir -p /root/.cache/torch/hub/checkpoints
# !cp ../input/effb7-pth/efficientnet-b7-dcc49843.pth /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth

In [3]:
import os
import torch
import numpy as np
import pandas as pd
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

from glob import glob

In [4]:
def get_class_names(df):
    labels = df['class']
    return labels.unique()

def make_test_augmenter(conf):
    crop_size = round(conf.image_size*conf.crop_size)
    return  A.Compose([
        A.CenterCrop(height=crop_size, width=crop_size),
        ToTensorV2(transpose_mask=True)
    ])

def get_id(filename):
    # e.g. filename: case123_day20/scans/slice_0001_266_266_1.50_1.50.png
    # id: case123_day20_slice_0001
    tokens = filename.split('/')
    return tokens[-3] + '_' + '_'.join(tokens[-1].split('_')[:2])

In [5]:
import torch.utils.data as data
class VisionDataset(data.Dataset):
    def __init__(
            self, df, conf, input_dir, imgs_dir,
            class_names, transform, is_test=False, subset=100):
        self.conf = conf
        self.transform = transform
        self.is_test = is_test
        self.CLAHE =cv2.createCLAHE(clipLimit=2.0,tileGridSize=(8,8))
        
        if subset != 100:
            assert subset < 100
            # train and validate on subsets
            num_rows = df.shape[0]*subset//100
            df = df.iloc[:num_rows]

        files = df['img_files']
        self.files = [os.path.join(input_dir, imgs_dir, f) for f in files]
        self.masks = [f.replace('train', 'masks') for f in files]

    def resize(self, img, interp):
        return  cv2.resize(
            img, (self.conf.image_size, self.conf.image_size), interpolation=interp)

    def load_slice(self, img_file, diff):
        slice_num = os.path.basename(img_file).split('_')[1]
        filename = (
            img_file.replace(
                'slice_' + slice_num,
                'slice_' + str(int(slice_num) + diff).zfill(4)))
        if os.path.exists(filename):
#             clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8))
            return self.CLAHE.apply(cv2.imread(filename, cv2.IMREAD_UNCHANGED))
        return None

    def __getitem__(self, index):
        conf = self.conf
        img_file = self.files[index]
        # read 5 slices into one image
        imgs = [self.load_slice(img_file, i) for i in range(-2, 3)]
        if imgs[3] is None:
            imgs[3] = imgs[2]
        if imgs[4] is None:
            imgs[4] = imgs[3]
        if imgs[1] is None:
            imgs[1] = imgs[2]
        if imgs[0] is None:
            imgs[0] = imgs[1]
        img = np.stack(imgs, axis=2)

        img = img.astype(np.float32)
        max_val = img.max()
        if max_val != 0:
            img /= max_val
        img = self.resize(img, cv2.INTER_AREA)

        if self.is_test:
            msk = 0
            result = self.transform(image=img)
            img = result['image']
        else:
            # read mask
            msk_file = self.masks[index]
            msk = cv2.imread(msk_file, cv2.IMREAD_UNCHANGED)
            msk = self.resize(msk, cv2.INTER_NEAREST)
            msk = msk.astype(np.float32)
            result = self.transform(image=img, mask=msk)
            img, msk = result['image'], result['mask']
        return img, msk

    def __len__(self):
        return len(self.files)

In [6]:
import torch.nn as nn
import segmentation_models_pytorch as smp

class ModelWrapper(nn.Module):

    def __init__(self, conf, num_classes):
        super().__init__()
        if conf.arch == 'FPN':
            arch = smp.FPN
        elif conf.arch == 'Unet':
            arch = smp.Unet
        elif conf.arch == 'DeepLabV3':
            arch = smp.DeepLabV3
        else:
            assert 0, f'Unknown architecture {conf.arch}'

        weights = 'imagenet' if conf.pretrained else None
        self.model = arch(
            encoder_name=conf.backbone, encoder_weights=weights, in_channels=5,
            classes=num_classes, activation=None)

    def forward(self, x):
        x = self.model(x)
        return  x


In [7]:
from scipy.ndimage.morphology import binary_dilation
def drop_small_mask(mask, th):
    res = np.zeros_like(mask)
    ret, labels = cv2.connectedComponents(mask.astype(np.uint8), connectivity=4)
    for i in range(1, ret):
        if np.sum(labels==i) > th:
            res += (labels==i)
    return res  

In [8]:
class Config():
    # FPN, Unet or DeepLab
    arch = 'FPN'
    backbone = 'efficientnet-b7'
    pretrained = True

    # resize images to this size on the fly
    image_size = 512
    # crop to this fraction of image_size
    crop_size = 1.0

    # optimizer settings
    optim = 'adam'
    lr = 0.001
    weight_decay = 0.01
    batch_size = 48

    # scheduler settings
    gamma = 0.96

    # data augmentation
    aug_prob = 0.4
    strong_aug = True
    max_cutout = 0
    
conf = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
def create_test_loader(conf, input_dir, class_names):
    test_aug = make_test_augmenter(conf)
    test_df = pd.DataFrame()
    img_files = []
    img_dir = 'test'
    subdir = ''
    while len(img_files) == 0 and len(subdir) < 10:
        img_files = sorted(glob(f'{input_dir}/{img_dir}/{subdir}*.png'))
        subdir += '*/'
        if len(subdir) > 10:
            return None
    # delete common prefix from paths
    if len(img_files) == 0:
        img_dir = 'train'
        subdir = ''
        while len(img_files) == 0 and len(subdir) < 10:
            img_files = sorted(glob(f'{input_dir}/{img_dir}/{subdir}*.png'))
            subdir += '*/'
            if len(subdir) > 10:
                return None      
        img_files = img_files[:1000]
    img_files = [f.replace(f'{input_dir}/{img_dir}/', '') for f in img_files]

    test_df['img_files'] = img_files
    test_dataset = VisionDataset(
        test_df, conf, input_dir, img_dir,
        class_names, test_aug, is_test=True)
    print(f'{len(test_dataset)} examples in test set')
    loader = data.DataLoader(
        test_dataset, batch_size=conf.batch_size, shuffle=False,
        num_workers=2, pin_memory=False)
    return loader, test_df

In [10]:
def create_model(conf, model_dir, num_classes):
    checkpoint = torch.load(model_dir, map_location=device)['model']
    pretrained_dict = {k.replace('module.', '') : v for k, v in checkpoint.items()}
    conf.pretrained = False  
    model = ModelWrapper(conf, num_classes)
    model = model.to(device)
    model.load_state_dict(pretrained_dict)
    return model


# def create_model(conf, model_dir, num_classes):
# #     checkpoint = torch.load(model_dir, map_location=device)['model']
# #     pretrained_dict = {k.replace('module.', '') : v for k, v in checkpoint.items()}
#     conf.pretrained = False  
#     model = ModelWrapper(conf, num_classes)
#     model = model.to(device)
# #     model.load_state_dict(pretrained_dict)
#     return model

In [11]:
def rle_encode(img):
    '''
    this function is adapted from
    https://www.kaggle.com/code/stainsby/fast-tested-rle/notebook
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def get_img_shape(filename):
    basename = os.path.basename(filename)
    tokens = basename.split('_')
    height, width = int(tokens[3]), int(tokens[2])
    return (height, width)

def pad_mask(conf, mask):
    # pad image to conf.image_size
    padded = np.zeros((conf.image_size, conf.image_size), dtype=mask.dtype)
    dh = conf.image_size - mask.shape[0]
    dw = conf.image_size - mask.shape[1]

    top = dh//2
    left = dw//2
    padded[top:top + mask.shape[0], left:left + mask.shape[1]] = mask
    return padded

def resize_mask(mask, height, width):
    return cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)

In [18]:
import cv2
import numpy as np

'''
图像说明：
图像为二值化图像，255白色为目标物，0黑色为背景
要填充白色目标物中的黑色空洞
'''


def FillHole(im_in):
    # 复制 im_in 图像

    im_floodfill = im_in.copy()

    # Mask 用于 floodFill，官方要求长宽+2
    h, w = im_in.shape[:2]
    mask = np.zeros((h + 2, w + 2), np.uint8)

    seedPoint=(0,0)
    cv2.floodFill(im_floodfill, mask, seedPoint, 1)

    # 得到im_floodfill的逆im_floodfill_inv
    im_floodfill_inv = cv2.bitwise_not(im_floodfill)

    # 把im_in、im_floodfill_inv这两幅图像结合起来得到前景
    im_out = im_in * im_floodfill_inv
               
    return im_out

In [22]:
import PIL
import torchvision
import torchvision.transforms.functional as F
def run(input_dir, ckpt_paths, thresh):
    meta_file = os.path.join(input_dir, 'train.csv')
    train_df = pd.read_csv(meta_file, dtype=str)
    class_names = np.array(get_class_names(train_df))
    num_classes = len(class_names)
    loader, df = create_test_loader(conf, input_dir, class_names)
    img_files = df['img_files']
    subm = pd.read_csv(f'{input_dir}/sample_submission.csv')
    del subm['predicted']
    ids = []
    classes = []
    masks = []
    img_idx = 0
    sigmoid = nn.Sigmoid()
#     model.eval()

    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
            size = images.size()
            masks_tta = np.zeros((size[0], 3, size[2], size[3]), dtype=np.float32)
            for sub_ckpt_path in ckpt_paths:
                model = create_model(conf, sub_ckpt_path, num_classes)
                model.eval()
                outputs = model(images)
                preds = sigmoid(outputs).cpu().numpy()
                masks_tta += preds
                flips = [[-1]]
                for f in flips:
                    images_f = torch.flip(images, f)
                    outputs = model(images_f)
                    outputs = torch.flip(outputs, f)
                    preds = sigmoid(outputs).cpu().numpy()
                    masks_tta += preds
                for degree in [25, -25]:
                    images_d = torchvision.transforms.RandomRotation(degrees=(degree, degree), expand=False, center=(size[2]//2, size[3]//2))(images)
                    outputs = model(images_d)
                    outputs = torchvision.transforms.RandomRotation(degrees=(-degree, -degree), expand=False, center=(size[2]//2, size[3]//2))(outputs)
                    preds = sigmoid(outputs).cpu().numpy()
                    masks_tta += preds
            masks_tta /= 4*len(ckpt_paths)
            masks_tta[masks_tta >= thresh] = 1
            masks_tta[masks_tta < thresh] = 0
            mask_list=[]
#             for mask in masks_tta:
#                 im_out=FillHole(mask)
#                 mask_list.append(im_out)
#             masks_tta=np.stack(mask_list)
#             print(masks_tta.shape)
            
            for pred in masks_tta:
                img_file = img_files[img_idx]
                img_idx += 1
                img_id = get_id(img_file)
                height, width = get_img_shape(img_file)
                for class_id, class_name in enumerate(class_names):
                    mask = pred[class_id]
#                     mask = cv2.GaussianBlur(mask, (5,5), sigmaX=2)
                    #mask[mask >= thresh] = 1
                    #mask[mask < thresh] = 0
                    mask = drop_small_mask(mask, 128)
                    mask = FillHole(mask)
                    mask = pad_mask(conf, mask)
                    mask = resize_mask(mask, height, width)
                    enc_mask = '' if mask.sum() == 0 else rle_encode(mask)
                    ids.append(img_id)
                    classes.append(class_name)
                    masks.append(enc_mask)

    pred_df = pd.DataFrame({'id': ids, 'class': classes, 'predicted': masks})
    if pred_df.shape[0] > 0:
        # sort according to the given order and save to a csv file
        subm = subm.merge(pred_df, on=['id', 'class'])

        subm.to_csv('submission.csv', index=False)

In [23]:
test_thresh = 0.35
ckpt_paths ='../input/checkpoint/*.pth'
ckpt_paths_list=glob(ckpt_paths)
print(ckpt_paths_list)
run('../input/uw-madison-gi-tract-image-segmentation', ckpt_paths_list, test_thresh)

In [19]:
FillHole(mask).shape