In [None]:
import tensorflow as tf
from tensorflow.keras import layers

# Set the path to your dataset directory
data_dir = "path/to/your/dataset"  # <-- Replace with your dataset path

# Image dimensions and batch size
img_height = 224
img_width = 224
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

# Preprocessing function
def preprocess_image(image, label):
    image = tf.image.resize(image, [img_height, img_width])
    image = tf.cast(image, tf.float32) / 255.0  # Normalize to [0,1]
    return image, label

# Augmentation function
def augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    return image, label

# Load datasets
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode="binary"
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode="binary"
)

# Prepare datasets
train_ds = (
    train_ds
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .shuffle(100)
    .prefetch(AUTOTUNE)
)

val_ds = (
    val_ds
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

# Define model
model = tf.keras.Sequential([
    layers.InputLayer(input_shape=(img_height, img_width, 3)),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(1, activation='sigmoid')  # Binary classification
])

# Compile model with accuracy metric
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Train the model
model.fit(train_ds, validation_data=val_ds, epochs=10)

# Evaluate model on validation set
loss, accuracy = model.evaluate(val_ds)
print(f"Validation accuracy: {accuracy:.4f}")

# Save the model in .keras format
model.save("my_image_classifier.keras")
print("Model saved as my_image_classifier.keras")
