# PyTorch UNet Model for Pneumothorax Segmentation

In [1]:
# Imports
import numpy as np
import pandas as pd
from glob import glob
import pydicom
import random

# import image manipulation
from PIL import Image

# import matplotlib for visualization
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt

# Import PyTorch
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torch.utils.data.sampler import SubsetRandomSampler
from torch.autograd import Variable

from skimage.morphology import binary_opening, disk, label

# Import rle utils
import sys
sys.path.insert(0, '../input/siim-acr-pneumothorax-segmentation')
from mask_functions import rle2mask, mask2rle # import mask utilities

from tqdm import tqdm_notebook

## Dataset Utility Functions

In [2]:
# Data loading utility
def load_data(datafilepath = '../input/siim-train-test/siim/', healthy_num = 2000):
    '''
    Function to load the dataset.
    INPUT:
        datafilepath - path to directory containing the dataset.
    OUTPUT:
        train_fns - train dataset
        train_fns - test dataset
        df_masks - pandas dataframe containing masks for train dataset
    '''
    # Load full training and test sets
    train_fns = sorted(glob(datafilepath + 'dicom-images-train/*/*/*.dcm'))
    test_fns = sorted(glob(datafilepath + 'dicom-images-test/*/*/*.dcm'))
    # Load csv masks
    df_masks = pd.read_csv(datafilepath + 'train-rle.csv', index_col='ImageId')
    # create a list of filenames with images to use
    
    counter = 0
    files_list = []
    for fname in train_fns:
        try:
            if '-1' in df_masks.loc[fname.split('/')[-1][:-4],' EncodedPixels']:
                if counter <= 2000:
                    files_list.append(fname)
                    counter += 1
            else:
                files_list.append(fname)
        except:
            pass

    return train_fns, test_fns, df_masks, files_list

In [3]:
def normalize(arr):
    """
    Function performs the linear normalizarion of the array.
    https://stackoverflow.com/questions/7422204/intensity-normalization-of-image-using-pythonpil-speed-issues
    http://en.wikipedia.org/wiki/Normalization_%28image_processing%29
    INPUT:
        arr - orginal numpy array
    OUTPUT:
        arr - normalized numpy array
    """
    arr = arr.astype('float')
    # Do not touch the alpha channel
    for i in range(1):
        minval = arr[...,i].min()
        maxval = arr[...,i].max()
        if minval != maxval:
            arr[...,i] -= minval
            arr[...,i] *= (255.0/(maxval-minval))
    return arr

def normalize_image(img):
    """
    Function performs the normalization of the image.
    https://stackoverflow.com/questions/7422204/intensity-normalization-of-image-using-pythonpil-speed-issues
    INPUT:
        image - PIL image to be normalized
    OUTPUT:
        new_img - PIL image normalized
    """
    arr = np.array(img)
    new_img = Image.fromarray(normalize(arr).astype('uint8'),'L')
    return new_img

In [4]:
# Define the dataset
class PneumothoraxDataset(Dataset):
    '''
    The dataset for pneumothorax segmentation.
    '''

    def __init__(self, fns, df_masks, files_list, transform=True, size = (224, 224), mode = 'train'):
        '''
        INPUT:
            fns - glob containing the images
            df_masks - dataframe containing image masks
            transform (optional) - enable transforms for the images
        '''
        self.labels_frame = df_masks
        self.fns = fns
        self.transform = transform
        self.size = size
        self.transforms = transforms.Compose([transforms.Resize(self.size), transforms.ToTensor()])
        self.mode = mode
        self.files_list = files_list

    def __len__(self):
        if (self.mode == 'validation'):
            return len(self.fns)
        else:
            return len(self.files_list)

    def __getitem__(self, idx):
        '''
        Function to get items from dataset by idx.
        INPUT:
            idx - id of the image in the dataset
        '''
        # image height and width
        im_height = 1024
        im_width = 1024
        # image channels
        im_chan = 1

        # get train image and mask
        np_image = np.zeros((im_height, im_width, im_chan), dtype=np.uint8)
        np_mask = np.zeros((im_height, im_width, 1), dtype=np.bool)
        
        # if validation then return filename instead of mask
        if self.mode == 'validation':
            # read dcm file with image
            dataset = pydicom.read_file(self.fns[idx])
            np_image = np.expand_dims(dataset.pixel_array, axis=2)
        
            image = Image.fromarray(np_image.reshape(im_height, im_width) , 'L')
            image = self.transforms(image)
            return [image, self.fns[idx].split('/')[-1][:-4]]
        
        # read dcm file with image
        dataset = pydicom.read_file(self.files_list[idx])
        np_image = np.expand_dims(dataset.pixel_array, axis=2)

        # load mask
        try:
            # no pneumothorax
            if '-1' in self.labels_frame.loc[self.files_list[idx].split('/')[-1][:-4],' EncodedPixels']:
                np_mask = np.zeros((im_height, im_width, 1), dtype=np.bool)
            else:
                # there is pneumothorax
                if type(self.labels_frame.loc[self.files_list[idx].split('/')[-1][:-4],' EncodedPixels']) == str:
                    np_mask = np.expand_dims(rle2mask(self.labels_frame.loc[self.files_list[idx].split('/')[-1][:-4],' EncodedPixels'], im_height, im_width), axis=2)
                else:
                    np_mask = np.zeros((1024, 1024, 1))
                    for x in self.labels_frame.loc[self.files_list[idx].split('/')[-1][:-4],' EncodedPixels']:
                        np_mask =  np_mask + np.expand_dims(rle2mask(x, 1024, 1024), axis=2)
        except KeyError:
            # couldn't find mask in dataframe
            np_mask = np.zeros((im_height, im_width, 1), dtype=np.bool) # Assume missing masks are empty masks.

        # convert to PIL
        image = Image.fromarray(np_image.reshape(im_height, im_width) , 'L')
        
        np_mask = np.transpose(np_mask)
        mask = Image.fromarray(np_mask.reshape(im_height, im_width).astype(np.uint8) , 'L')
        
        # apply optional transforms (random rotation for the training set)
        if self.transform:
            # apply random rotation
            angle = random.uniform(-5, 5)
            image = TF.rotate(image, angle)
            mask = TF.rotate(mask, angle)
        
        # apply required transforms normalization, reshape and convert to tensor
        image = normalize_image(image)
        image = self.transforms(image)
        mask = self.transforms(mask)
        
        # convert to tensor and clip mask
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))
        mask = np.clip(mask, 0, 1)

        return image, mask

## Training Utilities

U-Net Model:

In [5]:
# https://www.kaggle.com/witwitchayakarn/u-net-with-pytorch

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(in_channels,
                     out_channels,
                     kernel_size=1,
                     groups=groups,
                     stride=1)

def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1):
    return nn.Conv2d(in_channels,
                     out_channels,
                     kernel_size=3,
                     stride=stride,
                     padding=padding,
                     bias=bias,
                     groups=groups)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(in_channels,
                                  out_channels,
                                  kernel_size=2,
                                  stride=2)
    else:
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))
    
class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool

class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 merge_mode='concat',
                 up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels,
                                self.out_channels,
                                mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(2*self.out_channels,
                                 self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)

        self.conv2 = conv3x3(self.out_channels, self.out_channels)

    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x
    
class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).
    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=1, depth=5,
                 start_filts=64, up_mode='transpose',
                 merge_mode='concat'):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                upsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))

        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        self.conv_final = conv1x1(outs, self.num_classes)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal(m.weight)
            nn.init.constant(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)

    def forward(self, x):
        encoder_outs = []

        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)

        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x

## Train the Model

In [6]:
def train(model, device, trainloader, testloader, optimizer, criterion, epochs):
    model.to(device)
    steps = 0
    running_loss = 0
    print_every = 100

    for epoch in range(epochs):
        for inputs, labels in trainloader:

            steps += 1
            # Move input and label tensors to the default device
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model.forward(inputs)
            loss = criterion(outputs, labels.reshape(-1, 224, 224).long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if steps % print_every == 0:
                test_loss = 0
                accuracy = 0
                model.eval()
                with torch.no_grad():
                    for inputs, labels in testloader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model.forward(inputs)
                        batch_loss = criterion(outputs, labels.reshape(-1, 224, 224).long())
                        test_loss += batch_loss.item()

                print(f"Epoch {epoch+1}/{epochs}.. "
                      f"Train loss: {running_loss/print_every:.3f}.. "
                      f"Test loss: {test_loss/len(testloader):.3f}.. ")
 
                running_loss = 0
                   
                model.train()

In [7]:
# Load data
print('Loading data: \n')
train_fns, test_fns, df_masks, files_list = load_data()

# Training presets
batch_size = 16
epochs = 10
learning_rate = 0.001
test_split = .2

original_size = 1024
width = 224
height = 224

# Create dataset and data loader
print('Preparing the dataset: \n')
train_ds = PneumothoraxDataset(train_fns, df_masks, files_list, transform=True, size = (height, width), mode = 'train')

# Creating data indices for training and validation splits:
dataset_size = len(train_ds)
indices = list(range(dataset_size))
split = int(np.floor(test_split * dataset_size))
np.random.seed(42)
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

trainloader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, sampler=train_sampler, num_workers=4)
testloader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, sampler=test_sampler, num_workers=4)

valid_ds = PneumothoraxDataset(test_fns, None, None, transform=False, size = (height, width), mode = 'validation')
validloader = DataLoader(valid_ds, batch_size=8, shuffle=False, num_workers=1)

torch.cuda.empty_cache()

# Prepare for training: initialize model, loss function, optimizer
class param:
    unet_depth = 5
    unet_start_filters = 8
model = UNet(2, depth=param.unet_depth, start_filts=param.unet_start_filters, merge_mode='concat')

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup device for training
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

# Train the model
print('Train the model: \n')

train_stats_df = pd.DataFrame(columns = ['Epoch','Train loss','Test loss'])
train(model, device, trainloader, testloader, optimizer, criterion, epochs)

# Save the model
#print('Save the model: \n')
#filepath = 'simple_unet.pth'
#checkpoint = {'state_dict': model.state_dict()}
#torch.save(checkpoint, filepath)

Loading data: 

Preparing the dataset: 

Train the model: 





Epoch 1/10.. Train loss: 0.120.. Test loss: 0.034.. 
Epoch 1/10.. Train loss: 0.030.. Test loss: 0.031.. 
Epoch 2/10.. Train loss: 0.029.. Test loss: 0.030.. 
Epoch 2/10.. Train loss: 0.028.. Test loss: 0.030.. 
Epoch 3/10.. Train loss: 0.030.. Test loss: 0.030.. 
Epoch 3/10.. Train loss: 0.029.. Test loss: 0.029.. 
Epoch 4/10.. Train loss: 0.029.. Test loss: 0.029.. 
Epoch 4/10.. Train loss: 0.029.. Test loss: 0.030.. 
Epoch 5/10.. Train loss: 0.026.. Test loss: 0.031.. 
Epoch 5/10.. Train loss: 0.027.. Test loss: 0.029.. 
Epoch 6/10.. Train loss: 0.029.. Test loss: 0.028.. 
Epoch 6/10.. Train loss: 0.029.. Test loss: 0.028.. 
Epoch 6/10.. Train loss: 0.026.. Test loss: 0.029.. 
Epoch 7/10.. Train loss: 0.027.. Test loss: 0.028.. 
Epoch 7/10.. Train loss: 0.026.. Test loss: 0.029.. 
Epoch 8/10.. Train loss: 0.026.. Test loss: 0.027.. 
Epoch 8/10.. Train loss: 0.028.. Test loss: 0.028.. 
Epoch 9/10.. Train loss: 0.026.. Test loss: 0.027.. 
Epoch 9/10.. Train loss: 0.026.. Test loss: 0.

## Create Submission

In [8]:
submission = {'ImageId': [], 'EncodedPixels': []}

model.eval()
torch.cuda.empty_cache()

for X, fns in tqdm_notebook(validloader):
    X = Variable(X).cuda()
    output = model(X)
    for i, fname in enumerate(fns):
        mask = torch.sigmoid(output[i, 0]).data.cpu().numpy()
        mask = binary_opening(mask > 0.5, disk(2))
        
        im = Image.fromarray((mask*255).astype(np.uint8)).resize((original_size,original_size))
        im = np.asarray(im)
        
        submission['EncodedPixels'].append(mask2rle(im, original_size, original_size))
        submission['ImageId'].append(fname)

HBox(children=(IntProgress(value=0, max=173), HTML(value='')))




In [9]:
submission_df = pd.DataFrame(submission, columns=['ImageId', 'EncodedPixels'])
submission_df.loc[submission_df.EncodedPixels=='', 'EncodedPixels'] = '-1'
submission_df.to_csv('submission.csv', index=False)
submission_df.sample(10)

Unnamed: 0,ImageId,EncodedPixels
602,1.2.276.0.7230010.3.1.4.8323329.6347.151787519...,-1
749,1.2.276.0.7230010.3.1.4.8323329.6482.151787519...,-1
1142,1.2.276.0.7230010.3.1.4.8323329.684.1517875164...,-1
590,1.2.276.0.7230010.3.1.4.8323329.6336.151787519...,-1
1028,1.2.276.0.7230010.3.1.4.8323329.6736.151787519...,-1
561,1.2.276.0.7230010.3.1.4.8323329.631.1517875163...,-1
50,1.2.276.0.7230010.3.1.4.8323329.5842.151787519...,-1
1092,1.2.276.0.7230010.3.1.4.8323329.6794.151787520...,-1
617,1.2.276.0.7230010.3.1.4.8323329.6360.151787519...,-1
177,1.2.276.0.7230010.3.1.4.8323329.5959.151787519...,-1


In [10]:
submission_df.sample(10)

Unnamed: 0,ImageId,EncodedPixels
928,1.2.276.0.7230010.3.1.4.8323329.6645.151787519...,-1
820,1.2.276.0.7230010.3.1.4.8323329.6547.151787519...,-1
1038,1.2.276.0.7230010.3.1.4.8323329.6745.151787519...,-1
1132,1.2.276.0.7230010.3.1.4.8323329.6830.151787520...,-1
1123,1.2.276.0.7230010.3.1.4.8323329.6822.151787520...,-1
239,1.2.276.0.7230010.3.1.4.8323329.6015.151787519...,-1
238,1.2.276.0.7230010.3.1.4.8323329.6014.151787519...,-1
201,1.2.276.0.7230010.3.1.4.8323329.5980.151787519...,-1
630,1.2.276.0.7230010.3.1.4.8323329.6372.151787519...,-1
841,1.2.276.0.7230010.3.1.4.8323329.6566.151787519...,-1


In [11]:
len(submission_df)

1377

In [12]:
submission_df.EncodedPixels.unique()

array(['-1'], dtype=object)