In [None]:
import os
import time
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from PIL import Image
import tensorflow as tf
from tensorflow.keras.applications import VGG19
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, Callback

# Enhanced Custom Callback to track and save detailed epoch metrics
class EpochMetricsCallback(Callback):
    def __init__(self, working_path):
        super().__init__()
        self.working_path = working_path
        self.epoch_metrics_file = os.path.join(working_path, 'epoch_metrics.json')
        self.epoch_metrics = self.load_existing_metrics()

    def load_existing_metrics(self):
        try:
            with open(self.epoch_metrics_file, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            return []

    def save_epoch_metrics(self):
        with open(self.epoch_metrics_file, 'w') as f:
            json.dump(self.epoch_metrics, f, indent=4)

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        epoch_duration = time.time() - self.epoch_start_time
        
        global_epoch_num = len(self.epoch_metrics) + 1
        
        epoch_metrics = {
            'global_epoch': global_epoch_num,
            'local_epoch': global_epoch_num,
            'duration_minutes': epoch_duration / 60,
            'accuracy': logs.get('accuracy', 0),
            'val_accuracy': logs.get('val_accuracy', 0),
            'loss': logs.get('loss', 0),
            'val_loss': logs.get('val_loss', 0),
            'precision': logs.get('precision', 0),
            'val_precision': logs.get('val_precision', 0),
            'recall': logs.get('recall', 0),
            'val_recall': logs.get('val_recall', 0)
        }
        
        self.epoch_metrics.append(epoch_metrics)
        self.save_epoch_metrics()
        
        print(f"\nGlobal Epoch {global_epoch_num} Metrics:")
        print(f"Time taken: {epoch_metrics['duration_minutes']:.2f} minutes")
        print(f"Training Accuracy: {epoch_metrics['accuracy']:.4f}")
        print(f"Validation Accuracy: {epoch_metrics['val_accuracy']:.4f}")
        print(f"Training Loss: {epoch_metrics['loss']:.4f}")
        print(f"Validation Loss: {epoch_metrics['val_loss']:.4f}")

        model_path = os.path.join(self.working_path, f'model_epoch_{global_epoch_num}.h5')
        self.model.save(model_path)
        print(f"Model saved to: {model_path}")

        plot_metrics(self.epoch_metrics)

# Set paths
BASE_PATH = '/Users/sridhariyer/Sridhar/PHD FINAL PROJECT/Untitled Folder/400BirdSpecies'
WORKING_PATH = '/Users/sridhariyer/Sridhar/PHD FINAL PROJECT/Untitled Folder/VGG19/'
TRAIN_PATH = os.path.join(BASE_PATH, 'train')
VALID_PATH = os.path.join(BASE_PATH, 'valid')
TEST_PATH = os.path.join(BASE_PATH, 'test')

# Enhanced data generators with augmentation
train_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.vgg19.preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=False,
    zoom_range=0.2,
    shear_range=0.2,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

valid_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.vgg19.preprocess_input
)

test_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.vgg19.preprocess_input
)

# Create generators
train_generator = train_datagen.flow_from_directory(
    TRAIN_PATH,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

validation_generator = valid_datagen.flow_from_directory(
    VALID_PATH,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    TEST_PATH,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)

n_classes = len(train_generator.class_indices)
class_labels = list(train_generator.class_indices.keys())

# Load previous training state with detailed tracking
def load_previous_state(working_path):
    try:
        with open(os.path.join(working_path, 'training_state.json'), 'r') as f:
            state = json.load(f)
            return (
                state.get('global_epoch', 0), 
                state.get('best_val_accuracy', 0),
                state.get('model_path', None)
            )
    except FileNotFoundError:
        return 0, 0, None

def save_training_state(working_path, global_epoch, best_val_accuracy, model_path):
    state = {
        'global_epoch': global_epoch,
        'best_val_accuracy': best_val_accuracy,
        'model_path': model_path
    }
    with open(os.path.join(working_path, 'training_state.json'), 'w') as f:
        json.dump(state, f, indent=4)

# Create VGG19 model with transfer learning
def create_vgg19_model(input_shape=(224, 224, 3), num_classes=n_classes):
    # Load VGG19 with pre-trained weights
    base_model = VGG19(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    
    # Freeze the pre-trained layers
    for layer in base_model.layers:
        layer.trainable = False
    
    # Add custom layers for fine-tuning
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x)
    x = Dropout(0.5)(x)
    x = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x)
    x = Dropout(0.4)(x)
    predictions = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=predictions)
    return model

def save_best_model(model, val_accuracy, best_val_accuracy, working_path):
    if val_accuracy > best_val_accuracy:
        model_path = os.path.join(working_path, 'best_model.h5')
        model.save(model_path)
        print(f"New best model saved with validation accuracy: {val_accuracy}")
        return val_accuracy, model_path
    return best_val_accuracy, None

# Plot metrics function (same as original)
def plot_metrics(epoch_metrics):
    metrics = epoch_metrics[-1]
    epochs = [m['global_epoch'] for m in epoch_metrics]
    
    plt.figure(figsize=(15, 10))

    plt.subplot(2, 2, 1)
    plt.plot(epochs, [m['accuracy'] for m in epoch_metrics], label='Training Accuracy')
    plt.plot(epochs, [m['val_accuracy'] for m in epoch_metrics], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(epochs, [m['loss'] for m in epoch_metrics], label='Training Loss')
    plt.plot(epochs, [m['val_loss'] for m in epoch_metrics], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(2, 2, 3)
    plt.plot(epochs, [m['precision'] for m in epoch_metrics], label='Training Precision')
    plt.plot(epochs, [m['val_precision'] for m in epoch_metrics], label='Validation Precision')
    plt.title('Model Precision')
    plt.xlabel('Epoch')
    plt.ylabel('Precision')
    plt.legend()

    plt.subplot(2, 2, 4)
    plt.plot(epochs, [m['recall'] for m in epoch_metrics], label='Training Recall')
    plt.plot(epochs, [m['val_recall'] for m in epoch_metrics], label='Validation Recall')
    plt.title('Model Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Initialize the callback
epoch_metrics_callback = EpochMetricsCallback(WORKING_PATH)

# Load previous state with global epoch
global_epoch, best_val_accuracy, previous_model_path = load_previous_state(WORKING_PATH)

# Create or load the model
if os.path.exists(os.path.join(WORKING_PATH, 'best_model.h5')):
    print("Loading previous model...")
    model = load_model(os.path.join(WORKING_PATH, 'best_model.h5'))
else:
    model = create_vgg19_model()

# Compile the model
model.compile(
    optimizer=Adam(learning_rate=0.0001, clipnorm=1.0),
    loss='categorical_crossentropy',
    metrics=['accuracy', 
             tf.keras.metrics.Precision(),
             tf.keras.metrics.Recall()]
)

# Callbacks
checkpoint = ModelCheckpoint(
    filepath=os.path.join(WORKING_PATH, 'checkpoint_epoch-{epoch:02d}_val_loss-{val_loss:.2f}.weights.h5'),
    monitor='val_accuracy',
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=7,
    restore_best_weights=True,
    min_delta=0.001
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=4,
    min_lr=1e-7
)

# Training configuration
total_epochs = 50
histories = []
total_start_time = time.time()

# Continue training from the last global epoch
for _ in range(global_epoch, total_epochs):
    global_epoch += 1
    
    print(f'\nGlobal Epoch {global_epoch}/{total_epochs}')
    history = model.fit(
        train_generator,
        validation_data=validation_generator,
        epochs=1,
        callbacks=[checkpoint, early_stopping, reduce_lr, epoch_metrics_callback],
        verbose=1
    )
    
    val_accuracy = history.history['val_accuracy'][0]
    best_val_accuracy, model_path = save_best_model(model, val_accuracy, best_val_accuracy, WORKING_PATH)
    
    histories.append(history)
    save_training_state(WORKING_PATH, global_epoch, best_val_accuracy, model_path or previous_model_path)
    
    if early_stopping.model.stop_training:
        print("Early stopping triggered. Training complete.")
        break

# Total training time
total_training_time = (time.time() - total_start_time) / 60
print(f"\nTotal Training Time: {total_training_time:.2f} minutes")

# Print final training summary
print("\nTraining Summary:")
print(f"Total Global Epochs: {global_epoch}")
print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
print(f"Total Training Time: {total_training_time:.2f} minutes")

Found 58389 images belonging to 400 classes.
Found 2000 images belonging to 400 classes.
Found 2005 images belonging to 400 classes.


2025-01-17 10:03:28.710782: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-01-17 10:03:28.710808: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-01-17 10:03:28.710810: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
2025-01-17 10:03:28.710840: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-01-17 10:03:28.710850: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)



Global Epoch 1/50


  self._warn_if_super_not_called()
2025-01-17 10:03:29.721668: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m 386/1825[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m5:51[0m 245ms/step - accuracy: 0.0040 - loss: 28.5526 - precision: 0.0042 - recall: 0.0036