# Image Segmentation U-Net model for retina blood vessel segmentation

## Environment settings

### libraries import

In [None]:
import os
import numpy as np
import random

# resizing and basic image procressing
import cv2

# data extraction
from glob import glob

# display progress bar
from tqdm import tqdm

# read the gif masks
import imageio

# data augmentation library
from albumentations import HorizontalFlip, VerticalFlip, Rotate

# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


# training
import time

# test
from operator import add
import imageio
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

Define function to fix the random seed

In [None]:
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

Define function to calculate running time

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

Define function to create directories for data sorting and saving

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

### Data loading

In [None]:
def load_data(path):
    train_x = sorted(list(map(lambda x: x.replace('\\', '/'), glob(os.path.join(path, "training", "images", "*.tif")))))
    train_y = sorted(list(map(lambda x: x.replace('\\', '/'), glob(os.path.join(path, "training", "1st_manual", "*.gif")))))

    test_x = sorted(list(map(lambda x: x.replace('\\', '/'), glob(os.path.join(path, "test", "images", "*.tif")))))
    test_y = sorted(list(map(lambda x: x.replace('\\', '/'), glob(os.path.join(path, "test", "1st_manual", "*.gif")))))

    return (train_x, train_y), (test_x, test_y)

Data augmentation function

In [None]:
def augment_data(images, masks, save_path, augment=True):
    # define image size
    size = (512, 512)

    for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
        # extract the name of the image
        name = x.split("/")[-1].split('.')[0]
        
        # read the image and the mask
        x = cv2.imread(x, cv2.IMREAD_COLOR)
        y = imageio.mimread(y)[0]

        if augment == True:
            aug = HorizontalFlip(p=1.0) # p is the probability of applying HorizontalFlip
            augmented = aug(image=x, mask=y)
            x1 = augmented["image"]
            y1 = augmented["mask"]

            aug = VerticalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x2 = augmented["image"]
            y2 = augmented["mask"]

            aug = Rotate(limit=45, p=1.0) # 45 degree rotation
            augmented = aug(image=x, mask=y)
            x3 = augmented["image"]
            y3 = augmented["mask"]

            X = [x, x1, x2, x3]
            Y = [y, y1, y2, y3]

        else:
            X = [x]
            Y = [y]

        index = 0
        
        for i, m in zip(X, Y):
            # resize the arrays
            i = cv2.resize(i, size)
            m = cv2.resize(m, size)

            # create temporary file names
            tmp_image_name = "{}_{}.png".format(name, index)
            tmp_mask_name = "{}_{}.png".format(name, index)

            # save images
            image_path = os.path.join(save_path, "image", tmp_image_name)
            mask_path = os.path.join(save_path, "mask", tmp_mask_name)
            
            cv2.imwrite(image_path, i)
            cv2.imwrite(mask_path, m)

            index += 1

## Model

In [None]:
class conv_block(nn.Module):
    
    def __init__(self, in_c, out_c):
        super().__init__()

        # 1st convolutional layer
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        # batch normalization
        self.bn1 = nn.BatchNorm2d(out_c)

        # 2nd convolutional layer
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1) # WARNING put out_c as input here!
        # batch normalization
        self.bn2 = nn.BatchNorm2d(out_c)

        # activation function
        self.relu = nn.ReLU()
    
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x

In [None]:
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

In [None]:
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c) # addition is less computationally expensive

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1) # concatenate the convolutions
        x = self.conv(x)
        
        return x

In [None]:
class build_unet(nn.Module):
    def __init__(self):
        super().__init__()

        # define encoder part
        self.e1 = encoder_block(3, 64) # in_channels: 3 since RGB image composed of 3 channels, out_channels: 64
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        # define the bottleneck layer (bridge layer) --> just a convolution block
        self.b = conv_block(512, 1024)

        # define the decoder part
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        # define the classifier
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0) # out_channels: 1 since we want a binary mask as output

    def forward(self, inputs):
        # encoder
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        # bottleneck
        b = self.b(p4)

        # decoder
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        
        # output
        outputs = self.outputs(d4)

        return outputs

In [None]:
class DriveDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        # reading image
        image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
        image = image/255.0 # dimension: (512, 512, 3)
        image = np.transpose(image, (2, 0, 1))  # dimension: (3, 512, 512)
        image = image.astype(np.float32)
        image = torch.from_numpy(image)

        # reading mask
        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
        mask = mask/255.0   # dimension: (512, 512)
        mask = np.expand_dims(mask, axis=0) # dimension: (1, 512, 512)
        mask = mask.astype(np.float32)
        mask = torch.from_numpy(mask)

        return image, mask
    
    def __len__(self):
        return self.n_samples

## Training

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # transpose inputs with sigmoid
        inputs = torch.sigmoid(inputs) # comment out if the model already contain a sigmoid (or equivalent) activation layer

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid (or equivalent) activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

In [None]:
def train(model, loader, optimizer, loss_function, device):
    epoch_loss = 0.0

    model.train()

    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_function(y_pred, y)
        loss.backward() # back propagation
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)

    return epoch_loss

In [None]:
def evaluate(model, loader, loss_function, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_function(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)

    return epoch_loss

Fix the random seed to get the same data

In [None]:
seeding(42)


Create the files to save the data

In [None]:
create_dir('data/')
create_dir('data/retina_segmentation_data/')

In [None]:
create_dir('data/retina_segmentation_data/augmented_data/train/image/')
create_dir('data/retina_segmentation_data/augmented_data/train/mask/')
create_dir('data/retina_segmentation_data/augmented_data/test/image/')
create_dir('data/retina_segmentation_data/augmented_data/test/mask/')

Load the dataset

In [None]:
data_path = 'data/retina_segmentation_data/'
(train_x, train_y), (test_x, test_y) = load_data(data_path)

print('Dataset size:')
print('Train: \n x: {} \n y: {}'.format(len(train_x), len(train_y)))
print('Test: \n x: {} \n y: {}'.format(len(test_x), len(test_y)))

# MODIFY : create TRAINING SET, VALIDATION SET and TEST SET. To augment data even more do:
- a lot of rotation augmentation (au moins 360x 1 degré)
- all the flipping possibilities
- bluring
- surement d'autres processus pour augmenter le dataset

In [None]:
# create the training data
augment_data(images=train_x, masks=train_y, save_path='data/retina_segmentation_data/augmented_data/train/', augment=True)
# create the test data
augment_data(images=test_x, masks=test_y, save_path='data/retina_segmentation_data/augmented_data/test/', augment=False)

In [None]:
train_x = sorted(glob("data/retina_segmentation_data/augmented_data/train/image/*"))[:40]
train_y = sorted(glob("data/retina_segmentation_data/augmented_data/train/mask/*"))[:40]

valid_x = sorted(glob("data/retina_segmentation_data/augmented_data/test/image/*"))
valid_y = sorted(glob("data/retina_segmentation_data/augmented_data/test/mask/*"))

print('Dataset size:')
print('Train: \n x: {} \n y: {}'.format(len(train_x), len(train_y)))
print('Test: \n x: {} \n y: {}'.format(len(valid_x), len(valid_y)))

Hyperparameters

In [None]:
HEIGHT = 512
WIDTH = 512
SIZE = (HEIGHT, WIDTH)
BATCH_SIZE = 2
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
CHECKPOINT_PATH = "data/retina_segmentation_data/model_saved/checkpoint.pth"

Dataset and Loader

In [None]:
train_dataset = DriveDataset(train_x, train_y)
valid_dataset = DriveDataset(valid_x, valid_y)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
device = torch.device('cpu')
model = build_unet()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_function = DiceBCELoss()

Training the model

In [None]:
best_valid_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, loss_function, device)
    valid_loss = evaluate(model, valid_loader, loss_function, device)

    # saving the model
    if valid_loss < best_valid_loss:
        print('Validation loss improved from {:2.4f} to {:2.4f}. Saving checkpoint: {}'.format(best_valid_loss, valid_loss, CHECKPOINT_PATH))
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), CHECKPOINT_PATH)
        
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print('Epoch: {} | Epoch time: {}min {}s\n\tTraining loss: {:.3f}\n\tValidation loss: {:.3f}\n'.format(epoch, epoch_mins, epoch_secs, train_loss, valid_loss))

## Testing

In [None]:
def calculate_metrics(y_true, y_pred):
    # Ground truth
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    # Prediction
    y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    score_jaccard = jaccard_score(y_true, y_pred)
    score_f1 = f1_score(y_true, y_pred)
    score_recall = recall_score(y_true, y_pred)
    score_precision = precision_score(y_true, y_pred)
    score_acc = accuracy_score(y_true, y_pred)

    return [score_jaccard, score_f1, score_recall, score_precision, score_acc]

In [None]:
def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1) # (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1) # (512, 512, 3)
    return mask

In [None]:
create_dir('data/retina_segmentation_data/results')

Load the test dataset

In [None]:
test_x = sorted(glob("data/retina_segmentation_data/augmented_data/test/image/*"))
test_y = sorted(glob("data/retina_segmentation_data/augmented_data/test/mask/*"))

print('Dataset size:')
print('Train: \n x: {} \n y: {}'.format(len(test_x), len(test_y)))

load the checkpoint file

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = build_unet()
model = model.to(device)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
model.eval()

Testing the model and creating images for output comparison

In [None]:


metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []

for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
    # Extract the name
    print('1')
    name = x.replace('\\', '/').split("/")[-1].split(".")[0]
    print(name)
    # reading image
    image = cv2.imread(x, cv2.IMREAD_COLOR) 
    image = cv2.resize(image, SIZE) # (512, 512, 3)
    x = np.transpose(image, (2, 0, 1)) # (3, 512, 512)
    x = x/255.0 # normalize the pixels
    x = np.expand_dims(x, axis=0) # (1, 3, 512, 512)
    x = x.astype(np.float32)
    x = torch.from_numpy(x)
    x = x.to(device)

    # reading mask
    mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)  
    mask = cv2.resize(mask, SIZE) # (512, 512)
    y = np.expand_dims(mask, axis=0) # (1, 512, 512)
    y = y/255.0 # normalize the pixels
    y = np.expand_dims(y, axis=0) # (1, 1, 512, 512)
    y = y.astype(np.float32)
    y = torch.from_numpy(y)
    y = y.to(device)

    with torch.no_grad():
        # Prediction and Calculating FPS
        start_time = time.time()
        pred_y = model(x)
        pred_y = torch.sigmoid(pred_y)
        total_time = time.time() - start_time
        time_taken.append(total_time)

        # calculate the metrics
        score = calculate_metrics(y, pred_y)
        metrics_score = list(map(add, metrics_score, score))
        pred_y = pred_y[0].cpu().numpy() # (1, 512, 512)
        pred_y = np.squeeze(pred_y, axis=0) # (512, 512)
        pred_y = pred_y > 0.5
        pred_y = np.array(pred_y, dtype=np.uint8)

    # Saving masks
    original_mask = mask_parse(mask)
    pred_y = mask_parse(pred_y)
    line = np.ones((SIZE[1], 10, 3)) * 128

    concatenated_images = np.concatenate([image, line, original_mask, line, pred_y*255], axis=1)
    cv2.imwrite("data/retina_segmentation_data/results/{}.png".format(name), concatenated_images)

jaccard = metrics_score[0]/len(test_x)
f1 = metrics_score[1]/len(test_x)
recall = metrics_score[2]/len(test_x)
precision = metrics_score[3]/len(test_x)
accuracy = metrics_score[4]/len(test_x)
print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {accuracy:1.4f}")

fps = 1/np.mean(time_taken)
print("FPS: ", fps)    