In [None]:
# Jupyter Notebook - 代码

# 导入必要的库
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import tensorflow_model_optimization as tfmot  # 新增剪枝库
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_apply = tfmot.quantization.keras.quantize_apply
import datetime
# 设定日志级别
tf.get_logger().setLevel('ERROR')

# 🔹 超参数
IMG_SIZE = (96, 96)
AUTOTUNE = tf.data.AUTOTUNE

In [None]:
# 🔹 数据增强
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(factor=(-0.125,0.125),fill_mode="nearest"),
    tf.keras.layers.RandomZoom(0.25,fill_mode="nearest"),
    tf.keras.layers.RandomTranslation(height_factor=0.25, width_factor=0.25),
    tf.keras.layers.RandomBrightness(0.25),
    tf.keras.layers.RandomContrast(0.3)
])

def preprocess_image(image, label):
    image = tf.cast(image, tf.float32)  # 转换为 float32
    image = image / 255.0              # 归一化到 [0,1]
    return image, label

def preprocess_image_Au(image, label):
    image = tf.cast(image, tf.float32)  # 转换为 float32
    image = data_augmentation(image)
    image = image / 255.0              # 归一化到 [0,1]
    return image, label

# -------------------------------- 第一阶段训练 ---------------------------------

In [None]:
# 🔹 数据集路径
base_dir = './dataset'
train_dir = os.path.join(base_dir, 'DATA_1')
valid_dir = os.path.join(base_dir, 'DATA_1')

BATCH_SIZE = 128
# 🔹 加载数据集
train_dataset_raw = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir, 
    validation_split=0.8, 
    subset="training", 
    seed=12,
    batch_size=BATCH_SIZE, 
    image_size=IMG_SIZE)

validation_dataset_raw = tf.keras.preprocessing.image_dataset_from_directory(
    valid_dir, 
    validation_split=0.2, 
    subset="validation", 
    seed=12,
    batch_size=BATCH_SIZE, 
    image_size=IMG_SIZE)

class_names = train_dataset_raw.class_names
print("Class Names:", class_names)

# 加载数据集
train_dataset = (train_dataset_raw
                 .map(preprocess_image, num_parallel_calls=AUTOTUNE)
                 .cache('train_cache')  # 缓存到文件
                 .shuffle(1000)
                 .prefetch(AUTOTUNE))

validation_dataset = (validation_dataset_raw
                      .map(preprocess_image, num_parallel_calls=AUTOTUNE)
                      .cache('val_cache')
                      .prefetch(AUTOTUNE))


In [None]:
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SHAPE,
    include_top=False,
    pooling='avg',
    alpha=0.35,
    weights='imagenet')

# 量化注释并应用
annotated_model = tf.keras.Sequential([
    base_model,
    quantize_annotate_layer(tf.keras.layers.Dense(len(class_names), activation='softmax'))
])

# 应用量化
qat_model = quantize_apply(annotated_model)
qat_model.build((None, 96, 96, 3))
qat_model.summary()

In [None]:
# 编译模型
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.00001, decay_steps=1000, decay_rate=0.90, staircase=True)

qat_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

# 训练第一阶段
early_stopping = tf.keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True)

history_stage1 = qat_model.fit(train_dataset,
                    validation_data=validation_dataset,
                    epochs=10, 
                    callbacks=[early_stopping]
                    )

# 保存第一阶段模型
qat_model.save('./model/stage1_model.h5')
     

# -------------------------------- 第二阶段训练 ---------------------------------

In [None]:
# 🔹 数据集路径
base_dir = './dataset'
train_dir = os.path.join(base_dir, 'DATA_2')
valid_dir = os.path.join(base_dir, 'DATA_2')

BATCH_SIZE = 128
# 🔹 加载数据集
train_dataset_raw = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir, 
    validation_split=0.8, 
    subset="training", 
    seed=12,
    batch_size=BATCH_SIZE, 
    image_size=IMG_SIZE)

validation_dataset_raw = tf.keras.preprocessing.image_dataset_from_directory(
    valid_dir, 
    validation_split=0.2, 
    subset="validation", 
    seed=12,
    batch_size=BATCH_SIZE, 
    image_size=IMG_SIZE)

class_names = train_dataset_raw.class_names
print("Class Names:", class_names)

# 加载数据集
train_dataset = (train_dataset_raw
                 .map(preprocess_image, num_parallel_calls=AUTOTUNE)
                 .cache('train_cache')  # 缓存到文件
                 .shuffle(1000)
                 .prefetch(AUTOTUNE))

validation_dataset = (validation_dataset_raw
                      .map(preprocess_image, num_parallel_calls=AUTOTUNE)
                      .cache('val_cache')
                      .prefetch(AUTOTUNE))


In [None]:
# 编译模型
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.00001, decay_steps=1000, decay_rate=0.90, staircase=True)

qat_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

# 训练第一阶段
early_stopping = tf.keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True)

history_stage2 = qat_model.fit(train_dataset,
                    validation_data=validation_dataset,
                    epochs=10, 
                    callbacks=[early_stopping]
                    )

# 保存第一阶段模型
qat_model.save('./model/stage2_model.h5')

# -------------------------------- 第三阶段训练 ---------------------------------

In [None]:
# 🔹 数据集路径
base_dir = './dataset'
train_dir = os.path.join(base_dir, 'DATA_ALL')
valid_dir = os.path.join(base_dir, 'DATA_ALL')

BATCH_SIZE = 128
# 🔹 加载数据集
train_dataset_raw = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir, 
    validation_split=0.8, 
    subset="training", 
    seed=12,
    batch_size=BATCH_SIZE, 
    image_size=IMG_SIZE)

validation_dataset_raw = tf.keras.preprocessing.image_dataset_from_directory(
    valid_dir, 
    validation_split=0.2, 
    subset="validation", 
    seed=12,
    batch_size=BATCH_SIZE, 
    image_size=IMG_SIZE)

class_names = train_dataset_raw.class_names
print("Class Names:", class_names)

# 加载数据集
train_dataset = (train_dataset_raw
                 .map(preprocess_image_Au, num_parallel_calls=AUTOTUNE)
                 .cache('train_cache')  # 缓存到文件
                 .shuffle(1000)
                 .prefetch(AUTOTUNE))

validation_dataset = (validation_dataset_raw
                      .map(preprocess_image_Au, num_parallel_calls=AUTOTUNE)
                      .cache('val_cache')
                      .prefetch(AUTOTUNE))


In [None]:
# 编译模型
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.00001, decay_steps=1000, decay_rate=0.90, staircase=True)

qat_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

# 训练第一阶段
early_stopping = tf.keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True)

history_stage3 = qat_model.fit(train_dataset,
                    validation_data=validation_dataset,
                    epochs=10, 
                    callbacks=[early_stopping]
                    )

# 保存第一阶段模型
qat_model.save('./model/stage3_model.h5')

In [None]:
# 加载剪枝后的模型
model = tf.keras.models.load_model('./model/stage3_model.h5')

def representative_dataset():
    for image_batch, _ in tqdm(validation_dataset_raw.take(500), desc="Processing"):
        yield [tf.cast(image_batch, tf.float32)]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model_quant = converter.convert()

# ---- 4. 动态生成带时间的文件名 ----
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
output_tflite_path = f'./model/model_{timestamp}.tflite'  # 新文件名格式

with open(output_tflite_path, 'wb') as f:
    f.write(tflite_model_quant)

target_dir = "model"
# 直接匹配当前目录下的 .h5 文件
for file in os.listdir(target_dir):
    if file.endswith(".h5"):
        file_path = os.path.join(target_dir, file)
        try:
            os.remove(file_path)
            print(f"已删除: {file_path}")
        except Exception as e:
            print(f"删除失败 [{file_path}]: {e}")

In [None]:
# 混淆矩阵
y_pred = np.argmax(model.predict(validation_dataset), axis=1)
y_true = np.concatenate([labels.numpy() for _, labels in validation_dataset])

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, cmap="Blues", fmt="d", 
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

In [None]:
from librariy import polt_improved

stage_names = ["Stage 1", "Stage 2", "Stage 3 (Pruning)"]
history_list = [history_stage1, history_stage2, history_stage3]
polt_improved.plot_combined_curves_improved(history_list)