In [1]:
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from Steganalysis.config.configuration import ConfigurationManager

In [2]:
config_manager = ConfigurationManager()
model_training_config = config_manager.get_model_training_config()

[2024-11-09 17:26:36,897: INFO: common: yaml file: E:\Projects\Steganalysis\config\config.yaml loaded successfully]
[2024-11-09 17:26:36,910: INFO: common: yaml file: E:\Projects\Steganalysis\params.yaml loaded successfully]
[2024-11-09 17:26:36,910: INFO: common: created directory at: artifacts]
[2024-11-09 17:26:36,918: INFO: common: created directory at: artifacts/training]


In [4]:
model = load_model(model_training_config.model_path)
model.summary()

ValueError: File not found: filepath=artifacts\prepare_base_model\base_model_updated.keras. Please ensure the file is an accessible `.keras` zip file.

In [None]:
data_gen = ImageDataGenerator(rescale=1.0/255.0, validation_split=0.2)

In [None]:
data_dir = "artifacts/data_ingestion/unzip_dir" 

train_generator = data_gen.flow_from_directory(
    directory=data_dir,
    target_size=model_training_config.params_image_size[:2],
    batch_size=model_training_config.batch_size,
    class_mode="binary",
    subset="training"
)

val_generator = data_gen.flow_from_directory(
    directory=data_dir,
    target_size=model_training_config.params_image_size[:2],
    batch_size=model_training_config.batch_size,
    class_mode="binary",
    subset="validation"
)


callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=os.path.join(model_training_config.callbacks_dir, "tensorboard_log")),
    tf.keras.callbacks.ModelCheckpoint(filepath=model_training_config.model_path, save_best_only=True)
]


history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=model_training_config.epochs,
    callbacks=callbacks
)

In [None]:

# Step 7: Display Results (Optional)
import matplotlib.pyplot as plt

# Plot accuracy and loss curves
plt.figure(figsize=(14, 5))

# Plot training & validation accuracy values
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.show()
