In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import os
import numpy as np

# Set the directories for train, validation, and test
train_dir = '/Users/kshitijverma/Downloads/Dataset/Train'
val_dir = '/Users/kshitijverma/Downloads/Dataset/Validation'
test_dir = '/Users/kshitijverma/Downloads/Dataset/Test'

# Image parameters
img_size = (224, 224)  # Resize images to (224, 224) for MobileNetV2
batch_size = 32

# Define function to limit the number of images in each folder and exclude non-image files
def limit_images(directory, max_images_per_class):
    images = []
    labels = []
    class_names = os.listdir(directory)

    for class_name in class_names:
        class_dir = os.path.join(directory, class_name)
        if not os.path.isdir(class_dir):
            continue
        
        image_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        image_files = image_files[:max_images_per_class]  # Limit to max_images_per_class
        
        for image_file in image_files:
            image_path = os.path.join(class_dir, image_file)
            images.append(image_path)
            labels.append(class_name)
    
    return images, labels


max_images_per_class = 10000
train_images, train_labels = limit_images(train_dir, max_images_per_class)

# Prepare ImageDataGenerators
train_datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize pixel values between 0 and 1
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# Convert the image paths to a tf.data.Dataset
def create_data_generator(image_paths, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))

    def process_image(image_path, label):
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, img_size)
        img = img / 255.0  # Normalize
        label = 1 if label == 'Real' else 0  # Convert 'Real' to 1 and 'Fake' to 0
        return img, label

    dataset = dataset.map(process_image).batch(batch_size).shuffle(buffer_size=1000)
    return dataset


train_dataset = create_data_generator(train_images, train_labels, batch_size)


val_data = val_datagen.flow_from_directory(
    val_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    classes=['Fake', 'Real']
)

test_data = test_datagen.flow_from_directory(
    test_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    classes=['Fake', 'Real']
)


base_model = MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base_model.trainable = False  # Freeze base layers

# Create the model
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(1, activation='sigmoid')
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

# Set up EarlyStopping and ModelCheckpoint callbacks
early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('new_model.h5', monitor='val_accuracy', save_best_only=True)

# Train the model with early stopping and model checkpoint
model.fit(
    train_dataset,
    epochs=50,  # Train for a large number of epochs, but early stopping will halt if necessary
    validation_data=val_data,
    callbacks=[early_stopping, model_checkpoint]
)

# Evaluate the model on the test data
loss, accuracy = model.evaluate(test_data)
print(f"Test accuracy: {accuracy:.4f}")


