# Cell 1: Imports & Configuration

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib.pyplot as plt
import numpy as np
import os

# --- Configuration ---
# Update this path to your specific folder
DATA_PATH = r"D:\Machine Learning project\Natural Images Project\processed_images"

ORIGINAL_WIDTH = 256
ORIGINAL_HEIGHT = 256
BATCH_SIZE = 32
SEED = 123

print("TensorFlow Version:", tf.__version__)

# Cell 2: Data Loading & Splitting

In [None]:
# --- Load the Full Dataset ---
dataset = image_dataset_from_directory(
    DATA_PATH,
    seed=SEED,
    image_size=(ORIGINAL_HEIGHT, ORIGINAL_WIDTH),
    batch_size=BATCH_SIZE,
    color_mode='rgb',
    shuffle=True
)

class_names = dataset.class_names
class_num = len(class_names)
print("Class Names:", class_names)
print("Number of Classes:", class_num)

# --- Define Split Function ---
def get_dataset_partitions(ds, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000):
    ds_size = len(ds)
    
    if shuffle:
        ds = ds.shuffle(shuffle_size, seed=SEED)
    
    train_size = int(ds_size * train_split)
    val_size = int(ds_size * val_split)
    
    train_data = ds.take(train_size)
    val_data = ds.skip(train_size).take(val_size)
    test_data = ds.skip(train_size).skip(val_size)
    
    return train_data, val_data, test_data

# --- Create Partitions ---
train_data, val_data, test_data = get_dataset_partitions(dataset, train_split=0.7, val_split=0.15, test_split=0.15)

print(f"Training batches: {len(train_data)}")
print(f"Validation batches: {len(val_data)}")
print(f"Testing batches: {len(test_data)}")

# Cell 3 visualizing the train_data

In [None]:
fig, axs = plt.subplots(8, 4, figsize=(12,24))

# Flatten the axs array for easier iteration if needed, or use nested loops
axs_flat = axs.flatten()

# 1. Take one batch from the dataset
for image_batch, label_batch in train_data.take(1):

    # Plot on each subplot
    for i, ax in enumerate(axs_flat):
        ax.imshow(image_batch[i].numpy().astype("uint8"))
        ax.set_title(class_names[label_batch[i]])
        ax.grid(True) # Add a grid for better visualization
        #ax.axis("off")


    plt.tight_layout()
    plt.show()


# Cell 4: Performance Optimization

In [None]:
# --- Optimization ---
# Cache keeps images in memory after the first epoch.
# Prefetch prepares the next batch while the GPU is working on the current one.

# Shuffle ONLY training data
train_data = train_data.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)

# Do NOT shuffle validation/test data (keeps evaluation consistent)
val_data = val_data.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
test_data = test_data.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# Cell 5: Preprocessing & Augmentation (Added This)

In [None]:
# --- ADDED: Define Rescaling and Augmentation ---

# MobileNetV2 expects inputs in range [-1, 1].
# This layer maps 0->-1 and 255->1.
rescale = layers.Rescaling(1./127.5, offset=-1)

# Specific data augmentation to prevent overfitting.
# These layers are only active during training, not prediction.
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.06),
    layers.RandomZoom(0.06),
])

# Cell 6: Model Definition

In [None]:
input_img_shape = (original_height, original_width, 3)

# --- 1. Load the Pre-trained Base Model ---
# We use MobileNetV2, but you could also use EfficientNetB0, ResNet50, etc.
base_model = keras.applications.MobileNetV2(
    input_shape=input_img_shape,
    include_top=False,  # Do not include the final classifier (1000 classes)
    weights='imagenet'  # Load weights pre-trained on ImageNet
)

# --- 2. Freeze the Base Model ---
# This stops its weights from being updated during initial training.
# We only want to train our *new* layers.
base_model.trainable = False

# --- 3. Create Your New Model ---
model_transfer = keras.Sequential([
    keras.Input(shape=input_img_shape),
    rescale,
    data_augmentation,
    base_model, # The frozen pre-trained base
    
    # --- New Classifier Head ---
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.5), # Regularization
    layers.Dense(128, activation='relu'), # A dense layer to learn from the features
    layers.Dense(class_num, activation='softmax')
])


# Cell 7: Training

In [None]:
# --- 4. Compile ---
model_transfer.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# --- 5. Train ---
print("--- Training the new classifier head ---")
# Saving history to a variable so we can plot it
history = model_transfer.fit(
    train_data,
    validation_data=val_data,
    epochs=10
)

# Cell 8: Evaluation & Plots

In [None]:
# --- Plotting ---
plt.figure(figsize=(12, 4))

# Accuracy Plot
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Loss Plot
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.show()

# --- Test Evaluation ---
print("Evaluating on Test Data...")
acc_on_test = model_transfer.evaluate(test_data)
print(f"Test Accuracy: {acc_on_test[1]*100:.2f}%")

# Cell 9: Saving

In [None]:
# --- Save Model ---
model_transfer.save('MobileNetV2_classifier.keras')
print("Model saved to MobileNetV2_classifier.keras")