In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
from tensorflow.keras import Model
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping # type: ignore

from data_pipeline.preprocessing.data_processing import create_tf_dataset, load_and_preprocess_data
from deep_lab.learning_rate import PolyDecay
from deep_lab.metrics import MeanIoU
from deep_lab.model import DeepLabV3Plus


In [3]:
#####################
##  PREPROCESSING  ##
#####################

In [4]:

# Preprocess data paths
image_dir = r'D:\01_Arnaud\Etudes\04_CNAM\RCP209\Projet\DeepLab\data\VOCdevkit\VOC2012\JPEGImages'
mask_dir = r'D:\01_Arnaud\Etudes\04_CNAM\RCP209\Projet\DeepLab\data\VOCdevkit\VOC2012\SegmentationClass'

train_images, val_images, train_masks, val_masks = load_and_preprocess_data(image_dir, mask_dir)

# Create TensorFlow datasets
train_dataset = create_tf_dataset(train_images, train_masks)
val_dataset = create_tf_dataset(val_images, val_masks)

In [None]:
############################
## FINE-TUNING & TRAINING ##
############################

In [16]:
initial_lr = 0.001
epochs = 50
poly_decay = PolyDecay(initial_learning_rate=initial_lr, max_epochs=epochs)
optimizer = tf.keras.optimizers.Adam(learning_rate=poly_decay, weight_decay=0.0005)

model = DeepLabV3Plus(dropout_rate=0.3)
# input_shape = (224, 224, 3)
# input = tf.keras.Input(shape=input_shape)
# model = Model(inputs=input, outputs=model.call(input))
# for layer in model.layers:
#     if not isinstance(layer, tf.keras.layers.BatchNormalization):
#         layer.trainable = False

# for layer in model.aspp.layers + model.decoder.layers:
#     layer.trainable = True

# model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy', MeanIoU(num_classes=NUM_CLASSES)])

# history = model.fit(train_dataset, validation_data=val_dataset, epochs=epochs)

for layer in model.layers:
    layer.trainable = True

model.compile(optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy', MeanIoU()])


KeyboardInterrupt: 

In [None]:

# Définition de l'arrêt anticipé (early stopping)
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

# Train the model
#history = model.fit(train_dataset, validation_data=val_dataset, epochs=5, callbacks=[early_stopping])
history = model.fit(train_dataset, validation_data=val_dataset, epochs=5)

Epoch 1/5




[1m  1/292[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m5:54:56[0m 73s/step - accuracy: 0.0257 - loss: 3.8100 - mean_iou: 0.0073

KeyboardInterrupt: 

In [None]:
# Optional: save the model
model.save('results/my_model.h5')

In [None]:
IMG_SIZE = 224
NUM_CLASSES = 21
input = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))    

deeplab_model_training = DeepLabV3Plus()
output = deeplab_model_training(input, training=True)

#deeplab_model_training = Model(inputs=input, outputs=output)
deeplab_model_training.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
deeplab_model_training.fit(train_dataset, validation_data=val_dataset, epochs=5, callbacks=[early_stopping])

In [None]:
##################
##  EVALUATION  ##
##################

In [None]:

model.evaluate(val_dataset)

In [None]:
########################
## INFERENCE EXEMPLE  ## 
########################

In [None]:
# Display result for one image
val_images, val_masks = next(iter(val_dataset))
image = val_images[0]
prediction = model.predict(tf.expand_dims(image, axis=0))
predicted_mask = tf.argmax(prediction, axis=-1)
predicted_mask = tf.squeeze(predicted_mask)

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(image)

plt.subplot(1, 3, 2)
plt.title("True Mask")
plt.imshow(tf.squeeze(val_masks[0]), cmap='gray')

plt.subplot(1, 3, 3)
plt.title("Predicted Mask")
plt.imshow(predicted_mask, cmap='gray')

plt.show()

In [None]:
## Evaluate on the offcial testing set