The notebook presents cell segmentation by integrating U-Net architecture. Data was taken from [here](https://www.kaggle.com/competitions/data-science-bowl-2018/data). Unfortunately, the project still lacks from success due to non-updated errors in the training, encouraging more investigation. 

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch import optim
import os

# Data Preparation

Defining constants

In [4]:
TRAIN_PATH = 'stage1_train'
TEST_PATH = 'stage1_test'
DEVICE = 'cpu'

IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3
SIZE_IMAGE = (3, 270, 290)
SIZE_MASK = (1, 194, 226)

Creating Dataset class to import data to notebook with data loader.

In [5]:
class ImageDataset(Dataset):
    def __init__(self, filepath, device = 'cpu', transform_image = None, train = True, transform_mask = None,
                 size_image = (128, 128), size_mask = (128, 128)):
        self.device = device
        self.filepath = filepath
        self.filenames = os.listdir(self.filepath)
        self.train = train
        self.size_image = size_image
        self.size_mask = size_mask
        self.transform_image = transforms.Normalize
        self.transform_mask = transforms.Normalize

        # height, width
        self.resizer_image = transforms.Resize(size_image)
        self.resizer_mask = transforms.Resize(size_mask)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        img_path = os.path.join(self.filepath, self.filenames[index], 'images', self.filenames[index] + '.png')

        image = read_image(img_path)[:3, :, :].to(self.device)
        image = image.float()
        image = image / 255
        image = self.resizer_image(image)

        if self.train:
            mask_path = os.path.join(self.filepath, self.filenames[index], 'masks')
            mask = torch.zeros((1, *self.size_mask)).to(self.device)
            for filename in os.listdir(mask_path):
                mask_file_path = os.path.join(mask_path, filename)

                mask_image = read_image(mask_file_path).to(self.device)
                mask_image = self.resizer_mask(mask_image)
                mask = torch.maximum(mask, mask_image)

            mask = mask / 255

            return image, mask

        return image

Creating the dataset object. It is worth to add that the train data was previously and randomly splitted inside stage1_train and 10% of the data was collected in a new folder called stage1_val to evaluate the model performance later. 

In [6]:
training_dataset = ImageDataset('stage1_train', train = True, 
                                size_image = SIZE_IMAGE[1:], size_mask = SIZE_MASK[1:])

In [7]:
validation_dataset = ImageDataset('stage1_val', train = True, 
                                size_image = SIZE_IMAGE[1:], size_mask = SIZE_MASK[1:])

In [8]:
test_dataset = ImageDataset('stage1_test', train = False, 
                                size_image = SIZE_IMAGE[1:])

In [9]:
train_loader = DataLoader(training_dataset, batch_size = 16)
validation_loader = DataLoader(validation_dataset, batch_size = 16)
test_loader = DataLoader(test_dataset, batch_size = 16)

# Modeling

Creating U-Net architecture 

In [76]:
class UNet(nn.Module):
    def __init__(self, image_shape, kernel_size_conv = 3, kernel_size_pool = 2,
                 kernel_size_conv_transpose = 2, conv_padding = 0,
                 stride = 2, pool_padding = 0):
        super().__init__()

        # dropout
        c, h, w = image_shape

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        self.dropout = nn.Dropout(p = 0.1)
        self.pool = nn.MaxPool2d(kernel_size = kernel_size_pool, stride = stride, padding = pool_padding)

        self.conv_c1_c = nn.Conv2d(in_channels = c, out_channels = 16, kernel_size = kernel_size_conv, padding = conv_padding)
        self.conv_c1 = nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = kernel_size_conv, padding = conv_padding)

        self.conv_c2_1 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = kernel_size_conv, padding = conv_padding)
        self.conv_c2_2 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = kernel_size_conv, padding = conv_padding)

        self.conv_c3_1 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = kernel_size_conv, padding = conv_padding)
        self.conv_c3_2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = kernel_size_conv, padding = conv_padding)

        self.conv_c4_1 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = kernel_size_conv, padding = conv_padding)
        self.conv_c4_2 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = kernel_size_conv, padding = conv_padding)

        self.conv_c5_1 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = kernel_size_conv, padding = conv_padding)
        self.conv_c5_2 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = kernel_size_conv, padding = conv_padding)

        self.conv_u1_1 = nn.Conv2d(in_channels = 32, out_channels = 16, kernel_size = kernel_size_conv)
        self.conv_u1_2 = nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = kernel_size_conv)
        self.conv_u1_3 = nn.Conv2d(in_channels = 16, out_channels = 1, kernel_size = kernel_size_conv)
        self.conv_transpose_u1 = nn.ConvTranspose2d(in_channels = 32, out_channels = 16, stride = stride, kernel_size = kernel_size_conv_transpose)

        self.conv_u2_1 = nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = kernel_size_conv)
        self.conv_u2_2 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = kernel_size_conv)
        self.conv_transpose_u2 = nn.ConvTranspose2d(in_channels = 64, out_channels = 32, stride = stride, kernel_size = kernel_size_conv_transpose)

        self.conv_u3_1 = nn.Conv2d(in_channels = 128, out_channels = 64, kernel_size = kernel_size_conv)
        self.conv_u3_2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = kernel_size_conv)
        self.conv_transpose_u3 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, stride = stride, kernel_size = kernel_size_conv_transpose)

        self.conv_u4_1 = nn.Conv2d(in_channels = 256, out_channels = 128, kernel_size = kernel_size_conv)
        self.conv_u4_2 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = kernel_size_conv)
        self.conv_transpose_u4 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, stride = stride, kernel_size = kernel_size_conv_transpose)

    def forward(self, x):
        xc1 = self.conv_c1_c(x)
        xc1 = self.relu(xc1)
        xc1 = self.dropout(xc1)
        xc1 = self.conv_c1(xc1)
        xc1 = self.relu(xc1)
        xc2 = self.pool(xc1)

        xc2 = self.conv_c2_1(xc2)
        xc2 = self.relu(xc2)
        xc2 = self.dropout(xc2)
        xc2 = self.conv_c2_2(xc2)
        xc2= self.relu(xc2)
        xc3 = self.pool(xc2)

        xc3 = self.conv_c3_1(xc3)
        xc3 = self.relu(xc3)
        xc3 = self.dropout(xc3)
        xc3 = self.conv_c3_2(xc3)
        xc3 = self.relu(xc3)
        xc4 = self.pool(xc3)

        xc4 = self.conv_c4_1(xc4)
        xc4 = self.relu(xc4)
        xc4 = self.dropout(xc4)
        xc4 = self.conv_c4_2(xc4)
        xc4 = self.relu(xc4)
        xc5 = self.pool(xc4)

        xc5 = self.conv_c5_1(xc5)
        xc5 = self.relu(xc5)
        xc5 = self.dropout(xc5)
        xc5 = self.conv_c5_2(xc5)
        xc5 = self.relu(xc5)

        xu4 = self.conv_transpose_u4(xc5)

        k_dim_2 = (xc4.shape[2] - xu4.shape[2]) // 2
        k_residual_dim_2 = (xc4.shape[2] - xu4.shape[2]) % 2
        k_dim_3 = (xc4.shape[3] - xu4.shape[3]) // 2
        k_residual_dim_3 = (xc4.shape[3] - xu4.shape[3]) % 2
        size_dim_2 = xc4.shape[2] - k_residual_dim_2
        size_dim_3 = xc4.shape[3] - k_residual_dim_3

        xc4_portion = xc4[:, :, k_dim_2:size_dim_2 - k_dim_2, k_dim_3:size_dim_3 - k_dim_3]
        xu4 = torch.cat((xc4_portion, xu4), dim = 1)
        xu4 = self.conv_u4_1(xu4)
        xu4 = self.relu(xu4)
        xu4 = self.dropout(xu4)
        xu4 = self.conv_u4_2(xu4)
        xu4 = self.relu(xu4)

        xu3 = self.conv_transpose_u3(xu4)

        k_dim_2 = (xc3.shape[2] - xu3.shape[2]) // 2
        k_residual_dim_2 = (xc3.shape[2] - xu3.shape[2]) % 2
        k_dim_3 = (xc3.shape[3] - xu3.shape[3]) // 2
        k_residual_dim_3 = (xc3.shape[3] - xu3.shape[3]) % 2
        size_dim_2 = xc3.shape[2] - k_residual_dim_2
        size_dim_3 = xc3.shape[3] - k_residual_dim_3

        xc3_portion = xc3[:, :, k_dim_2:size_dim_2 - k_dim_2, k_dim_3:size_dim_3 - k_dim_3]
        xu3 = torch.cat((xc3_portion, xu3), dim = 1)
        xu3 = self.conv_u3_1(xu3)
        xu3 = self.relu(xu3)
        xu3 = self.dropout(xu3)
        xu3 = self.conv_u3_2(xu3)
        xu3 = self.relu(xu3)

        xu2 = self.conv_transpose_u2(xu3)

        k_dim_2 = (xc2.shape[2] - xu2.shape[2]) // 2
        k_residual_dim_2 = (xc2.shape[2] - xu2.shape[2]) % 2
        k_dim_3 = (xc2.shape[3] - xu2.shape[3]) // 2
        k_residual_dim_3 = (xc2.shape[3] - xu2.shape[3]) % 2
        size_dim_2 = xc2.shape[2] - k_residual_dim_2
        size_dim_3 = xc2.shape[3] - k_residual_dim_3

        xc2_portion = xc2[:, :, k_dim_2:size_dim_2 - k_dim_2, k_dim_3:size_dim_3 - k_dim_3]
        xu2 = torch.cat((xc2_portion, xu2), dim = 1)
        xu2 = self.conv_u2_1(xu2)
        xu2 = self.relu(xu2)
        xu2 = self.dropout(xu2)
        xu2 = self.conv_u2_2(xu2)
        xu2 = self.relu(xu2)

        xu1 = self.conv_transpose_u1(xu2)

        k_dim_2 = (xc1.shape[2] - xu1.shape[2]) // 2
        k_residual_dim_2 = (xc1.shape[2] - xu1.shape[2]) % 2
        k_dim_3 = (xc1.shape[3] - xu1.shape[3]) // 2
        k_residual_dim_3 = (xc1.shape[3] - xu1.shape[3]) % 2
        size_dim_2 = xc1.shape[2] - k_residual_dim_2
        size_dim_3 = xc1.shape[3] - k_residual_dim_3

        xc1_portion = xc1[:, :, k_dim_2:size_dim_2 - k_dim_2, k_dim_3:size_dim_3 - k_dim_3]
        xu1 = torch.cat((xc1_portion, xu1), dim = 1)
        xu1 = self.conv_u1_1(xu1)
        xu1 = self.relu(xu1)
        xu1 = self.dropout(xu1)
        xu1 = self.conv_u1_2(xu1)
        xu1 = self.relu(xu1)
        xu1 = self.conv_u1_3(xu1)
        xu1 = self.sigmoid(xu1)
        
        return xu1

Initializing model and checking the random performance.

In [88]:
model = UNet(SIZE_IMAGE, conv_padding = 'same').to(DEVICE)

In [78]:
x = torch.randn((16, 3, 270, 290))
model(x).shape

torch.Size([16, 1, 194, 226])

Number of total parameters in the model.

In [79]:
iterator = iter(model.parameters())

total_parameters = 0

for subparameters in iterator:
    total_parameters += subparameters.numel()

In [80]:
total_parameters

1941233

Defining the optimizer, loss function, train and evaluation function.

In [86]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [87]:
loss_fn = F.binary_cross_entropy

In [89]:
def train_model(model, train_loader, optimizer, loss_fn, n):
    model.train()
    m = 5
    train_loss = 0
    total_train_loss = 0
    
    for i, (X, y) in enumerate(train_loader):
        optimizer.zero_grad()
        
        print('Prediction')
        pred = model(X)
        
        loss = loss_fn(y, pred)
        
        print('Backwarding\n')
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()

        if i % m == m - 1:
            print('{} Epoch --> {}/{} with loss of {:.5f}'.format(
                    epoch + 1, i + 1, n, train_loss / m 
                    ))
            total_train_loss += train_loss
            train_loss = 0
    
    if i % m != m - 1:   
        total_train_loss += train_loss
    return total_train_loss / n
            
def evaluate_model(model, validation_loader, loss_fn, n):
    model.eval()
    m = 5
    
    with torch.no_grad():
        total_val_loss = 0

        for i, (X, y) in enumerate(validation_loader):
            pred = model(X)
            loss = loss_fn(pred, y)
            total_val_loss += loss.item()
            
            if i % m == m - 1:
                print(i)
                
        return total_val_loss / n

Training and Testing.

In [90]:
epochs = 5

for epoch in range(epochs):
    print('-' * 50)
    
    train_loss = train_model(model, train_loader, optimizer, loss_fn, n = len(train_loader))
    print('\nAverage train loss error --> {:.5f}'.format(train_loss))
    
    test_loss = evaluate_model(model, validation_loader, loss_fn, n = len(validation_loader))
    print('Average test loss error --> {:.5f}\n'.format(test_loss))
    

--------------------------------------------------
Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

1 Epoch --> 5/38 with loss of 49.51272
Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

1 Epoch --> 10/38 with loss of 49.56812
Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

1 Epoch --> 15/38 with loss of 49.38841
Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

1 Epoch --> 20/38 with loss of 49.34517
Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

1 Epoch --> 25/38 with loss of 49.69519
Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

Prediction
Backwarding

1 Epoch --> 30/38 with loss of

Saving the final model.

In [91]:
torch.save(model.state_dict(), './models/UNet')

In [93]:
model = UNet(SIZE_IMAGE).to('cpu')
model.load_state_dict(torch.load('./models/UNet'))

<All keys matched successfully>

The project is still under the development since errors in each corresponding iteration of epochs seem not be updated properly. This suggests that further investigation is necessary and final and more successfuly project is expected to be uploaded within 1-2 weeks. 