# AI vs Real Image Detector Training Notebook

This notebook will guide you through training a Convolutional Neural Network (CNN) to detect AI-generated images.

## 1. Setup & Dataset
We will automatically download the **CIFAKE** dataset using `kagglehub`. This dataset contains labeled Real and AI images perfect for our task.

### Instructions:
Just run the cells below! The dataset will be downloaded for you.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
import kagglehub
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import pathlib

# Automatic Dataset Download
print("Downloading CIFAKE dataset...")
path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")
print("Path to dataset files:", path)

# Check for GPU availability
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# Define parameters
batch_size = 32
img_height = 32
img_width = 32

# Path to the dataset (Using the downloaded path)
data_dir = pathlib.Path(path) / 'train'

if not data_dir.exists():
    print("WARNING: Dataset directory not found at {}. Please compare with instructions above.".format(data_dir.resolve()))
else:
    print("Dataset directory found at:", data_dir)

# Load Training Data
train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

# Load Validation Data
val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names
print("Class names:", class_names)

In [None]:
# Visualize the data
try:
    plt.figure(figsize=(10, 10))
    for images, labels in train_ds.take(1):
      for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
except Exception as e:
    print("Could not visualize data (dataset might be missing):", e)

## 2. Build the Model
We will use a Sequential model with three Convolutional blocks followed by a Dense layer.

In [None]:
num_classes = len(class_names)

model = Sequential([
  layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.summary()

In [None]:
# Train the model
epochs = 10
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

In [None]:
# 4. Visualize Training Results
try:
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs_range = range(epochs)

    plt.figure(figsize=(12, 12))
    plt.subplot(2, 1, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(2, 1, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()
except NameError:
    print("Training history not found. Train the model first!")

## 3. Save the Model
Once training is complete, we save the model so our Web App can use it.

In [None]:
model.save('ai_detector_model.h5')
print("Model saved as ai_detector_model.h5")