In [None]:
"""
fingerprint_vgg16
=====================
An example end-to-end script for training a VGG16-based CNN
to classify damaged fingerprint images.

Folder structure assumption:
  data/
    train/
      class1/
        img_1.jpg
        img_2.jpg
        ...
      class2/
        img_1.jpg
        img_2.jpg
        ...
      ...
    val/
      class1/
        ...
      class2/
        ...
    test/
      class1/
        ...
      class2/
        ...
"""

import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

###########################
# 1. USER-DEFINED PARAMS  #
###########################

# Paths to your dataset folders
TRAIN_DIR = "data/train"
VAL_DIR   = "data/val"
TEST_DIR  = "data/test"

# Model checkpoints/logs
CHECKPOINT_PATH = "checkpoints/vgg16_fingerprint_best.h5"
LOG_DIR = "logs"

# Image parameters
IMG_HEIGHT = 224
IMG_WIDTH  = 224
CHANNELS   = 3  # If your fingerprints are grayscale, set this to 1 and adapt accordingly.
BATCH_SIZE = 16

# Training hyperparameters
EPOCHS      = 20
LR          = 1e-4
DROPOUT_RATE = 0.3

###########################
# 2. DATA AUGMENTATION    #
###########################

# For damaged fingerprints, you may want robust augmentations:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=False,   # Usually fingerprints aren't flipped horizontally, but you can experiment
    vertical_flip=False,     # Same note as above
    # You could add custom “damage simulation” here if desired
)

val_datagen = ImageDataGenerator(rescale=1./255)

# If you have separate test data:
test_datagen = ImageDataGenerator(rescale=1./255)

###########################
# 3. DATA LOADERS         #
###########################

train_generator = train_datagen.flow_from_directory(
    directory=TRAIN_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',  # or 'sparse' if you prefer
    color_mode='rgb' if CHANNELS == 3 else 'grayscale',
    shuffle=True
)

val_generator = val_datagen.flow_from_directory(
    directory=VAL_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    color_mode='rgb' if CHANNELS == 3 else 'grayscale',
    shuffle=False
)

test_generator = test_datagen.flow_from_directory(
    directory=TEST_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    color_mode='rgb' if CHANNELS == 3 else 'grayscale',
    shuffle=False
)

num_classes = train_generator.num_classes
print(f"Detected {num_classes} classes in the training set.")

###################################
# 4. BUILD THE VGG16-BASED MODEL  #
###################################

# Load pretrained VGG16 (on ImageNet) without the top/classifier layers
# If your images are grayscale, you'd need to tweak input_shape=(224,224,1)
base_model = VGG16(
    weights='imagenet',
    include_top=False,
    input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS)
)

# Freeze the early layers so they act as a fixed feature extractor
# You can experiment with how many layers to freeze/unfreeze
for layer in base_model.layers[:10]:
    layer.trainable = False

# Add a custom classification head
x = base_model.output
# Option 1: Flatten
# x = Flatten()(x)

# Option 2: GlobalAveragePooling2D (usually better)
x = GlobalAveragePooling2D()(x)

x = Dense(256, activation='relu')(x)
x = Dropout(DROPOUT_RATE)(x)
# You can add more dense layers if needed

# Final output layer for classification
predictions = Dense(num_classes, activation='softmax')(x)

# Build the full model
model = Model(inputs=base_model.input, outputs=predictions)

# Print a summary to see the trainable parameters
model.summary()

###################################
# 5. COMPILE THE MODEL            #
###################################

model.compile(
    optimizer=Adam(learning_rate=LR),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

###################################
# 6. TRAINING CALLBACKS           #
###################################

# Early stopping to prevent overfitting
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

# Model checkpoint to save the best model
checkpoint = ModelCheckpoint(
    filepath=CHECKPOINT_PATH,
    monitor='val_accuracy',
    save_best_only=True,
    verbose=1
)

callbacks_list = [early_stop, checkpoint]

###################################
# 7. FIT / TRAIN THE MODEL        #
###################################
steps_per_epoch   = train_generator.samples // BATCH_SIZE
validation_steps  = val_generator.samples // BATCH_SIZE

history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=EPOCHS,
    validation_data=val_generator,
    validation_steps=validation_steps,
    callbacks=callbacks_list
)

###################################
# 8. EVALUATE ON THE TEST SET     #
###################################
print("\nEvaluating on the test set...")
test_loss, test_acc = model.evaluate(test_generator, verbose=1)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

###################################
# 9. SAVE THE TRAINED MODEL       #
###################################
# If you want to save the final model (even if not the best)
final_model_path = "vgg16_fingerprint_final.h5"
model.save(final_model_path)
print(f"Model saved to {final_model_path}")

###################################
# 10. OPTIONAL: PLOT RESULTS      #
###################################
import matplotlib.pyplot as plt

def plot_history(history):
    # Accuracy
    plt.figure(figsize=(8,4))
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Val Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

    # Loss
    plt.figure(figsize=(8,4))
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

plot_history(history)
