In [8]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras import layers
import os.path
from tensorflow import data as tf_data

In [None]:
IMAGE_SIZE = 224  # Input size for EfficientNetB0
MODEL_FILE = "model_eff.h5"
train = tf.keras.preprocessing.image_dataset_from_directory(
    "split_data/train",
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=32,
    label_mode='int',
    shuffle=True
)
val = tf.keras.preprocessing.image_dataset_from_directory(
    "split_data/val",
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=32,
    label_mode='int',
    shuffle=False
)

class_names = train.class_names # ['Pallas_cats', 'Persian_cats', 'Ragdolls', 'Singapura_cats', 'Sphynx_cats']
AUTOTUNE = tf.data.AUTOTUNE

augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.2)
])


def preprocess(x, y, train):
    x = tf.cast(x, tf.float32) 
    x = preprocess_input(x)  # EfficientNetB0 specific preprocessing
    if train:
        x = augmentation(x)  # Apply data augmentation
    return x, y

Found 3566 files belonging to 5 classes.
Found 894 files belonging to 5 classes.


In [10]:
train_generator = train.map(lambda x, y: preprocess(x, y, True)).prefetch(buffer_size=AUTOTUNE)
val_generator = val.map(lambda x, y: preprocess(x, y, False)).prefetch(buffer_size=AUTOTUNE)

In [11]:
def create_model(num_classes):
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)
    )
    base_model.trainable = False  # 初始冻结基础模型
    
    inputs = tf.keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
    x = preprocess_input(inputs)  # 关键：使用专用预处理
    x = base_model(x, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.2)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    return model

In [12]:
def load_existing(model_file):
    model = load_model(model_file)
    # 解冻最后4个块进行微调
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):  # 找到基础模型
            base_model = layer
            break
    
    if base_model:
        # 解冻最后4个块
        for layer in base_model.layers[-20:]:
            if not isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = True
    return model


In [13]:
def train(model_file, train_path, validation_path, num_classes=5, steps=100, num_epochs=20):
    if os.path.exists(model_file):
        print("\n*** Loading existing model ***\n")
        model = load_existing(model_file)
        # 必须重新编译
        model.compile(
            optimizer=Adam(learning_rate=1e-4),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
    else:
        print("\n*** Creating new model ***\n")
        model = create_model(num_classes)
        model.compile(
            optimizer=Adam(learning_rate=1e-3),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

    checkpoint = ModelCheckpoint(
        model_file,
        save_best_only=True,
        monitor='val_accuracy',
        mode='max'
    )

    # 第一阶段：训练新添加的层
    print("=== Phase 1: Training Head ===")
    history = model.fit(
        train_generator,
        steps_per_epoch=steps,
        epochs=num_epochs,
        callbacks=[checkpoint],
        validation_data=val_generator,
        validation_steps=20
    )

    # 第二阶段：微调
    print("\n=== Phase 2: Fine-Tuning ===")
    # 找到基础模型并解冻部分层
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):
            base_model = layer
            break
    
    if base_model:
        base_model.trainable = True
        # 解冻最后4个块 (EfficientNetB0有7个块，解冻block5b到block7a)
        for layer in base_model.layers:
            layer.trainable = False  # 先冻结所有
            
        # 解冻最后部分层
        for layer in base_model.layers[-20:]:
            if not isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = True

    # 使用更小的学习率
    model.compile(
        optimizer=Adam(learning_rate=1e-5),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    model.fit(
        train_generator,
        steps_per_epoch=steps,
        epochs=num_epochs,
        callbacks=[checkpoint],
        validation_data=val_generator,
        validation_steps=20
    )


In [14]:
def main():
    train(
        MODEL_FILE,
        train_path="split_data/train",
        validation_path="split_data/val",
        steps=100,  # 根据数据集大小调整
        num_epochs=15
    )

if __name__ == '__main__':
    main()


*** Creating new model ***

=== Phase 1: Training Head ===
Epoch 1/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 207ms/step - accuracy: 0.7340 - loss: 0.7134



[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 249ms/step - accuracy: 0.7349 - loss: 0.7110 - val_accuracy: 0.9328 - val_loss: 0.2106
Epoch 2/15
[1m 12/100[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m16s[0m 187ms/step - accuracy: 0.8855 - loss: 0.3330



[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 52ms/step - accuracy: 0.8829 - loss: 0.3426 - val_accuracy: 0.9328 - val_loss: 0.2113
Epoch 3/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 236ms/step - accuracy: 0.9021 - loss: 0.2570 - val_accuracy: 0.9297 - val_loss: 0.1981
Epoch 4/15
[1m 12/100[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m16s[0m 182ms/step - accuracy: 0.8940 - loss: 0.2981



[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 52ms/step - accuracy: 0.9007 - loss: 0.2819 - val_accuracy: 0.9422 - val_loss: 0.1786
Epoch 5/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 238ms/step - accuracy: 0.9194 - loss: 0.2213 - val_accuracy: 0.9328 - val_loss: 0.1961
Epoch 6/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 50ms/step - accuracy: 0.9188 - loss: 0.2393 - val_accuracy: 0.9328 - val_loss: 0.2002
Epoch 7/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 236ms/step - accuracy: 0.9158 - loss: 0.2182 - val_accuracy: 0.9406 - val_loss: 0.1763
Epoch 8/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 53ms/step - accuracy: 0.8838 - loss: 0.2706 - val_accuracy: 0.9391 - val_loss: 0.1923
Epoch 9/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 245ms/step - accuracy: 0.9223 - loss: 0.2123 - val_accuracy: 0.9375 - val_loss: 0.1916
Epoch 10/15
[1m 12/100[0m



[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 53ms/step - accuracy: 0.8960 - loss: 0.2319 - val_accuracy: 0.9484 - val_loss: 0.1658
Epoch 11/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 243ms/step - accuracy: 0.9244 - loss: 0.1820 - val_accuracy: 0.9406 - val_loss: 0.1837
Epoch 12/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 52ms/step - accuracy: 0.9126 - loss: 0.2609 - val_accuracy: 0.9391 - val_loss: 0.2079
Epoch 13/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 242ms/step - accuracy: 0.9374 - loss: 0.1780 - val_accuracy: 0.9344 - val_loss: 0.1972
Epoch 14/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 52ms/step - accuracy: 0.9139 - loss: 0.2365 - val_accuracy: 0.9453 - val_loss: 0.1752
Epoch 15/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 242ms/step - accuracy: 0.9506 - loss: 0.1521 - val_accuracy: 0.9344 - val_loss: 0.2167

=== Phase 2: Fine-Tun



[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 55ms/step - accuracy: 0.9458 - loss: 0.1431 - val_accuracy: 0.9516 - val_loss: 0.1573
Epoch 3/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 220ms/step - accuracy: 0.9462 - loss: 0.1530



[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 253ms/step - accuracy: 0.9462 - loss: 0.1529 - val_accuracy: 0.9547 - val_loss: 0.1544
Epoch 4/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 53ms/step - accuracy: 0.9487 - loss: 0.1343 - val_accuracy: 0.9453 - val_loss: 0.1551
Epoch 5/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 251ms/step - accuracy: 0.9533 - loss: 0.1263 - val_accuracy: 0.9516 - val_loss: 0.1587
Epoch 6/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 53ms/step - accuracy: 0.9587 - loss: 0.1532 - val_accuracy: 0.9531 - val_loss: 0.1605
Epoch 7/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 255ms/step - accuracy: 0.9442 - loss: 0.1389 - val_accuracy: 0.9547 - val_loss: 0.1475
Epoch 8/15
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 53ms/step - accuracy: 0.9427 - loss: 0.1498 - val_accuracy: 0.9531 - val_loss: 0.1701
Epoch 9/15
[1m100/100[0m 