In [None]:
# =======================
# IMPORTS
# =======================
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, GlobalAveragePooling2D, LeakyReLU, Conv2D, MaxPooling2D, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2, EfficientNetB0, ResNet50, VGG16
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_curve, auc
import numpy as np
import os
import random

In [None]:
# =======================
# CONFIGURATIONS
# =======================
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 15
TFLITE_PATH = "model.tflite"
SEED = 42
LOSS_FUNCTION = 'focal'  # 'binary_crossentropy' or 'focal'
ARCHITECTURE = 'CNN'     # 'MobileNetV2', 'EfficientNetB0', 'ResNet50', 'VGG16', 'CNN'
ACTIVATION = 'leaky_relu' # 'relu', 'leaky_relu', 'elu'
OPTIMIZER = 'rmsprop'    # 'adam', 'sgd', 'rmsprop'

# Set seeds for reproducibility
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Paths (update accordingly)
train_path = r"C:\Users\Adars\Desktop\dataset\train"
test_path = r"C:\Users\Adars\Desktop\dataset\test"

In [None]:
# =======================
# FUNCTIONS
# =======================
def focal_loss(gamma=2.0, alpha=0.25):
    def focal_loss_fn(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
        alpha_factor = K.ones_like(y_true) * alpha
        alpha_factor = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
        cross_entropy = -K.log(p_t)
        loss = alpha_factor * K.pow(1 - p_t, gamma) * cross_entropy
        return K.mean(K.sum(loss, axis=-1))
    return focal_loss_fn

def get_activation_function(name):
    if name == 'relu':
        return tf.keras.layers.ReLU()
    elif name == 'leaky_relu':
        return LeakyReLU(alpha=0.3)
    elif name == 'elu':
        return tf.keras.layers.ELU()
    else:
        return tf.keras.layers.ReLU()

def get_optimizer(name):
    if name == 'adam':
        return tf.keras.optimizers.Adam(learning_rate=1e-4)
    elif name == 'sgd':
        return tf.keras.optimizers.SGD(learning_rate=1e-4, momentum=0.9)
    elif name == 'rmsprop':
        return tf.keras.optimizers.RMSprop(learning_rate=1e-4)
    else:
        return tf.keras.optimizers.Adam(learning_rate=1e-4)

def build_model(architecture, activation_fn):
    if architecture == 'MobileNetV2':
        base_model = MobileNetV2(
            input_shape=(IMG_SIZE, IMG_SIZE, 3), 
            include_top=False, 
            weights='imagenet'
        )
    elif architecture == 'EfficientNetB0':
        base_model = EfficientNetB0(
            input_shape=(IMG_SIZE, IMG_SIZE, 3), 
            include_top=False, 
            weights='imagenet'
        )
    elif architecture == 'ResNet50':
        base_model = ResNet50(
            input_shape=(IMG_SIZE, IMG_SIZE, 3), 
            include_top=False, 
            weights='imagenet'
        )
    elif architecture == 'VGG16':
        base_model = VGG16(
            input_shape=(IMG_SIZE, IMG_SIZE, 3),
            include_top=False,
            weights='imagenet'
        )
    elif architecture == 'CNN':
        model = Sequential([
            Conv2D(32, (3,3), activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 3)),
            MaxPooling2D(2,2),
            Conv2D(64, (3,3), activation='relu'),
            MaxPooling2D(2,2),
            Conv2D(128, (3,3), activation='relu'),
            MaxPooling2D(2,2),
            Flatten(),
            Dense(256),
            activation_fn,
            Dropout(0.5),
            Dense(1, activation='sigmoid')
        ])
        return model
    else:
        raise ValueError(f"Invalid architecture: {architecture}")

    # For transfer learning models
    base_model.trainable = False
    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        Dense(128),
        activation_fn,
        BatchNormalization(),
        Dropout(0.5),
        Dense(1, activation='sigmoid')
    ])
    return model

def create_data_generators():
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        zoom_range=0.2,
        width_shift_range=0.1,
        height_shift_range=0.1,
        horizontal_flip=True,
        validation_split=0.2
    )

    test_datagen = ImageDataGenerator(rescale=1./255)

    train_gen = train_datagen.flow_from_directory(
        train_path,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='binary',
        subset='training',
        seed=SEED
    )

    val_gen = train_datagen.flow_from_directory(
        train_path,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='binary',
        subset='validation',
        seed=SEED
    )

    test_gen = test_datagen.flow_from_directory(
        test_path,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='binary',
        shuffle=False
    )

    return train_gen, val_gen, test_gen

In [None]:
# =======================
# TRAINING & EVALUATION
# =======================
if __name__ == "__main__":
    train_gen, val_gen, test_gen = create_data_generators()

    model = build_model(ARCHITECTURE, get_activation_function(ACTIVATION))
    
    # Select loss function
    if LOSS_FUNCTION == 'focal':
        loss = focal_loss()
        custom_objects = {'focal_loss_fn': focal_loss(), 'LeakyReLU': LeakyReLU}
    else:
        loss = 'binary_crossentropy'
        custom_objects = {'LeakyReLU': LeakyReLU}

    model.compile(
        optimizer=get_optimizer(OPTIMIZER),
        loss=loss,
        metrics=[
            'accuracy',
            tf.keras.metrics.Recall(name='recall'),
            tf.keras.metrics.Precision(name='precision')
        ]
    )

    callbacks = [
        EarlyStopping(monitor='val_recall', patience=5, mode='max', restore_best_weights=True),
        ModelCheckpoint('best_model.keras', save_best_only=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
    ]

    history = model.fit(
        train_gen,
        epochs=EPOCHS,
        validation_data=val_gen,
        callbacks=callbacks
    )

    model = tf.keras.models.load_model(
        'best_model.keras',
        custom_objects=custom_objects
    )

    # Evaluation
    test_results = model.evaluate(test_gen)
    print(f"\nTest Metrics:")
    print(f"Loss: {test_results[0]:.4f}")
    print(f"Accuracy: {test_results[1]:.4f}")
    print(f"Recall: {test_results[2]:.4f}")
    print(f"Precision: {test_results[3]:.4f}")

    # Visualization
    plt.figure(figsize=(20, 12))

    # Accuracy plot
    plt.subplot(2, 3, 1)
    plt.plot(history.history['accuracy'], label='Train')
    plt.plot(history.history['val_accuracy'], label='Validation')
    plt.title('Accuracy', fontsize=14)
    plt.legend()

    # Precision plot
    plt.subplot(2, 3, 2)
    plt.plot(history.history['precision'], label='Train')
    plt.plot(history.history['val_precision'], label='Validation')
    plt.title('Precision', fontsize=14)
    plt.legend()

    # Recall plot
    plt.subplot(2, 3, 3)
    plt.plot(history.history['recall'], label='Train')
    plt.plot(history.history['val_recall'], label='Validation')
    plt.title('Recall', fontsize=14)
    plt.legend()

    # Loss plot
    plt.subplot(2, 3, 4)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Validation')
    plt.title('Loss', fontsize=14)
    plt.legend()

    # Confusion Matrix
    plt.subplot(2, 3, 5)
    y_pred = (model.predict(test_gen) > 0.5).astype(int)
    cm = confusion_matrix(test_gen.classes, y_pred)
    ConfusionMatrixDisplay(cm, display_labels=test_gen.class_indices.keys()).plot(cmap='Blues')
    plt.title('Confusion Matrix', fontsize=14)

    # Precision-Recall Curve
    plt.subplot(2, 3, 6)
    y_probs = model.predict(test_gen)
    precision, recall, _ = precision_recall_curve(test_gen.classes, y_probs)
    plt.plot(recall, precision, label=f'AUC = {auc(recall, precision):.2f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # TFLite Conversion
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite_model = converter.convert()

    with open(TFLITE_PATH, 'wb') as f:
        f.write(tflite_model)

    print(f"\nTFLite model saved to: {TFLITE_PATH}")
    print("Conversion complete with SELECT_TF_OPS support!")

    # Model Summary
    print("\nFinal Model Architecture:")
    model.summary()