In [None]:
import tensorflow as tf
from tensorflow.keras.applications.efficientnet_v2 import EfficientNetV2B3
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras import layers
import os.path
from tensorflow import data as tf_data

In [None]:
IMAGE_SIZE = 224  # Input size for EfficientNetB0
MODEL_FILE = "model_eff.h5"
train = tf.keras.preprocessing.image_dataset_from_directory(
    "split_data/train",
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=32,
    label_mode='int',
    shuffle=True
)
val = tf.keras.preprocessing.image_dataset_from_directory(
    "split_data/val",
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=32,
    label_mode='int',
    shuffle=False
)

class_names = train.class_names # ['Pallas_cats', 'Persian_cats', 'Ragdolls', 'Singapura_cats', 'Sphynx_cats']
AUTOTUNE = tf.data.AUTOTUNE

augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.2)
])


def preprocess(x, y, train):
    x = tf.cast(x, tf.float32)
    x = preprocess_input(x)  # EfficientNetB0 specific preprocessing
    if train:
        x = augmentation(x)  # Apply data augmentation
    return x, y

In [None]:
train_generator = train.map(lambda x, y: preprocess(x, y, True)).prefetch(buffer_size=AUTOTUNE)
val_generator = val.map(lambda x, y: preprocess(x, y, False)).prefetch(buffer_size=AUTOTUNE)

In [None]:
def create_model(num_classes):
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)
    )
    base_model.trainable = False  # 初始冻结基础模型

    inputs = tf.keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
    x = preprocess_input(inputs)  # 关键：使用专用预处理
    x = base_model(x, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.2)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    return model

In [None]:
def load_existing(model_file):
    model = load_model(model_file)
    # 解冻最后4个块进行微调
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):  # 找到基础模型
            base_model = layer
            break

    if base_model:
        # 解冻最后4个块
        for layer in base_model.layers[-20:]:
            if not isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = True
    return model


In [None]:
def train(model_file, train_path, validation_path, num_classes=5, steps=100, num_epochs=20):
    if os.path.exists(model_file):
        print("\n*** Loading existing model ***\n")
        model = load_existing(model_file)
        # 必须重新编译
        model.compile(
            optimizer=Adam(learning_rate=1e-4),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
    else:
        print("\n*** Creating new model ***\n")
        model = create_model(num_classes)
        model.compile(
            optimizer=Adam(learning_rate=1e-3),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

    checkpoint = ModelCheckpoint(
        model_file,
        save_best_only=True,
        monitor='val_accuracy',
        mode='max'
    )

    # 第一阶段：训练新添加的层
    print("=== Phase 1: Training Head ===")
    history = model.fit(
        train_generator,
        steps_per_epoch=steps,
        epochs=num_epochs,
        callbacks=[checkpoint],
        validation_data=val_generator,
        validation_steps=20
    )

    # 第二阶段：微调
    print("\n=== Phase 2: Fine-Tuning ===")
    # 找到基础模型并解冻部分层
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):
            base_model = layer
            break

    if base_model:
        base_model.trainable = True
        # 解冻最后4个块 (EfficientNetB0有7个块，解冻block5b到block7a)
        for layer in base_model.layers:
            layer.trainable = False  # 先冻结所有

        # 解冻最后部分层
        for layer in base_model.layers[-20:]:
            if not isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = True

    # 使用更小的学习率
    model.compile(
        optimizer=Adam(learning_rate=1e-5),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    model.fit(
        train_generator,
        steps_per_epoch=steps,
        epochs=num_epochs,
        callbacks=[checkpoint],
        validation_data=val_generator,
        validation_steps=20
    )


In [None]:
def main():
    train(
        MODEL_FILE,
        train_path="split_data/train",
        validation_path="split_data/val",
        steps=100,  # 根据数据集大小调整
        num_epochs=15
    )

if __name__ == '__main__':
    main()