In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

# Constants
IMG_SIZE = 256
BATCH_SIZE = 32
CHANNELS = 3
INPUT_SHAPE = (IMG_SIZE, IMG_SIZE, CHANNELS)
EPOCHS = 100  # Increased epochs

# Transfer Learning - MobileNetV2
base_model = tf.keras.applications.MobileNetV2(
    input_shape=INPUT_SHAPE, 
    include_top=False, 
    weights='imagenet'
)
base_model.trainable = False  # Freeze base model

# Define Model with Augmentation & Batch Normalization
model = models.Sequential([
    layers.Rescaling(1./255, input_shape=INPUT_SHAPE),
    
    # Data Augmentation
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.2),  # Added contrast adjustment
    layers.RandomBrightness(0.2),  # Added brightness variation
    
    base_model,  # Transfer learning base
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),  # Added Batch Normalization
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.3),  # Regularization
    layers.Dense(1, activation='sigmoid')
])

# Learning Rate Schedule (Cosine Decay)
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=0.001,  # Slightly higher LR at start
    decay_steps=EPOCHS * 36,  # Total steps (epochs * batches)
    alpha=0.0001  # Minimum LR
)

# Compile Model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['accuracy']
)

# Early Stopping Callback
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,  # Stop if no improvement for 10 epochs
    restore_best_weights=True
)

# Reduce LR on Plateau
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    verbose=1
)

# Train Model
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[early_stopping, lr_scheduler]
)
