In [1]:
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.models import Model

def load_data(data_folder, target_size=(128, 128)):
    images = []
    masks = []
    names = []

    for root, _, files in os.walk(data_folder):
        for file in files:
            if file.endswith('.json'):
                json_file_path = os.path.join(root, file)
                with open(json_file_path, 'r') as jsonfile:
                    json_data = json.load(jsonfile)
                    for img in json_data:
                        val = json_data[img]
                        image_path = os.path.join(root, val['filename'])
                        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
                        if image is not None:
                            # Resize the image
                            resized_image = cv2.resize(image, target_size)
                            # Create a mask
                            mask = np.zeros(target_size, dtype=np.uint8)
                            for region in val['regions']:
                                shape_attributes = region.get('shape_attributes', {})
                                region_attributes = region.get('region_attributes', {})
                                if region_attributes.get('Eye') in ['iris', 'pupil']:
                                    x_points = shape_attributes.get('all_points_x', [])
                                    y_points = shape_attributes.get('all_points_y', [])
                                    points = np.array(list(zip(x_points, y_points)), dtype=np.int32)
                                    if region_attributes.get('Eye') == 'iris':
                                        cv2.fillPoly(mask, [points], 1)  # Iris label
                                    elif region_attributes.get('Eye') == 'pupil':
                                        cv2.fillPoly(mask, [points], 2)  # Pupil label
                            # Append the image and mask
                            images.append(resized_image)
                            masks.append(mask)
                            names.append(val['filename'])
                        else:
                            print(f"Warning: Unable to read image {image_path}")

    # Convert lists to NumPy arrays
    images = np.array(images)
    masks = np.array(masks)
    names = np.array(names)

    return images, masks, names

def unet_model(input_size=(128, 128, 1)):
    inputs = Input(input_size)
    
    # Encoder
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)
    
    # Decoder
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
    
    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
    
    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)
    
    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
    
    conv10 = Conv2D(3, 1, activation='softmax')(conv9)
    
    model = Model(inputs=[inputs], outputs=[conv10])
    
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

def display_images(image, mask, masked_image):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='gray')
    plt.title('Iris Mask')
    plt.subplot(1, 3, 3)
    plt.imshow(masked_image, cmap='gray')
    plt.title('Masked Image')
    plt.show()


In [2]:
data_folder = "data"
images, masks, names = load_data(data_folder)

error: OpenCV(4.9.0) /Users/xperience/GHA-OpenCV-Python2/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/drawing.cpp:2432: error: (-215:Assertion failed) p.checkVector(2, CV_32S) >= 0 in function 'fillPoly'


In [None]:


# Split the data into training and testing sets
images_train, images_test, masks_train, masks_test, names_train, names_test = train_test_split(
    images, masks, names, test_size=0.2, random_state=42
)

# Expand dimensions to add the channel dimension
images_train = np.expand_dims(images_train, axis=-1)
images_test = np.expand_dims(images_test, axis=-1)

# Build the model
model = unet_model()

# Train the model
history = model.fit(images_train, masks_train, epochs=20, batch_size=16, validation_data=(images_test, masks_test))

# Save the model
model.save('segmentation_model.h5')

# Evaluate the model
loss, accuracy = model.evaluate(images_test, masks_test)
print(f'Test loss: {loss}')
print(f'Test accuracy: {accuracy}')

# Predict on a few test images
predictions = model.predict(images_test[:5])

# Plot the results
for i in range(5):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 3, 1)
    plt.title('Input Image')
    plt.imshow(images_test[i].squeeze(), cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('True Mask')
    plt.imshow(masks_test[i], cmap='gray')
    
    plt.subplot(1, 3, 3)
    plt.title('Predicted Mask')
    plt.imshow(np.argmax(predictions[i], axis=-1), cmap='gray')
    
    plt.show()