In [None]:
import cv2
import numpy as np
import os

# Load images and masks
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder, filename))
        if img is not None:
            images.append(img)
    return images

def load_masks_from_folder(folder):
    masks = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder, filename), cv2.IMREAD_GRAYSCALE)
        if img is not None:
            masks.append(img)
    return masks

# Paths to dataset folders
image_dir = './data/DRIVE/test/images/'
mask_dir = './data/DRIVE/test/1st_manual/'

images = load_images_from_folder(image_dir)
masks = load_masks_from_folder(mask_dir)

# Example function to preprocess images and masks
def preprocess_data(images, masks):
    resized_images = [cv2.resize(img, (256, 256)) for img in images]
    resized_masks = [cv2.resize(mask, (256, 256)) for mask in masks]
    normalized_images = [img / 255.0 for img in resized_images]
    binarized_masks = [mask / 255.0 for mask in resized_masks]
    return np.array(normalized_images), np.array(binarized_masks)

X_train, y_train = preprocess_data(images, masks)


In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, UpSampling2D

# Define U-Net model
def unet_model(input_size=(256, 256, 3)):
    inputs = Input(input_size)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)

    up5 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv3], axis=3)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(up5)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv2], axis=3)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv1], axis=3)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv7)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv7)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# Compile the model
model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])


In [None]:
# Train the U-Net model
history = model.fit(X_train, y_train, batch_size=16, epochs=50, validation_split=0.1)

# Save the trained model
model.save('unet_retinal_segmentation_model.h5')


In [None]:
def enhance_image(image):
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)

    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    cl = clahe.apply(l)

    limg = cv2.merge((cl, a, b))
    enhanced_image = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)

    return enhanced_image

# Predict using the trained U-Net model
def predict_vessels(image, model):
    image = enhance_image(image)
    image = cv2.resize(image, (256, 256)) / 255.0
    image = np.expand_dims(image, axis=0)
    pred_mask = model.predict(image)
    pred_mask = np.squeeze(pred_mask, axis=0)
    pred_mask = (pred_mask > 0.5).astype(np.uint8) * 255
    return pred_mask

# Example prediction
test_image = cv2.imread('./data/DRIVE/test/images/01_test.png')
predicted_mask = predict_vessels(test_image, model)

# Display result
cv2.imshow('Predicted Mask', predicted_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()
