# UNet For Retina Blood Vessel Segmentation

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np

import os
from glob import glob

import gc
import yaml
from tqdm import tqdm

import matplotlib.pyplot as plt
from IPython.display import clear_output

import cv2

import torch
from torch.utils.data import DataLoader

from utils import init_random_seed, \
    get_params_number, train_eval_loop, \
    calculate_metrics, mask_parse, predict_with_model
from model import UNet
from data import RetinaDataset, augment_data
from loss import DiceLoss, DiceBCELoss

In [None]:
option_path = 'config.yml'
with open(option_path, 'r') as file_option:
    option = yaml.safe_load(file_option)

In [None]:
seed = option['seed']
init_random_seed(seed)

In [None]:
path_dataset = option['dataset']['path']
path_new_dataset = path_dataset + '/Augmented Data'

In [None]:
# X_train = sorted(glob(os.path.join(path_dataset, 'train', 'images', '*.tif')))
# y_train = sorted(glob(os.path.join(path_dataset, 'train', 'targets', '*.gif')))
# 
# X_test = sorted(glob(os.path.join(path_dataset, 'test', 'images', '*.tif')))
# y_test = sorted(glob(os.path.join(path_dataset, 'test', 'targets', '*.gif')))

In [None]:
size = option['dataset']['size_image']
size = (size, size)

# augment_data(X_train, y_train, path_new_dataset+'/train', size, augment=True)
# augment_data(X_test, y_test, path_new_dataset+'/test', size, augment=False)

In [None]:
X_train = sorted(glob(os.path.join(path_new_dataset, 'train', 'image', '*')))
y_train = sorted(glob(os.path.join(path_new_dataset, 'train', 'mask', '*')))

X_test = sorted(glob(os.path.join(path_new_dataset, 'test', 'image', '*')))
y_test = sorted(glob(os.path.join(path_new_dataset, 'test', 'mask', '*')))

In [None]:
train_dataset = RetinaDataset(X_train, y_train)
test_dataset = RetinaDataset(X_test, y_test)

In [None]:
len(train_dataset), len(test_dataset)

In [None]:
batch_size = option['network']['batch_size']
num_epochs = option['network']['num_epochs']
lr = float(option['network']['lr'])
device = option['network']['device']
scheduler_patience = option['network']['scheduler_patience']
early_stopping_patience = option['network']['early_stopping_patience']
save_path_model = option['network']['save_path_model']
num_workers_dataloader = option['dataset']['num_workers_dataloader']
loss_fn = option['network']['loss']

In [None]:
device = device if torch.cuda.is_available() else 'cpu'

In [None]:
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers_dataloader
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers_dataloader
)

In [None]:
model = UNet()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=scheduler_patience,
                                                       verbose=True)

if loss_fn == 'DiceLoss':
    loss_fn = DiceLoss()
else:
    loss_fn = DiceBCELoss()

In [None]:
get_params_number(model)

In [None]:
print("Number of unreachable objects collected by GC:", gc.collect())
torch.cuda.empty_cache()

In [None]:
# best_model, loss_train, loss_test = train_eval_loop(model,
#                                                     train_dataloader, test_dataloader,
#                                                     optimizer, loss_fn,
#                                                     num_epochs, device,
#                                                     early_stopping_patience,
#                                                     scheduler)

In [None]:
# clear_output(True)
# plt.plot(loss_train, label='Train loss')
# plt.plot(loss_test, label='Test loss')
# plt.legend(loc='upper right')
# plt.show()

In [None]:
# torch.save(best_model.state_dict(), save_path_model)

In [None]:
model.load_state_dict(torch.load(save_path_model, map_location=device))

## Testing the model

In [None]:
y_pred, y_true = predict_with_model(model,
                                    test_dataloader,
                                    device,
                                    use_sigmoid=True,
                                    return_labels=True)

In [None]:
metrics_score = calculate_metrics(y_true, y_pred)

In [None]:
print(f'Jaccard: {metrics_score[0]:1.4f}\n'
      f'F1: {metrics_score[1]:1.4f}\n'
      f'Recall: {metrics_score[2]:1.4f}\n'
      f'Precision: {metrics_score[3]:1.4f}\n'
      f'Accuracy: {metrics_score[4]:1.4f}\n')

print(f'Average of all: {sum(metrics_score) / len(metrics_score):1.3f}')

In [None]:
# for i, (x, y) in tqdm(enumerate(zip(X_test, y_test)), total=len(X_test)):
#     name = x.split('\\')[-1].split('.')[0]
#     
#     image = cv2.imread(x, cv2.IMREAD_COLOR)
#     mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
#     
#     pred_y = np.squeeze(y_pred[i], axis=0)
#     pred_y = pred_y > 0.5
#     pred_y = np.array(pred_y, dtype=np.uint8)
# 
#     ori_mask = mask_parse(mask)
#     pred_y = mask_parse(pred_y)
#     line = np.ones((pred_y.shape[0], 10, 3)) * 128
#     
#     cat_images = np.concatenate(
#         [image, line, ori_mask, line, pred_y * 255],
#         axis=1
#     )
# 
#     cv2.imwrite(f'./Data/Results-512x/{name}.png', cat_images)