In [1]:
!pip install --no-index --find-links="../input/download-segmentation-model-pytorch-packages" segmentation-models-pytorch

Looking in links: ../input/download-segmentation-model-pytorch-packages
Processing /kaggle/input/download-segmentation-model-pytorch-packages/segmentation_models_pytorch-0.2.1-py3-none-any.whl
Processing /kaggle/input/download-segmentation-model-pytorch-packages/efficientnet_pytorch-0.6.3.tar.gz
  Preparing metadata (setup.py) ... [?25l- done
[?25hProcessing /kaggle/input/download-segmentation-model-pytorch-packages/timm-0.4.12-py3-none-any.whl
Processing /kaggle/input/download-segmentation-model-pytorch-packages/pretrainedmodels-0.7.4.tar.gz
  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l- \ done
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-py3-none-any.whl size=12421 sha256=e4d795b82c19958fd974a2a3744e66eff4e2ba319849fa2a28f9cba0c095060d
  Stored in directory: /root/.cache/pip/wheels/

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from skimage.io import imread
import math
import cv2
from matplotlib.patches import Rectangle

import torch
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision import transforms

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.cuda import amp

from tensorflow.keras.utils import to_categorical

import segmentation_models_pytorch as smp

import os, glob

In [3]:
PATH = '../input/uw-madison-gi-tract-image-segmentation'
SEED = 42
batch_size = 32
RESCALE_SIZE = 224
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
if not len(sub_df):
    debug = True
    sub_df = pd.read_csv(PATH + '/train.csv')
    sub_df = sub_df.drop(columns=['class','segmentation']).drop_duplicates()
else:
    debug = False
    sub_df = sub_df.drop(columns=['class','predicted']).drop_duplicates()


ids = [row.split('_') for row in sub_df['id']]
cases = [x[0][4:] for x in ids]
days = [x[1][3:] for x in ids]
slices = [x[3] for x in ids]

sub_df['case'] = pd.Series(cases)
sub_df['day'] = pd.Series(days)
sub_df['slice'] = pd.Series(slices)

sub_df = {"id": sub_df.id.values, "case": cases, "day": days, "slice": slices}

sub_df = pd.DataFrame(sub_df)
sub_df.head()

Unnamed: 0,id,case,day,slice
0,case123_day20_slice_0001,123,20,1
1,case123_day20_slice_0002,123,20,2
2,case123_day20_slice_0003,123,20,3
3,case123_day20_slice_0004,123,20,4
4,case123_day20_slice_0005,123,20,5


In [5]:
if debug:
    paths = glob.glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/train/**/*png', recursive=True)
else:
    paths = glob.glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/test/**/*png', recursive=True)

all_imgs_info = {"case": [], "day": [], "slice": [], "path": [], "height": [], "width": []}

for i in paths:
    all_imgs_info['case'].append(i.split('/')[5][4:])
    all_imgs_info['day'].append(i.split('/')[6].split('_')[1][3:])
    all_imgs_info['slice'].append(i.split('/')[-1].split('_')[1])
    all_imgs_info['height'].append(int(i.split('/')[-1].split('_')[2:4][0]))
    all_imgs_info['width'].append(int(i.split('/')[-1].split('_')[2:4][1]))
    all_imgs_info['path'].append(i)

all_imgs_info = pd.DataFrame(all_imgs_info)
all_imgs_info.head()

Unnamed: 0,case,day,slice,path,height,width
0,36,14,6,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
1,36,14,82,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
2,36,14,113,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
3,36,14,76,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
4,36,14,125,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266


In [6]:
test_df = sub_df.merge(all_imgs_info, on=['case','day','slice'], how='left')
test_df.head()

Unnamed: 0,id,case,day,slice,path,height,width
0,case123_day20_slice_0001,123,20,1,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
1,case123_day20_slice_0002,123,20,2,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
2,case123_day20_slice_0003,123,20,3,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
3,case123_day20_slice_0004,123,20,4,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266
4,case123_day20_slice_0005,123,20,5,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266


In [7]:
def load_img(file_name, RESCALE_SIZE=RESCALE_SIZE):
    image = Image.open(file_name)
    image = image.resize((RESCALE_SIZE, RESCALE_SIZE))
    image = np.array(image, dtype='float32')
    image = np.tile(image[...,None], [1, 1, 3])
    mx = np.max(image)
    if mx:
        image /= mx
    return image

In [8]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df
        self.id = df['id'].values
        self.img_paths = df['path'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path  = self.img_paths[index]
        img = []
        img = load_img(img_path)
        idf = self.df[self.df['id'] == self.id[index]]
        h, w = idf['height'].values[0], idf['width'].values[0]

        img = np.transpose(img, (2, 0, 1))

        return torch.tensor(img), self.id[index], h, w

In [9]:
def rle_encode(img):

    pixels = img.flatten()
    pad    = np.array([0])
    pixels = np.concatenate([pad, pixels, pad])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    runs = ' '.join(str(x) for x in runs)
    
    return runs

                             
def masks_pred(msks, ids, heights, widths):
    pred_strings = []; pred_ids = []; pred_classes = [];
    for idx in range(msks.shape[0]):
        height = heights[idx].item()
        width = widths[idx].item()
        msk = cv2.resize(msks[idx], 
                         dsize=(width, height), 
                         interpolation=cv2.INTER_NEAREST)
        rle = [None]*3
        for midx in [0, 1, 2]:
            rle[midx] = rle_encode(msk[...,midx])
        pred_strings.extend(rle)
        pred_ids.extend([ids[idx]]*len(rle))
        pred_classes.extend(['large_bowel', 'small_bowel', 'stomach'])
    return pred_strings, pred_ids, pred_classes

In [10]:
@torch.no_grad()
def test(model, test_data, th=0.5):
    model.eval()
    k = 0
    pred_strings = []; pred_ids = []; pred_classes = [];
    for (img, ids, heights, widths) in test_data:
        img = img.to(device, dtype=torch.float)
        size = img.size()

        y_pred = model(img)
        y_pred = nn.Sigmoid()(y_pred)

        msk = (y_pred.permute((0,2,3,1)) > th).to(torch.uint8).cpu().detach().numpy()
        result = masks_pred(msk, ids, heights, widths)

        pred_strings.extend(result[0])
        pred_ids.extend(result[1])
        pred_classes.extend(result[2])
        k += 1
        if k % 100 == 0:
            print(k)
    return pred_strings, pred_ids, pred_classes

In [11]:
test_dataset = BuildDataset(test_df)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [12]:
# smp_UnetPlusPlusefficientnet_b1 = smp.UnetPlusPlus (
#         encoder_name='efficientnet-b1',
#         encoder_weights=None,
#         classes=3,
#         activation=None,
#     ).to(device)

# smp_UnetPlusPlus_resnet18 = smp.UnetPlusPlus (
#         encoder_name='resnet18',
#         encoder_weights=None,
#         classes=3,
#         activation=None,
#     ).to(device);

smp_UnetPlusPlus_efficientnet_b3 = smp.UnetPlusPlus (
        encoder_name='efficientnet-b3',
        encoder_weights=None,
        classes=3,
        activation=None,
    ).to(device);

In [13]:
# smp_UnetPlusPlusefficientnet_b1.load_state_dict(torch.load('../input/unetplusplus/smp_UnetPlusPlusefficientnet_b1', map_location=torch.device(device)))
smp_UnetPlusPlus_efficientnet_b3.load_state_dict(torch.load('../input/unetplusplus/smp_UnetPlusPlus_efficientnet_b3', map_location=torch.device(device)))
# smp_UnetPlusPlus_resnet18.load_state_dict(torch.load('../input/unetplusplus/smp_UnetPlusPlus_resnet18', map_location=torch.device(device)))

<All keys matched successfully>

In [14]:
# pred_strings, pred_ids, pred_classes = test(smp_UnetPlusPlusefficientnet_b1, test_loader)
pred_strings, pred_ids, pred_classes = test(smp_UnetPlusPlus_efficientnet_b3, test_loader)
# pred_strings, pred_ids, pred_classes = test(smp_UnetPlusPlus_resnet18, test_loader)

100
200
300
400
500
600
700
800
900
1000
1100
1200


In [15]:
pred_df = pd.DataFrame({
    "id":pred_ids,
    "class":pred_classes,
    "predicted":pred_strings
})

if not debug:
    sub_df = pd.read_csv(PATH + '/sample_submission.csv')
    del sub_df['predicted']
else:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')
    del sub_df['segmentation']

sub_df = sub_df.merge(pred_df, on=['id','class'])
sub_df.to_csv('submission.csv',index=False)
display(sub_df.head(5))

Unnamed: 0,id,class,predicted
0,case123_day20_slice_0001,large_bowel,
1,case123_day20_slice_0001,small_bowel,
2,case123_day20_slice_0001,stomach,
3,case123_day20_slice_0002,large_bowel,
4,case123_day20_slice_0002,small_bowel,
