In [7]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
import matplotlib.pyplot as plt
import os
import keras.backend as K
from keras.losses import binary_crossentropy
from tensorflow.keras.preprocessing.image import load_img, img_to_array

tf.config.run_functions_eagerly(True)

print(tf.config.list_physical_devices('GPU'))

def extract_channel(mask, channel_index):
    # Assuming 'mask' is a tensor with shape (height, width, channels)
    extracted_channel = mask[:, :, channel_index+1]

    # Add an extra dimension to make it (height, width, 1)
    extracted_channel = tf.expand_dims(extracted_channel, axis=-1)

    return extracted_channel

def load_and_preprocess_image(image_path, mask_path):
    image_path = image_path.numpy().decode('utf-8')
    mask_path = mask_path.numpy().decode('utf-8')
    # Load and preprocess image
    image = load_img(image_path, target_size=(768, 768))
    image = img_to_array(image)/255.0  # Normalize to [0, 1]

    # Load and preprocess mask
    mask = load_img(mask_path,color_mode='grayscale',  target_size=(768, 768))
    mask = img_to_array(mask)/255.0  # Normalize to [0, 1]
    
    return image, mask

def _parse_function(image_path, mask_path):
   # print(mask_path)
    image_string = tf.io.read_file(image_path)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    #image_decoded = tf.image.resize(image_decoded, [128, 128])
    image = tf.cast(image_decoded, tf.float32)/255.0
    mask_string = tf.io.read_file(mask_path)
    mask_decoded = tf.image.decode_jpeg(mask_string, channels=3)
    #plt.imshow(mask_decoded)
    #plt.show()
    mask_decoded = tf.image.rgb_to_grayscale(mask_decoded)
    #print(mask_decoded.numpy().tolist())
    #mask_decoded = tf.image.resize(mask_decoded, [128, 128])
   # plt.imshow(mask_decoded)
    #plt.show()
   # mask_decoded = extract_channel(mask_decoded,0)
    mask = tf.cast(mask_decoded, tf.float32)//255
    #print()
    #print(mask.numpy().tolist())
    return image, mask


#Define the U-Net architecture
def unet_model(input_shape=(768, 768, 3)):
    inputs = Input(shape=input_shape)

   # Encoder
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D((2, 2))(conv2)

    # Bottom
    conv3 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv3)

    # Decoder
    up1 = UpSampling2D((2, 2))(conv3)
    concat1 = concatenate([conv2, up1], axis=-1)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(concat1)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)

    up2 = UpSampling2D((2, 2))(conv4)
    concat2 = concatenate([conv1, up2], axis=-1)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(concat2)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv5)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv5)  # Assuming binary segmentation

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# def unet_model(input_shape=(768, 768, 3)):
#     inputs = Input(shape=input_shape)

#     # Encoder
#     conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(inputs)
#     conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv1)
#     pool1 = MaxPooling2D((2, 2))(conv1)

#     conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
#     conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
#     pool2 = MaxPooling2D((2, 2))(conv2)

#     # Bottleneck
#     conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(pool2)
#     conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv3)

#     # Decoder
#     up1 = UpSampling2D((2, 2))(conv3)
#     concat1 = concatenate([conv2, up1], axis=-1)
#     conv4 = Conv2D(16, (3, 3), activation='relu', padding='same')(concat1)
#     conv4 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv4)

#     up2 = UpSampling2D((2, 2))(conv4)
#     concat2 = concatenate([conv1, up2], axis=-1)
#     conv5 = Conv2D(8, (3, 3), activation='relu', padding='same')(concat2)
#     conv5 = Conv2D(8, (3, 3), activation='relu', padding='same')(conv5)

#     outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv5)  # Assuming binary segmentation

#     model = tf.keras.Model(inputs=inputs, outputs=outputs)
#     return model

# def unet_model(input_shape=(768, 768, 3)):
#     inputs = Input(shape=input_shape)

#     # Encoder
#     conv1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
#     conv1 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv1)
#     pool1 = MaxPooling2D((2, 2))(conv1)

#     # Bottleneck
#     conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(pool1)
#     conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv2)

#     # Decoder
#     up1 = UpSampling2D((2, 2))(conv2)
#     concat1 = concatenate([conv1, up1], axis=-1)
#     conv3 = Conv2D(16,(3, 3), activation='relu', padding='same')(concat1)
#     conv3 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv3)

#     outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv3)

#     model = tf.keras.Model(inputs=inputs, outputs=outputs)
#     return model

threshold = 0.5

# def dice_coef(y_true, y_pred, smooth=1e-6):
#     #print(y_true.numpy)
#     #print(y_pred.numpy())
#     y_true_f = tf.cast(K.flatten(y_true),'float32')
#     y_pred_f = tf.cast(K.flatten(y_pred), 'float32')
#     #tf.keras.backend.clear_session()
#     y_pred_f = tf.where(y_pred_f < threshold, 0.0, y_pred_f )
#     #tf.keras.backend.clear_session()
#     y_pred_f = tf.where(y_pred_f >= threshold, 1.0, y_pred_f )
#     #print(y_pred_f.shape)
#     #print(y_true_f.shape)
#     #print(y_true.numpy().tolist())
#     #y_pred_f = tf.round(y_pred_f)
#     intersection = K.sum(y_true_f * y_pred_f)
#     #print(intersection)
#     #print(intersection.numpy())
#     #print(intersection.numpy())
#     dice = (2 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
#     return dice

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    dice = (2 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return dice

def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

path = '.'
# train_dataset = tf.keras.utils.image_dataset_from_directory(
#     path+'/processed_data_train/train',
#     image_size=(768, 768),  # Set the desired image size
#     seed=123,  # Set a seed for reproducibility (optional)
#     labels="inferred",
#     color_mode='rgb',
#     label_mode=None,  # No labels as we are loading masks
#     image_folder='images',
#     mask_folder='masks',
#     interpolation="nearest"
# )

# validation_dataset = tf.keras.utils.image_dataset_from_directory(
#     path+'/processed_data_train/validation',
#     image_size=(768, 768),  # Set the desired image size
#     seed=123,  # Set a seed for reproducibility (optional)
#     labels="inferred",
#     color_mode='rgb',
#     label_mode=None,  # No labels as we are loading masks
#     image_folder='images',
#     mask_folder='masks',
#     interpolation="nearest"
# )

def getDs(path_img,path_masks,cnt):
    
    image_paths = []
    mask_paths = []
    for i in range(1,cnt+1):
        image_paths.append(path_img+'/image'+str(i)+'.jpg')
        mask_paths.append(path_masks+'/mask'+str(i)+'.jpg')
    
    # Create a dataset from slices
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))

    # Map the load_and_preprocess_image function to load and preprocess each image and mask
    dataset = dataset.map(lambda x, y: tf.py_function(_parse_function, [x, y], [tf.float32, tf.float32]))
    
    # Shuffle and batch the dataset
    dataset = dataset.shuffle(buffer_size=len(image_paths)).batch(batch_size=2).prefetch(buffer_size=tf.data.AUTOTUNE) 
    return dataset

# Instantiate the model
model = unet_model()

optimizer = keras.optimizers.Adam(lr=1e-5, clipvalue=1.0)

# Compile the model
model.compile(optimizer=optimizer,  loss=dice_coef_loss)

# Display the model summary
model.summary()

AUTOTUNE = tf.data.AUTOTUNE
train_dataset = getDs(path+'/processed_data_train/train/images',path+'/processed_data_train/train/masks',100)

#_parse_function(path+'/processed_data_train/train/images/image2.jpg',path+'/processed_data_train/train/masks/mask2.jpg')

validation_dataset = getDs(path+'/processed_data_train/validation/images',path+'/processed_data_train/validation/masks',25)
#train
print(len(train_dataset))

[]




Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 768, 768, 3)]        0         []                            
                                                                                                  
 conv2d_33 (Conv2D)          (None, 768, 768, 128)        3584      ['input_4[0][0]']             
                                                                                                  
 conv2d_34 (Conv2D)          (None, 768, 768, 128)        147584    ['conv2d_33[0][0]']           
                                                                                                  
 max_pooling2d_6 (MaxPoolin  (None, 384, 384, 128)        0         ['conv2d_34[0][0]']           
 g2D)                                                                                       

In [None]:
BATCH = 2
STEPS_PER_EPOCH = 1000//BATCH
VALIDATION_STEPS = 250//BATCH

history = model.fit(train_dataset,validation_data = (validation_dataset),epochs=10)
model.save('./model.hdf5')