In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore')

IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 15

BASE_DATA_DIR = 'dataset'
AGE_DATA_DIR = os.path.join(BASE_DATA_DIR, 'age')
GENDER_DATA_DIR = os.path.join(BASE_DATA_DIR, 'gender')
HAIR_DATA_DIR = os.path.join(BASE_DATA_DIR, 'hair')

MODEL_SAVE_DIR = 'saved_image_models'
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

AGE_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, 'age_model.h5')
GENDER_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, 'gender_model.h5')
HAIR_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, 'hair_model.h5')

AGE_CLASSES = ['lt_20', '20_30', 'gt_30']
GENDER_CLASSES = ['male', 'female']
HAIR_CLASSES = ['short', 'long']

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

validation_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

def build_model(input_shape, num_classes, base_model_trainable=False):
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    base_model.trainable = base_model_trainable

    inputs = Input(shape=input_shape)
    x = base_model(inputs, training=base_model_trainable)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)

    model = Model(inputs, outputs)

    loss_func = 'categorical_crossentropy' if num_classes > 2 else 'binary_crossentropy'
    model.compile(optimizer=Adam(learning_rate=0.0001),
                  loss=loss_func,
                  metrics=['accuracy'])
    return model

def train_evaluate_image_model(model_name, data_dir, classes, target_size, model_save_path):
    print(f"\n--- Training {model_name} Model ---")

    if not os.path.exists(data_dir) or not os.listdir(data_dir):
         print(f"!!! ERROR: Data directory '{data_dir}' not found or is empty. !!!")
         print("!!! Please create the directory and populate it with subdirectories named like: ", classes)
         print("!!! Skipping training for {model_name}. !!!")
         return None, None

    num_classes = len(classes)
    class_mode = 'categorical' if num_classes > 2 else 'binary'

    try:
        train_generator = train_datagen.flow_from_directory(
            data_dir,
            target_size=target_size,
            batch_size=BATCH_SIZE,
            class_mode=class_mode,
            classes=classes,
            subset='training',
            shuffle=True
        )

        validation_generator = validation_datagen.flow_from_directory(
            data_dir,
            target_size=target_size,
            batch_size=BATCH_SIZE,
            class_mode=class_mode,
            classes=classes,
            subset='validation',
            shuffle=False
        )
    except Exception as e:
        print(f"!!! ERROR loading data from {data_dir}: {e} !!!")
        print("!!! Check if subdirectories match class names and contain images. !!!")
        return None, None

    if train_generator.samples == 0 or validation_generator.samples == 0:
        print(f"!!! ERROR: No training or validation samples found in {data_dir}. Check image files. !!!")
        return None, None

    print(f"Found {train_generator.samples} training images belonging to {num_classes} classes.")
    print(f"Found {validation_generator.samples} validation images belonging to {num_classes} classes.")
    print(f"Class Indices: {train_generator.class_indices}")

    input_shape = target_size + (3,)
    model = build_model(input_shape, num_classes)

    print(f"Training {model_name} model for {EPOCHS} epochs...")
    history = model.fit(
        train_generator,
        steps_per_epoch=train_generator.samples // BATCH_SIZE,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples // BATCH_SIZE,
        epochs=EPOCHS,
        verbose=1
    )

    print(f"\nEvaluating {model_name} model...")
    loss, accuracy = model.evaluate(validation_generator, steps=validation_generator.samples // BATCH_SIZE)
    print(f"{model_name} Validation Accuracy: {accuracy:.4f}")
    print(f"{model_name} Validation Loss: {loss:.4f}")

    print(f"Saving {model_name} model to {model_save_path}")
    model.save(model_save_path)
    print("-" * (20 + len(model_name)))

    return model, history

if __name__ == "__main__":
    age_model, age_history = train_evaluate_image_model(
        "Age Group", AGE_DATA_DIR, AGE_CLASSES, IMAGE_SIZE, AGE_MODEL_PATH
    )

    gender_model, gender_history = train_evaluate_image_model(
        "Gender", GENDER_DATA_DIR, GENDER_CLASSES, IMAGE_SIZE, GENDER_MODEL_PATH
    )

    hair_model, hair_history = train_evaluate_image_model(
        "Hair Length", HAIR_DATA_DIR, HAIR_CLASSES, IMAGE_SIZE, HAIR_MODEL_PATH
    )

    print("\n--- All models trained and saved to", MODEL_SAVE_DIR, "---")
    print("!!! IMPORTANT: Ensure you replaced placeholder data directories with your actual labeled image datasets !!!")

    if age_history:
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.plot(age_history.history['accuracy'], label='Age Train Acc')
        plt.plot(age_history.history['val_accuracy'], label='Age Val Acc')
        plt.title('Age Model Accuracy')
        plt.legend()
    if gender_history:
         plt.subplot(1, 3, 2)
         plt.plot(gender_history.history['accuracy'], label='Gender Train Acc')
         plt.plot(gender_history.history['val_accuracy'], label='Gender Val Acc')
         plt.title('Gender Model Accuracy')
         plt.legend()
    if hair_history:
         plt.subplot(1, 3, 3)
         plt.plot(hair_history.history['accuracy'], label='Hair Train Acc')
         plt.plot(hair_history.history['val_accuracy'], label='Hair Val Acc')
         plt.title('Hair Model Accuracy')
         plt.legend()
    plt.tight_layout()
    plt.show()
