In [None]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0, ResNet50, MobileNetV2
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint


def create_transfer_learning_model(base_model_name='EfficientNetB0', 
                                   input_shape=(256, 256, 3), 
                                   num_classes=9,
                                   freeze_base=True):
    """
    Create a transfer learning model using ImageNet pre-trained weights
    """
    
    # Choose base model
    if base_model_name == 'EfficientNetB0':
        base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name == 'ResNet50':
        base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name == 'MobileNetV2':
        base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    else:
        raise ValueError("Choose from 'EfficientNetB0', 'ResNet50', or 'MobileNetV2'")
    
    # Freeze/unfreeze
    base_model.trainable = not freeze_base
    print(f"{'Frozen' if freeze_base else 'Unfrozen'} base model ({base_model_name}) with {len(base_model.layers)} layers")

    # Build model
    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        
        Dense(256, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        
        Dense(num_classes, activation='softmax')
    ])
    
    return model, base_model


def compile_transfer_model(model, learning_rate=0.001):
    """Compile the transfer learning model"""
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model


def create_callbacks(model_name='best_transfer_model.keras'):
    """Create training callbacks"""
    return [
        EarlyStopping(
            monitor='val_loss',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),
        ModelCheckpoint(
            model_name,
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        )
    ]


# Create and compile the model
model, base_model = create_transfer_learning_model(
    base_model_name='EfficientNetB0',
    input_shape=(256, 256, 3),
    num_classes=train_generator.num_classes,   # use generator classes
    freeze_base=True
)

model = compile_transfer_model(model, learning_rate=0.001)

# Print summary
print("\nModel Architecture:")
model.summary()

# Setup callbacks
callbacks = create_callbacks('best_transfer_model.keras')

# Train with generators
history = model.fit(
    train_generator,
    epochs=20,
    validation_data=val_generator,
    callbacks=callbacks,
    verbose=1
)
