In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from sklearn.model_selection import train_test_split




In [2]:
# Function to define the SegNet model
def segnet(input_shape=(256, 256, 1), num_classes=1):
    model = models.Sequential()

    # Encoder
    model.add(layers.Conv2D(64, (3, 3), padding='same', input_shape=input_shape))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(layers.Conv2D(128, (3, 3), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(layers.Conv2D(256, (3, 3), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(layers.Conv2D(512, (3, 3), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))

    # Decoder
    model.add(layers.Conv2D(512, (3, 3), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))
    model.add(layers.UpSampling2D(size=(2, 2)))

    model.add(layers.Conv2D(256, (3, 3), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))
    model.add(layers.UpSampling2D(size=(2, 2)))

    model.add(layers.Conv2D(128, (3, 3), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Activation('relu'))
    model.add(layers.UpSampling2D(size=(2, 2)))

    model.add(layers.Conv2D(num_classes, (3, 3), padding='same'))
    model.add(layers.Activation('sigmoid'))

    return model

In [3]:
# Data loading and preprocessing
train_dir = './train/'
X_train1 = []
y_train1 = []

for folder in os.listdir(train_dir):
    img = cv2.imread(os.path.join(train_dir, folder, 'images', folder+'.jpeg'), cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (256, 256))
    X_train1.append(img)
    
    mask_name = os.listdir(os.path.join(train_dir, folder, 'masks'))[0]
    mask = cv2.imread(os.path.join(train_dir, folder, 'masks', mask_name), cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (256, 256))
    y_train1.append(mask)

X_train = np.array(X_train1).reshape(-1, 256, 256, 1)
y_train = np.array(y_train1).reshape(-1, 256, 256, 1)

In [None]:
# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

# Model compilation
model = segnet()
model.compile(optimizer=optimizers.Adam(lr=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# Define callbacks for early stopping and model checkpoint
early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
model_checkpoint = callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)

# Model training with callbacks
history = model.fit(X_train, y_train, epochs=50, batch_size=8, validation_data=(X_val, y_val), callbacks=[early_stopping, model_checkpoint])

# Evaluate the model on the validation set
val_loss, val_accuracy = model.evaluate(X_val, y_val)
print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

# Data loading for testing
test_dir = './test/'
X_test = []

for folder in os.listdir(test_dir):
    img = cv2.imread(os.path.join(test_dir, folder, 'images', folder+'.jpeg'), cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (256, 256))
    X_test.append(img)

X_test = np.array(X_test).reshape(-1, 256, 256, 1)

# Model prediction on test data
predictions = model.predict(X_test)

# You can use the predictions for further analysis or visualization
# For example, saving the predicted masks
for i, folder in enumerate(os.listdir(test_dir)):
    pred_mask = predictions[i].reshape(256, 256) * 255
    cv2.imwrite(f'./predictions/{folder}_pred_mask.jpeg', pred_mask)







Epoch 1/50












  1/110 [..............................] - ETA: 34:36 - loss: 0.4854 - accuracy: 0.4986