In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torchvision import transforms as T

import PIL
from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt

In [2]:
# Get Image and mask data
class NucleusDataset(Dataset):
    def __init__(self, path, train = True, transform=None):
        self.path = path
        self.train = train
        self.transform = transform 

        self.IMG_WIDTH = 128
        self.IMG_HEIGHT = 128
        self.IMG_CHANNELS = 3


        if train:
            self.idx = next(os.walk(self.path))[1]
            self.images =[]
            self.masks = []

            for id_ in tqdm(self.idx):
                path = self.path + id_
                img = Image.open(path + '/images/' + id_ + '.png').convert('RGB')
                img = img.resize((self.IMG_HEIGHT, self.IMG_WIDTH))

                self.images.append(img)

                mask = np.zeros((self.IMG_HEIGHT, self.IMG_WIDTH, 1), dtype=np.bool)
                for mask_file in next(os.walk(path + '/masks/'))[2]:
                    mask_ = Image.open(path + '/masks/' + mask_file)
                    mask_ = np.expand_dims(mask_.resize((self.IMG_HEIGHT, self.IMG_WIDTH)), axis=-1)
                    mask = np.maximum(mask, mask_)

                    self.masks.append(mask)

        else:
            self.idx = next(os.walk(self.path))[1]
            self.test_imgs = []

            for id_ in tqdm(self.idx):
                path = self.path + id_
                img = Image.open(path + '/images/' + id_ + '.png').convert("RGB")
                img = img.resize((self.IMG_HEIGHT, self.IMG_WIDTH))

                self.test_imgs.append(img)

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, item):
        if self.train:
            image, mask = self.images[item], self.masks[item]

            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)
            
            return image, mask
        
        else:
            image = self.test_imgs[item]

            if self.transform:
                image = self.transform(image)

            return image
            
class Normalizer:
    def __call__(self, image):
        image = np.array(image)
        image = image.astype(np.float32) / 255
        return image


class ToTensor:
    def __call__(self, data):
        if len(data.shape) == 2:
            data = np.expand_dims(data, axis=0)
        elif len(data.shape) == 3:
            data = data.transpose((2, 0, 1))
        else:
            print("Unsupported shape!")
        return torch.from_numpy(data)

In [3]:
# load Train and Test dataset
train_data = NucleusDataset(path="nuclei_datasets/stage1_train/",
                                     train=True,
                                     transform= T.Compose([Normalizer(), ToTensor()]))
test_data = NucleusDataset(path="nuclei_datasets/stage1_test/",
                                     train=False,
                                     transform= T.Compose([Normalizer(), ToTensor()]))

100%|████████████████████████████████████████████████████████████████████████████████| 670/670 [00:32<00:00, 20.87it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 65/65 [00:00<00:00, 141.07it/s]


In [4]:
# Building the U-net Model

class UNet(nn.Module):
    def __init__(self, kernel_size=3, padding=1):
        super(UNet, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=kernel_size, padding=padding)
        self.conv1_2 = nn.Conv2d(16, 16, kernel_size=kernel_size, padding=padding)
        self.maxpool1 = nn.MaxPool2d(2)

        self.conv2_1 = nn.Conv2d(16, 32, kernel_size=kernel_size, padding=padding)
        self.conv2_2 = nn.Conv2d(32, 32, kernel_size=kernel_size, padding=padding)
        self.maxpool2 = nn.MaxPool2d(2)

        self.conv3_1 = nn.Conv2d(32, 64, kernel_size=kernel_size, padding=padding)
        self.conv3_2 = nn.Conv2d(64, 64, kernel_size=kernel_size, padding=padding)
        self.maxpool3 = nn.MaxPool2d(2)

        self.conv4_1 = nn.Conv2d(64, 128, kernel_size=kernel_size, padding=padding)
        self.conv4_2 = nn.Conv2d(128, 128, kernel_size=kernel_size, padding=padding)
        self.maxpool4 = nn.MaxPool2d(2)

        self.conv5_1 = nn.Conv2d(128, 256, kernel_size=kernel_size, padding=padding)
        self.conv5_2 = nn.Conv2d(256, 256, kernel_size=kernel_size, padding=padding)
        self.conv5_t = nn.ConvTranspose2d(256, 128, 2, stride=2)

        self.conv6_1 = nn.Conv2d(256, 128, kernel_size=kernel_size, padding=padding)
        self.conv6_2 = nn.Conv2d(128, 128, kernel_size=kernel_size, padding=padding)
        self.conv6_t = nn.ConvTranspose2d(128, 64, 2, stride=2)

        self.conv7_1 = nn.Conv2d(128, 64, kernel_size=kernel_size, padding=padding)
        self.conv7_2 = nn.Conv2d(64, 64, kernel_size=kernel_size, padding=padding)
        self.conv7_t = nn.ConvTranspose2d(64, 32, 2, stride=2)

        self.conv8_1 = nn.Conv2d(64, 32, kernel_size=kernel_size, padding=padding)
        self.conv8_2 = nn.Conv2d(32, 32, kernel_size=kernel_size, padding=padding)
        self.conv8_t = nn.ConvTranspose2d(32, 16, 2, stride=2)

        self.conv9_1 = nn.Conv2d(32, 16, kernel_size=kernel_size, padding=padding)
        self.conv9_2 = nn.Conv2d(16, 16, kernel_size=kernel_size, padding=padding)

        self.conv10_1 = nn.Conv2d(16, 1, kernel_size=kernel_size, padding=padding)
        self.conv10_2 = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding)

    def forward(self, x):
        conv1 = F.elu(self.conv1_1(x))
        conv1 = F.elu(self.conv1_2(conv1))
        pool1 = self.maxpool1(conv1)

        conv2 = F.elu(self.conv2_1(pool1))
        conv2 = F.elu(self.conv2_2(conv2))
        pool2 = self.maxpool2(conv2)

        conv3 = F.elu(self.conv3_1(pool2))
        conv3 = F.elu(self.conv3_2(conv3))
        pool3 = self.maxpool3(conv3)

        conv4 = F.elu(self.conv4_1(pool3))
        conv4 = F.elu(self.conv4_2(conv4))
        pool4 = self.maxpool4(conv4)

        conv5 = F.elu(self.conv5_1(pool4))
        conv5 = F.elu(self.conv5_2(conv5))

        up6 = torch.cat((self.conv5_t(conv5), conv4), dim=1)
        conv6 = F.elu(self.conv6_1(up6))
        conv6 = F.elu(self.conv6_2(conv6))

        up7 = torch.cat((self.conv6_t(conv6), conv3), dim=1)
        conv7 = F.elu(self.conv7_1(up7))
        conv7 = F.relu(self.conv7_2(conv7))

        up8 = torch.cat((self.conv7_t(conv7), conv2), dim=1)
        conv8 = F.elu(self.conv8_1(up8))
        conv8 = F.elu(self.conv8_2(conv8))

        up9 = torch.cat((self.conv8_t(conv8), conv1), dim=1)
        conv9 = F.elu(self.conv9_1(up9))
        conv9 = F.elu(self.conv9_2(conv9))
        
        conv10 = F.elu(self.conv10_1(conv9))
        
        return torch.sigmoid(self.conv10_2(conv10))

In [5]:
# U-Net Architecture
UNet()

UNet(
  (conv1_1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3_1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (c

In [6]:
# Training Model
def train(train_loader, epochs, learning_rate):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        print('Epoch {} of {}'.format(epoch + 1, epochs))
        print('-><-' * 10)

        running_loss = 0.0
        for batch_idx, (images, masks) in tqdm(enumerate(train_loader), total = len(train_loader)):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()

            output = model(images)
            loss = F.binary_cross_entropy(output, masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        print("Loss: {:.4f}\n".format(epoch_loss))

    os.makedirs("models", exist_ok=True)
    torch.save(model, "models/model_1.pt")


In [8]:
# Train Model
train_loader = torch.utils.data.DataLoader(
        train_data, batch_size= 8, shuffle=True)

train(train_loader, epochs = 20 , learning_rate = 0.001)

Epoch 1 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:06<00:00,  2.01s/it]


Loss: 0.3594

Epoch 2 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:15<00:00,  2.09s/it]


Loss: 0.3186

Epoch 3 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:10<00:00,  2.09s/it]


Loss: 0.3167

Epoch 4 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:15<00:00,  2.03s/it]


Loss: 0.3161

Epoch 5 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:19<00:00,  2.21s/it]


Loss: 0.3165

Epoch 6 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:24<00:00,  2.15s/it]


Loss: 0.3168

Epoch 7 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:25<00:00,  2.75s/it]


Loss: 0.3163

Epoch 8 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:29<00:00,  2.43s/it]


Loss: 0.3161

Epoch 9 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:38<00:00,  2.49s/it]


Loss: 0.3160

Epoch 10 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:47<00:00,  2.32s/it]


Loss: 0.3172

Epoch 11 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:43<00:00,  2.81s/it]


Loss: 0.3162

Epoch 12 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:39<00:00,  2.43s/it]


Loss: 0.3163

Epoch 13 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:35<00:00,  2.41s/it]


Loss: 0.3162

Epoch 14 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:32<00:00,  2.34s/it]


Loss: 0.3161

Epoch 15 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:45<00:00,  2.43s/it]


Loss: 0.3162

Epoch 16 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:18<00:00,  2.07s/it]


Loss: 0.3160

Epoch 17 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:13<00:00,  2.08s/it]


Loss: 0.3163

Epoch 18 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:10<00:00,  2.09s/it]


Loss: 0.3163

Epoch 19 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:10<00:00,  2.10s/it]


Loss: 0.3161

Epoch 20 of 20
-><--><--><--><--><--><--><--><--><--><-


100%|██████████████████████████████████████████████████████████████████████████████████| 84/84 [03:11<00:00,  2.09s/it]


Loss: 0.3165



  "type " + obj.__name__ + ". It won't be checked "


In [11]:
# Test the model
def test(test_loader , weights_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = torch.load(weights_path)
    model.eval()
    model.to(device)

    with torch.no_grad():
        images = next(iter(test_loader)).to(device)

        outputs = model(images)
        images = tensor_to_numpy(images)
        outputs = tensor_to_numpy(outputs)
        show_images(images, outputs)


def tensor_to_numpy(tensor):
    t_numpy = tensor.cpu().numpy()
    t_numpy = np.transpose(t_numpy, [0, 2, 3, 1])
    t_numpy = np.squeeze(t_numpy)
    
    return t_numpy


def show_images(images, masks, columns=4):
    fig = plt.figure()
    rows = np.ceil((images.shape[0] + masks.shape[0]) / columns)
    index = 1
    for image, mask in zip(images, masks):
        f1 = fig.add_subplot(rows, columns, index)
        f1.set_title('input')
        plt.axis('off')
        plt.imshow(image)
        index += 1
        

        f2 = fig.add_subplot(rows, columns, index)
        f2.set_title('prediction')
        plt.axis('off')     
        plt.imshow(mask)
        index += 1

    plt.show()
