In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
from tensorflow.keras import Model
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping # type: ignore

from data_pipeline.preprocessing.data_processing import create_tf_dataset, load_and_preprocess_data
from training.learning_rate import PolyDecay
from deep_lab.model import DeepLabV3Plus


In [None]:
#####################
##  PREPROCESSING  ##
#####################

In [None]:

# Preprocess data paths
image_dir = r'D:\01_Arnaud\Etudes\04_CNAM\RCP209\Projet\DeepLab\data\VOCdevkit\VOC2012\JPEGImages'
mask_dir = r'D:\01_Arnaud\Etudes\04_CNAM\RCP209\Projet\DeepLab\data\VOCdevkit\VOC2012\SegmentationClass'

train_images, val_images, train_masks, val_masks = load_and_preprocess_data(image_dir, mask_dir)

# Create TensorFlow datasets
train_dataset = create_tf_dataset(train_images, train_masks)
val_dataset = create_tf_dataset(val_images, val_masks)

In [None]:
############################
## FINE-TUNING & TRAINING ##
############################

In [None]:
from tensorflow.keras.metrics import MeanIoU


# input_shape = (224, 224, 3)
# input = tf.keras.Input(shape=input_shape)
# model = Model(inputs=input, outputs=model.call(input))
# for layer in model.layers:
#     if not isinstance(layer, tf.keras.layers.BatchNormalization):
#         layer.trainable = False

# for layer in model.aspp.layers + model.decoder.layers:
#     layer.trainable = True

# model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy', MeanIoU(num_classes=NUM_CLASSES)])

# history = model.fit(train_dataset, validation_data=val_dataset, epochs=epochs)

for layer in model.layers:
    layer.trainable = True

for layers in model.backbone.resnet_model.layers:
    layer.trainable = False


In [None]:
from data_pipeline.preprocessing.data_processing import create_tf_dataset, load_and_preprocess_data
from deep_lab.model import DeepLabV3Plus
# from src.data.data_loader import load_data
from training.trainer import Trainer

# Instantiate model
model = DeepLabV3Plus()

# Training configuration
config = {
    'learning_rate': 0.001,
    'epochs': 5,
    'checkpoint_path': 'results/checkpoints/model.keras',
    'model_save_path': 'results/models/model.h5',
    'num_classes': 21
}

# Create the model
model = DeepLabV3Plus()
trainer = Trainer(model=model, train_dataset=train_dataset, val_dataset=val_dataset, config=config)
history = trainer.train()
trainer.evaluate()
trainer.save_model()

In [None]:
# Optional: save the model
model.save('results/my_model.h5')

In [None]:
##################
##  EVALUATION  ##
##################

In [None]:

model.evaluate(val_dataset)

In [None]:
########################
## INFERENCE EXEMPLE  ## 
########################

In [None]:
# Display result for one image
val_images, val_masks = next(iter(val_dataset))
image = val_images[0]
prediction = model.predict(tf.expand_dims(image, axis=0))
predicted_mask = tf.argmax(prediction, axis=-1)
predicted_mask = tf.squeeze(predicted_mask)

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(image)

plt.subplot(1, 3, 2)
plt.title("True Mask")
plt.imshow(tf.squeeze(val_masks[0]), cmap='gray')

plt.subplot(1, 3, 3)
plt.title("Predicted Mask")
plt.imshow(predicted_mask, cmap='gray')

plt.show()

In [None]:
## Evaluate on the offcial testing set