In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, applications
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt
import time

# Set memory growth for GPUs
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    for device in physical_devices:
        tf.config.experimental.set_memory_growth(device, True)
    print(f"Found {len(physical_devices)} GPU(s)")
else:
    print("No GPU found, using CPU")

def find_dataset_path(base_dir="../data/fgvc-aircraft-2013b"):
    """Find the correct dataset path by checking for required files"""
    possible_paths = [
        base_dir,
        os.path.join(base_dir, "data"),
        os.path.join(base_dir, "fgvc-aircraft-2013b", "data")
    ]
    
    for path in possible_paths:
        variant_file = os.path.join(path, "variants.txt")
        if os.path.exists(variant_file):
            print(f"Found dataset at: {path}")
            return path
    
    # If we can't find the dataset, search recursively
    for root, dirs, files in os.walk(base_dir):
        if "variants.txt" in files:
            print(f"Found dataset at: {root}")
            return root
    
    raise FileNotFoundError(f"Could not find dataset in {base_dir}")

def load_fgvc_aircraft(data_dir, img_height=224, img_width=224, batch_size=32, task_type='variant'):
    """
    Load the FGVC Aircraft dataset
    
    Parameters:
    data_dir (str): Path to the dataset directory
    img_height (int): Image height after resizing
    img_width (int): Image width after resizing
    batch_size (int): Batch size for training/evaluation
    task_type (str): Classification task - 'variant', 'family', or 'manufacturer'
    """
    print(f"Loading dataset from: {data_dir}")
    
    # Select the appropriate class file based on task_type
    if task_type == 'variant':
        class_file = os.path.join(data_dir, "variants.txt")
    elif task_type == 'family':
        class_file = os.path.join(data_dir, "families.txt")
    elif task_type == 'manufacturer':
        class_file = os.path.join(data_dir, "manufacturers.txt")
    else:
        raise ValueError(f"Unsupported task type: {task_type}. Choose from 'variant', 'family', or 'manufacturer'")
    
    # Load class labels
    with open(class_file, 'r') as f:
        classes = [line.strip() for line in f.readlines()]
    
    print(f"Found {len(classes)} aircraft {task_type} classes")
    
    # Create class mapping
    class_to_idx = {cls: i for i, cls in enumerate(classes)}
    
    # Load train/val/test splits
    # For training, we'll use the train and val sets separately
    # For testing, we'll use the test set
    splits = ['train', 'val', 'test']
    datasets = {}
    
    for split in splits:
        # Get image filenames
        split_file = os.path.join(data_dir, f"images_{split}.txt")
        with open(split_file, 'r') as f:
            image_files = [line.strip() for line in f.readlines()]
        
        # Get image labels
        annotation_file = os.path.join(data_dir, f"images_variant_{split}.txt")
        image_labels = {}
        with open(annotation_file, 'r') as f:
            for line in f.readlines():
                parts = line.strip().split(' ', 1)
                if len(parts) == 2:
                    image_labels[parts[0]] = parts[1]
        
        images = []
        labels = []
        
        for img_name in image_files:
            img_path = os.path.join(data_dir, 'images', f'{img_name}.jpg')
            if not os.path.exists(img_path):
                # Try alternative path structure
                img_path = os.path.join(data_dir, '..', 'images', f'{img_name}.jpg')
                if not os.path.exists(img_path):
                    print(f"Warning: Could not find image {img_name}.jpg, skipping")
                    continue
            
            variant = image_labels[img_name]
            label = class_to_idx[variant]
            
            images.append(img_path)
            labels.append(label)
        
        # Create dataset
        def load_image(img_path, label):
            img = tf.io.read_file(img_path)
            img = tf.image.decode_jpeg(img, channels=3)
            img = tf.image.resize(img, [img_height, img_width])
            img = tf.cast(img, tf.float32) / 255.0
            return img, label
        
        # Skip empty datasets
        if not images:
            print(f"Warning: No images found for {split} split")
            continue
            
        dataset = tf.data.Dataset.from_tensor_slices((images, labels))
        
        if split == 'train':
            # Apply data augmentation to training set
            train_dataset = dataset.map(load_image).cache()
            train_dataset = train_dataset.shuffle(buffer_size=len(images))
            train_dataset = train_dataset.map(lambda x, y: (tf.image.random_flip_left_right(x), y))
            train_dataset = train_dataset.map(lambda x, y: (tf.image.random_brightness(x, 0.1), y))
            train_dataset = train_dataset.map(lambda x, y: (tf.image.random_contrast(x, 0.8, 1.2), y))
            train_dataset = train_dataset.map(lambda x, y: (tf.image.random_saturation(x, 0.8, 1.2), y))
            train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
            
            datasets['train'] = train_dataset
            print(f"Train set: {len(images)} images")
            
        elif split == 'val':
            # Prepare validation set
            val_dataset = dataset.map(load_image).batch(batch_size).prefetch(tf.data.AUTOTUNE)
            datasets['val'] = val_dataset
            print(f"Validation set: {len(images)} images")
            
        else:
            # Prepare test set
            test_dataset = dataset.map(load_image).batch(batch_size).prefetch(tf.data.AUTOTUNE)
            datasets['test'] = test_dataset
            print(f"Test set: {len(images)} images")
    
    return datasets, classes

def create_model(num_classes, input_shape=(224, 224, 3), base_model_name='ResNet50'):
    """Create a CNN model with transfer learning"""
    # Choose base model
    if base_model_name == 'ResNet50':
        base_model = applications.ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name == 'EfficientNetB0':
        base_model = applications.EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name == 'MobileNetV2':
        base_model = applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    else:
        raise ValueError(f"Unsupported base model: {base_model_name}")
    
    # Freeze base model layers
    base_model.trainable = False
    
    # Create model
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model

def extract_features(model, dataset, layer_name=None):
    """Extract features from a specific layer of the model"""
    if layer_name:
        # Create a feature extractor model
        feature_extractor = models.Model(
            inputs=model.input,
            outputs=model.get_layer(layer_name).output
        )
    else:
        # Use the model up to the layer before the final classification layer
        feature_extractor = models.Model(
            inputs=model.input,
            outputs=model.layers[-2].output
        )
    
    features = []
    labels = []
    
    # Extract features
    for images, batch_labels in dataset:
        batch_features = feature_extractor.predict(images)
        features.append(batch_features)
        labels.append(batch_labels)
    
    # Concatenate all features and labels
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    return features, labels

def train_model(model, datasets, epochs=20, learning_rate=0.001, output_dir='./model_output'):
    """Train the model with early stopping and learning rate reduction"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Compile model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Set up callbacks
    callbacks = [
        ModelCheckpoint(
            filepath=os.path.join(output_dir, 'best_model.h5'),
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=3,
            min_lr=1e-6,
            verbose=1
        )
    ]
    
    # Train model
    start_time = time.time()
    history = model.fit(
        datasets['train'],
        validation_data=datasets['val'],
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    training_time = time.time() - start_time
    
    print(f"Training completed in {training_time:.2f} seconds")
    
    # Save final model
    model.save(os.path.join(output_dir, 'final_model.h5'))
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_history.png'))
    
    return history

def evaluate_model(model, test_dataset, class_names, output_dir='./model_output'):
    """Evaluate the model on the test set"""
    # Evaluate the model
    test_loss, test_accuracy = model.evaluate(test_dataset)
    print(f"Test accuracy: {test_accuracy:.4f}")
    print(f"Test loss: {test_loss:.4f}")
    
    # Get predictions
    all_predictions = []
    all_labels = []
    
    for images, labels in test_dataset:
        predictions = model.predict(images)
        predicted_classes = np.argmax(predictions, axis=1)
        
        all_predictions.extend(predicted_classes)
        all_labels.extend(labels.numpy())
    
    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Calculate confusion matrix
    from sklearn.metrics import confusion_matrix, classification_report
    
    cm = confusion_matrix(all_labels, all_predictions)
    
    # Save confusion matrix as image (if not too large)
    if len(class_names) <= 100:  # Only plot if not too many classes
        plt.figure(figsize=(15, 15))
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion Matrix')
        plt.colorbar()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
    
    # Generate and save classification report
    report = classification_report(all_labels, all_predictions, target_names=class_names)
    with open(os.path.join(output_dir, 'classification_report.txt'), 'w') as f:
        f.write(report)
    
    print("Evaluation results saved to", output_dir)
    return test_accuracy

def fine_tune_model(model, datasets, epochs=10, learning_rate=0.0001, output_dir='./model_output'):
    """Fine-tune the model by unfreezing the top layers of the base model"""
    # Get the base model
    base_model = model.layers[0]
    
    # Unfreeze the top layers
    # For ResNet50, we'll unfreeze the last convolutional block
    base_model.trainable = True
    
    # Freeze all layers except the last few
    if isinstance(base_model, applications.ResNet50):
        # Unfreeze the last convolutional block (stage 5)
        for layer in base_model.layers[:-33]:  # Freeze everything before the last conv block
            layer.trainable = False
    elif isinstance(base_model, applications.EfficientNetB0):
        # Unfreeze the last block
        for layer in base_model.layers[:-20]:
            layer.trainable = False
    elif isinstance(base_model, applications.MobileNetV2):
        # Unfreeze the last block
        for layer in base_model.layers[:-23]:
            layer.trainable = False
    
    # Recompile model with a lower learning rate
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Set up callbacks
    callbacks = [
        ModelCheckpoint(
            filepath=os.path.join(output_dir, 'best_finetuned_model.h5'),
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        )
    ]
    
    # Fine-tune model
    print("Fine-tuning the model...")
    start_time = time.time()
    history = model.fit(
        datasets['train'],
        validation_data=datasets['val'],
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    fine_tuning_time = time.time() - start_time
    
    print(f"Fine-tuning completed in {fine_tuning_time:.2f} seconds")
    
    # Save fine-tuned model
    model.save(os.path.join(output_dir, 'finetuned_model.h5'))
    
    # Plot fine-tuning history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'finetuning_history.png'))
    
    return history

def main():
    # Configuration
    img_height = 224
    img_width = 224
    batch_size = 32
    epochs = 20
    fine_tune_epochs = 10
    base_model_name = 'ResNet50'  # Options: 'ResNet50', 'EfficientNetB0', 'MobileNetV2'
    output_dir = './aircraft_model_output'
    task_type = 'variant'  # Options: 'variant', 'family', 'manufacturer'
    
    # Find and load dataset
    try:
        data_dir = find_dataset_path()
        print(f"Using dataset path: {data_dir}")
        
        datasets, classes = load_fgvc_aircraft(
            data_dir=data_dir,
            img_height=img_height,
            img_width=img_width,
            batch_size=batch_size,
            task_type=task_type
        )
        
        num_classes = len(classes)
        print(f"Dataset loaded with {num_classes} classes")
        
        # Create model
        model = create_model(
            num_classes=num_classes,
            input_shape=(img_height, img_width, 3),
            base_model_name=base_model_name
        )
        
        model.summary()
        
        # Train model
        print("\n=== Training model ===")
        train_model(
            model=model,
            datasets=datasets,
            epochs=epochs,
            output_dir=output_dir
        )
        
        # Evaluate model
        print("\n=== Evaluating model ===")
        initial_accuracy = evaluate_model(
            model=model,
            test_dataset=datasets['test'],
            class_names=classes,
            output_dir=output_dir
        )
        
        # Fine-tune model (optional)
        print("\n=== Fine-tuning model ===")
        fine_tune_model(
            model=model,
            datasets=datasets,
            epochs=fine_tune_epochs,
            output_dir=output_dir
        )
        
        # Evaluate fine-tuned model
        print("\n=== Evaluating fine-tuned model ===")
        final_accuracy = evaluate_model(
            model=model,
            test_dataset=datasets['test'],
            class_names=classes,
            output_dir=output_dir
        )
        
        print(f"\nInitial test accuracy: {initial_accuracy:.4f}")
        print(f"Final test accuracy after fine-tuning: {final_accuracy:.4f}")
        
        # Extract features (example)
        print("\n=== Extracting features from the model ===")
        features, labels = extract_features(
            model=model,
            dataset=datasets['test'].take(10)  # Just take a few batches as an example
        )
        print(f"Extracted features shape: {features.shape}")
        
    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    main()


No GPU found, using CPU
Found dataset at: ../data/fgvc-aircraft-2013b\data
Using dataset path: ../data/fgvc-aircraft-2013b\data
Loading dataset from: ../data/fgvc-aircraft-2013b\data
Found 100 aircraft variant classes
Train set: 3334 images
Validation set: 3333 images
Test set: 3333 images
Dataset loaded with 100 classes


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 batch_normalization (Batch  (None, 2048)              8192      
 Normalization)                                                  
                                                            

: 