In [2]:
import albumentations as A
import cv2 as cv
import os
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import wandb
from albumentations.pytorch import ToTensorV2
from functools import partial

from modules import *
from networks import *
from training import *
from ROI import CenterNet, detect_roi, preprocess_centernet_input

## CenterNet

In [5]:
images_dir = '../data/ImagesForSegmentation'
test_images = [os.path.join(images_dir, f) for f in os.listdir(images_dir)]
test_images = list(filter(lambda x: os.path.isfile(x), test_images))
test_images

['../data/ImagesForSegmentation\\1082715_T_1082715_20151212_111331_Color_L_01.jpg',
 '../data/ImagesForSegmentation\\1085032_T_1085032_20151219_113812_Color_R_01.jpg',
 '../data/ImagesForSegmentation\\1107477_T_1107477_20160225_122700_Color_L_01.jpg',
 '../data/ImagesForSegmentation\\1119879_T_1119879_20160328_111556_Color_R_01.jpg',
 '../data/ImagesForSegmentation\\1120696_T_1120696_20160330_104233_Color_L_01.jpg',
 '../data/ImagesForSegmentation\\1205301_T_1205301_20161027_134554_Color_R_01.jpg',
 '../data/ImagesForSegmentation\\184303_T_184303_20170217_112031_Color_R_01.jpg',
 '../data/ImagesForSegmentation\\28877_T_28877_20151224_095119_Color_R_01.jpg',
 '../data/ImagesForSegmentation\\535502_T_535502_20170124_155432_Color_L_01.jpg',
 '../data/ImagesForSegmentation\\833620_T_833620_20160128_102311_Color_R_01.jpg']

In [36]:
MODEL_PATH = '../models/roi/centernet.pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = CenterNet(n_classes=1, scale=4, base='resnet18', custom=True)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
print('Model loaded from', MODEL_PATH)
model = model.to(DEVICE)


Model loaded from ../models/roi/centernet.pth


In [None]:
INPUT_SIZE = 512

transform = A.Compose([
    A.Resize(INPUT_SIZE, INPUT_SIZE, interpolation=cv.INTER_AREA),
    A.Normalize(mean=(0.9400, 0.6225, 0.3316), std=(0.1557, 0.1727, 0.1556)),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='coco', label_fields=['labels']))

dst_dir = os.path.join(images_dir, 'ROI')
if not os.path.exists(dst_dir):
    os.makedirs(dst_dir)

roi_test_images = []
for file in test_images:
    new_img = preprocess_centernet_input(file)
    roi_image, _ = detect_roi(
        model, new_img, None, transform, INPUT_SIZE,
        device=DEVICE, small_margin=32, roi_size=512,
    )
    img_path = os.path.join(dst_dir, os.path.basename(file))
    cv.imwrite(img_path, cv.cvtColor(roi_image, cv.COLOR_RGB2BGR))
    roi_test_images.append(img_path)
    # plt.imshow(roi_image)
    # plt.show()

## Dual

In [7]:
images_dir = '../data/ORIGA/ROI/TrainImages'
masks_dir = '../data/ORIGA/ROI/TrainMasks'
images = [os.path.join(images_dir, f) for f in os.listdir(images_dir)]
masks = [os.path.join(masks_dir, f) for f in os.listdir(masks_dir)]

IMAGE_SIZE = 256
val_transform = A.Compose([
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, interpolation=cv.INTER_AREA),
    A.CLAHE(p=1.0, clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True),
    ToTensorV2(),
])

loader = load_dataset(images, masks, val_transform, batch_size=4, shuffle=True, num_workers=4)

Loaded dataset with 325 samples in 82 batches.


In [10]:
checkpoint = load_checkpoint('../models/polar/ref/dual.pth', map_location=DEVICE)
model = checkpoint['model'].to(DEVICE)

=> Loading checkpoint: ../models/polar/ref/dual.pth


In [45]:
thresh = 0.5

fp_oc = 0
fn_oc = 0
fp_od = 0
fn_od = 0

# Make predictions
model.eval()
with torch.no_grad():
    for i, (images, masks) in enumerate(loader):
        images = images.to(DEVICE).float()
        masks = masks.to(DEVICE).long()

        preds, *_ = predict('dual', model, images, masks)
        met = get_metrics(masks, preds, [[1, 2], [2]])

        fp_oc += met['fp_OC']
        fn_oc += met['fn_OC']
        fp_od += met['fp_OD']
        fn_od += met['fn_OD']

print(f'FP OC: {fp_oc}, FN OC: {fn_oc}, FP OD: {fp_od}, FN OD: {fn_od}')

FP OC: 207015, FN OC: 185796, FP OD: 261919, FN OD: 127937


In [None]:
preds = []
save_dir = '../data/ImagesForSegmentation/DualArchitectureSegmentation'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

model.eval()
with torch.no_grad():
    for i, image in enumerate(roi_test_images):
        img = cv.imread(image)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        input_img = val_transform(image=img)['image']
        input_img = input_img.unsqueeze(0).to(DEVICE).float()

        pred, *_ = predict('dual', model, input_img, None)
        pred = pred.squeeze(0).squeeze(0).cpu().numpy()
        preds.append(pred)

        pred = fill_holes(pred)
        pred = keep_largest_component(pred)

        vcdr = calculate_vCDR(pred)

        contours = get_contour_image(img, pred)
        plt.figure(figsize=(10, 10))
        plt.imshow(contours)
        plt.title(f'vCDR: {vcdr:.4f}')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/{os.path.basename(image)}')
        plt.show()

## Cascade

In [85]:
checkpoint = load_checkpoint('./models/binary-raunet.pth')
binary_model = checkpoint['model'].to(DEVICE)

checkpoint = load_checkpoint('./models/cascade-raunet.pth')
model = checkpoint['model'].to(DEVICE)

=> Loading checkpoint: ./models/binary-raunet.pth
=> Loading checkpoint: ./models/cascade-raunet.pth


In [87]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
thresh = 0.5

fp_oc = 0
fn_oc = 0
fp_od = 0
fn_od = 0

# Make predictions
model.eval()
binary_model.eval()
with torch.no_grad():
    for i, (images, masks) in enumerate(loader):
        images = images.to(device).float()
        masks = masks.to(device).long()

        preds, *_ = predict('cascade', model, images, masks, model0=binary_model)
        met = get_metrics(masks, preds, [[1, 2], [2]])

        fp_oc += met['fp_OC']
        fn_oc += met['fn_OC']
        fp_od += met['fp_OD']
        fn_od += met['fn_OD']

print(f'FP OC: {fp_oc}, FN OC: {fn_oc}, FP OD: {fp_od}, FN OD: {fn_od}')

FP OC: 154049, FN OC: 139173, FP OD: 159846, FN OD: 113973


In [None]:
preds = []
save_dir = '../data/ImagesForSegmentation/CascadeArchitectureSegmentation'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

model.eval()
binary_model.eval()
with torch.no_grad():
    for i, image in enumerate(roi_test_images):
        img = cv.imread(image)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        input_img = val_transform(image=img)['image']
        input_img = input_img.unsqueeze(0).to(device).float()

        pred, *_ = predict('cascade', model, input_img, None, model0=binary_model)
        pred = pred.squeeze(0).squeeze(0).cpu().numpy()
        preds.append(pred)

        pred = fill_holes(pred, False)
        pred = keep_largest_component(pred, False)

        vcdr = calculate_vCDR(pred)

        contours = get_contour_image(img, pred, colors=[(0, 0, 0), (0, 0, 255)])
        plt.figure(figsize=(10, 10))
        plt.imshow(contours)
        plt.title(f'vCDR: {vcdr:.4f}')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/{os.path.basename(image)}')
        plt.show()