In [None]:
import tensorflow as tf
import os
import numpy as np
from tensorflow.keras import layers, models, mixed_precision
from PIL import Image
import glob
import matplotlib.pyplot as plt
import time # 用于计时

# --- 初始配置与全局参数 ---
print("--- 初始化 TensorFlow 和环境配置 ---")
print("TensorFlow 版本:", tf.__version__)

# 1. 配置 GPU
print("\n--- 配置 GPU ---")
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    try:
        # 设置内存增长，避免一次性占用所有显存
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        print(f"GPU 可用，内存增长已启用: {physical_devices[0].name}")
    except RuntimeError as e:
        print(f"无法设置内存增长: {e}")
else:
    print("未检测到 GPU，将使用 CPU (速度会非常慢)")
print("--- GPU 配置结束 ---")

# 2. 配置混合精度
print("\n--- 配置混合精度 ---")
policy_name = 'mixed_float16' # 适用于 Compute Capability >= 7.0 的 NVIDIA GPU
gpu_supported_fp16 = False
if physical_devices:
    try:
        details = tf.config.experimental.get_device_details(physical_devices[0])
        cc = details.get('compute_capability')
        if cc and cc[0] >= 7:
            gpu_supported_fp16 = True
            print(f"检测到支持 FP16 的 GPU (Compute Capability {cc[0]}.{cc[1]})")
        else:
            print(f"GPU Compute Capability ({cc}) 可能不支持 FP16 或效率不高 (需要 >= 7.0)")
    except Exception as e:
         print(f"检查 GPU 能力时出错: {e}，假定不支持 FP16")

if gpu_supported_fp16:
    try:
        print(f"尝试启用混合精度策略 '{policy_name}'")
        policy = mixed_precision.Policy(policy_name)
        mixed_precision.set_global_policy(policy)
        print('策略设置成功!')
        print(f'计算精度 (Compute dtype): {policy.compute_dtype}')
        print(f'变量精度 (Variable dtype): {policy.variable_dtype}')
    except Exception as e:
        print(f"设置混合精度策略 '{policy_name}' 时出错: {e}。将使用 float32。")
        policy = mixed_precision.Policy('float32')
        mixed_precision.set_global_policy(policy)
        print("已回退到 float32 策略。")
else:
    print("未检测到支持 FP16 的 GPU 或不满足条件。使用默认 float32 精度。")
    policy = mixed_precision.Policy('float32') # 确保 policy 被定义
    mixed_precision.set_global_policy(policy)
    print('计算精度 (Compute dtype): float32')
    print('变量精度 (Variable dtype): float32')
print("--- 混合精度配置结束 ---")

# 3. 路径设置
base_path = r"D:\Class\CV\Task2\CDSet\dataset_YOLO_format_3434" # *请根据你的实际路径修改*
train_img_path = os.path.join(base_path, "images", "train")
train_mask_path = os.path.join(base_path, "masks_crosswalk", "train") # 使用预处理的掩码
test_img_path = os.path.join(base_path, "images", "test")
test_mask_path = os.path.join(base_path, "masks_crosswalk", "test")   # 使用预处理的掩码

# 4. 图像和训练参数
IMG_HEIGHT = 352  # 降低后的分辨率
IMG_WIDTH = 640   # 降低后的分辨率
NUM_CLASSES = 2   # 背景 (0) + crosswalk (1)
BATCH_SIZE = 8    # 增大后的批量大小 (如果OOM则减小)
EPOCHS = 5        # 训练轮数

print("\n--- 全局参数设置 ---")
print(f"图像分辨率: {IMG_HEIGHT}x{IMG_WIDTH}")
print(f"批量大小 (Batch Size): {BATCH_SIZE}")
print(f"目标类别数 (含背景): {NUM_CLASSES}")
print(f"训练轮数 (Epochs): {EPOCHS}")
print(f"训练图像路径: {train_img_path}")
print(f"训练掩码路径: {train_mask_path}")
print(f"测试图像路径: {test_img_path}")
print(f"测试掩码路径: {test_mask_path}")
print("--- 参数设置结束 ---")

# --- 辅助函数定义 ---

# 5. 数据生成器 (加载和基础预处理)
def create_dataset(img_dir, mask_dir, batch_size, shuffle=True, is_training=True):
    """
    创建 TensorFlow 数据集。
    加载图像和对应的单通道掩码，进行缩放和归一化。
    """
    phase = "训练" if is_training else "测试"
    print(f"\n--- 开始准备 {phase} 数据集 ---")
    print(f"正在从 {img_dir} 和 {mask_dir} 匹配图像和掩码...")
    img_files = glob.glob(os.path.join(img_dir, "*.jpg")) + glob.glob(os.path.join(img_dir, "*.png"))

    mask_files = []
    valid_img_files = []
    for img_f in img_files:
        base_name = os.path.splitext(os.path.basename(img_f))[0]
        mask_f = os.path.join(mask_dir, base_name + ".png") # 查找对应的 .png 掩码
        if os.path.exists(mask_f):
            mask_files.append(mask_f)
            valid_img_files.append(img_f)
        # else:
        #     # print(f"警告：找不到图像 '{img_f}' 对应的掩码文件 '{mask_f}'") # 取消过多打印
        #     pass

    print(f"找到 {len(img_files)} 张图像，{len(valid_img_files)} 个有效图像-掩码对")

    if not valid_img_files:
        raise ValueError(f"在 {mask_dir} 中没有找到有效的掩码文件。请确保已运行预处理脚本！")

    # 定义加载和基础预处理函数 (不含 one-hot)
    def load_and_preprocess(img_path, mask_path):
        img = tf.io.read_file(img_path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
        img = tf.cast(img, tf.float32) / 255.0
        img = tf.clip_by_value(img, 0.0, 1.0)
        img = tf.ensure_shape(img, [IMG_HEIGHT, IMG_WIDTH, 3])

        mask = tf.io.read_file(mask_path)
        mask = tf.image.decode_png(mask, channels=1)
        mask = tf.image.resize(mask, [IMG_HEIGHT, IMG_WIDTH], method='nearest')
        mask = tf.cast(mask, tf.uint8)
        mask = tf.clip_by_value(mask, 0, NUM_CLASSES - 1)
        mask = tf.ensure_shape(mask, [IMG_HEIGHT, IMG_WIDTH, 1])
        return img, mask

    dataset = tf.data.Dataset.from_tensor_slices((valid_img_files, mask_files))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(valid_img_files))

    dataset = dataset.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE) # 提升性能

    print(f"--- {phase} 数据集准备完成 ---")
    return dataset

# 6. 数据可视化函数 (可视化原始单通道掩码)
def visualize_dataset(dataset, num_samples=3):
    """
    从数据集中取样并可视化图像和对应的单通道掩码。
    """
    print("\n--- 可视化部分数据样本 (图像和对应原始掩码) ---")
    try:
        for img_batch, mask_batch in dataset.take(1): # 只取一个批次
            for i in range(min(num_samples, img_batch.shape[0])):
                plt.figure(figsize=(12, 4))
                plt.subplot(1, 2, 1)
                plt.title("Image")
                plt.imshow(img_batch[i])
                plt.axis('off')

                plt.subplot(1, 2, 2)
                plt.title("Mask (Crosswalk=1, Background=0)")
                # 显示 uint8 的单通道掩码
                plt.imshow(tf.squeeze(mask_batch[i]), cmap='gray', vmin=0, vmax=NUM_CLASSES-1)
                plt.colorbar()
                plt.axis('off')
                plt.show()
            break # 只显示第一批次的前几个
    except Exception as e:
        print(f"可视化数据时出错: {e}")
        print("请确保数据集已正确加载。")
    print("--- 可视化结束 ---")


In [None]:
# 7. One-Hot 编码预处理函数 (用于训练)
@tf.function # 转换为 TF 图以提高效率
def preprocess_data_for_training(img, mask):
    """
    将单通道掩码转换为 One-Hot 编码，用于训练。
    """
    mask = tf.cast(mask, tf.int32)
    mask = tf.squeeze(mask, axis=-1) # (H, W, 1) -> (H, W)
    mask = tf.one_hot(mask, depth=NUM_CLASSES, dtype=tf.float32) # (H, W) -> (H, W, NUM_CLASSES)
    return img, mask

# 8. U-Net 模型定义
def unet_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES):
    """
    定义 U-Net 模型结构。
    """
    inputs = tf.keras.Input(shape=input_shape)

    # 编码器 (下采样)
    c1 = layers.Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    c1 = layers.Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p1)
    c2 = layers.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p2)
    c3 = layers.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p3)
    c4 = layers.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c4)
    p4 = layers.MaxPooling2D((2, 2))(c4)

    # 瓶颈层
    c5 = layers.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p4)
    c5 = layers.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c5)

    # 解码器 (上采样 + 跳跃连接)
    u6 = layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u6)
    c6 = layers.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c6)

    u7 = layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u7)
    c7 = layers.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c7)

    u8 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u8)
    c8 = layers.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c8)

    u9 = layers.Conv2DTranspose(32, 2, strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(u9)
    c9 = layers.Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c9)

    # 输出层 - 使用 softmax 输出概率
    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(c9)

    # 重要: 确保混合精度训练时最后输出是 float32
    outputs = layers.Activation('linear', dtype='float32')(outputs)

    model = models.Model(inputs, outputs, name="unet_crosswalk")
    return model


In [None]:
# 9. 自定义回调函数 (用于监控训练)
class CustomCallback(tf.keras.callbacks.Callback):
    """
    自定义回调，用于在每个 Epoch 结束时打印耗时和指标。
    """
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        epoch_duration = time.time() - self.epoch_start_time
        print(f"\nEpoch {epoch + 1}/{self.params['epochs']} 完成，耗时: {epoch_duration:.2f} 秒")
        if logs: # 确保 logs 不为 None
            print(f"  训练集 - loss: {logs.get('loss'):.4f}, accuracy: {logs.get('accuracy'):.4f}, mean_io_u: {logs.get('mean_io_u'):.4f}")
            print(f"  验证集 - val_loss: {logs.get('val_loss'):.4f}, val_accuracy: {logs.get('val_accuracy'):.4f}, val_mean_io_u: {logs.get('val_mean_io_u'):.4f}")
        else:
            print("  未能获取此 Epoch 的训练指标。")

# 10. 训练主函数
def train_model():
    """
    封装了模型训练的主要流程：数据准备、模型构建、编译和训练。
    """
    # 创建基础数据集 (不含 one-hot)
    train_dataset_base = create_dataset(train_img_path, train_mask_path, BATCH_SIZE, shuffle=True, is_training=True)
    test_dataset_base = create_dataset(test_img_path, test_mask_path, BATCH_SIZE, shuffle=False, is_training=False)

    # 可视化基础数据集中的样本
    visualize_dataset(train_dataset_base, num_samples=2)

    # 应用 One-Hot 编码，得到用于训练的数据集
    print("\n--- 对数据集应用 One-Hot 编码 ---")
    train_dataset = train_dataset_base.map(preprocess_data_for_training, num_parallel_calls=tf.data.AUTOTUNE)
    test_dataset = test_dataset_base.map(preprocess_data_for_training, num_parallel_calls=tf.data.AUTOTUNE)
    print("--- One-Hot 编码应用完成 ---")

    # 检查一个批次的形状 (调试)
    print("\n--- 检查训练数据批次形状 ---")
    for img_batch, mask_batch in train_dataset.take(1):
        print("应用 preprocess_data_for_training 后:")
        print(f"图像批次形状: {img_batch.shape}") # 应为 (B, H, W, 3)
        print(f"掩码批次形状: {mask_batch.shape}") # 应为 (B, H, W, NUM_CLASSES)
        print(f"掩码数据类型: {mask_batch.dtype}") # 应为 float32
        # 检查 One-Hot
        sample_pixel = mask_batch[0, IMG_HEIGHT // 2, IMG_WIDTH // 2, :]
        print(f"掩码 One-Hot 检查 (样本0中心像素): {sample_pixel.numpy()}")
        print(f"掩码 One-Hot 检查 (中心像素类别和): {tf.reduce_sum(sample_pixel).numpy()}") # 应为 1.0
    print("--- 形状检查结束 ---")

    print("\n--- 构建和编译模型 ---")
    model = unet_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES)
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy', # 因为使用了 one-hot
                  metrics=['accuracy', tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES, name='mean_io_u')])
    model.summary()
    print("--- 模型构建和编译完成 ---")

    print(f"\n--- 开始训练，共 {EPOCHS} 个 Epoch ---")
    start_time = time.time()
    history = model.fit(
        train_dataset,
        epochs=EPOCHS,
        validation_data=test_dataset,
        callbacks=[CustomCallback()]
    )
    end_time = time.time()
    total_training_time = end_time - start_time
    print(f"--- 训练完成，总耗时: {total_training_time:.2f} 秒 ---")

    # 保存模型
    model_save_path = "unet_crosswalk_model_final_refactored.h5"
    print(f"\n--- 保存模型到 {model_save_path} ---")
    try:
        model.save(model_save_path)
        print("--- 模型保存完成 ---")
    except Exception as e:
        print(f"模型保存失败: {e}")

    return history, model

# 11. 绘制训练曲线函数
def plot_history(history):
    """
    绘制训练过程中的准确率、损失和 MeanIoU 曲线。
    """
    print("\n--- 绘制训练和验证曲线 ---")
    plt.figure(figsize=(18, 5))

    # 准确率
    plt.subplot(1, 3, 1)
    plt.plot(history.history.get('accuracy', []), label='Train Accuracy', marker='o')
    plt.plot(history.history.get('val_accuracy', []), label='Val Accuracy', marker='x')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    # 损失
    plt.subplot(1, 3, 2)
    plt.plot(history.history.get('loss', []), label='Train Loss', marker='o')
    plt.plot(history.history.get('val_loss', []), label='Val Loss', marker='x')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # MeanIoU
    iou_metric_name = 'mean_io_u'
    val_iou_metric_name = 'val_mean_io_u'
    if iou_metric_name in history.history and val_iou_metric_name in history.history:
        plt.subplot(1, 3, 3)
        plt.plot(history.history[iou_metric_name], label='Train MeanIoU', marker='o')
        plt.plot(history.history[val_iou_metric_name], label='Val MeanIoU', marker='x')
        plt.title('Model Mean IoU')
        plt.xlabel('Epoch')
        plt.ylabel('Mean IoU')
        plt.legend()
        plt.grid(True)
    else:
        print(f"警告：找不到 MeanIoU 指标 ('{iou_metric_name}' 或 '{val_iou_metric_name}')")
        print("可用指标:", list(history.history.keys()))

    plt.tight_layout()
    plt.suptitle('U-Net Training Metrics', fontsize=16, y=1.02)
    plt.show()
    print("--- 曲线绘制完成 ---")

# 12. 测试和可视化预测结果函数
def test_and_visualize(model, num_samples=1):
    """
    在测试集上进行预测，并可视化输入图像、真实掩码和预测掩码。
    """
    print("\n--- 开始在测试集上进行预测和可视化 ---")
    # 需要重新创建一个不带 prefetch 的数据集，或者一个只用于可视化的数据集
    # 这里我们复用 create_dataset，但不进行 shuffle 和 prefetch (prefetch 可能影响 take)
    test_dataset_for_vis = create_dataset(test_img_path, test_mask_path, batch_size=BATCH_SIZE, shuffle=False, is_training=False)
    test_dataset_for_vis = test_dataset_for_vis.unbatch().batch(BATCH_SIZE) # 确保批次大小一致，但不预取

    samples_shown = 0
    try:
        for img_batch, mask_batch_true_label in test_dataset_for_vis: # mask 是 (B, H, W, 1), 值为 0 或 1
            # 模型预测
            predictions_probs = model.predict(img_batch) # (B, H, W, 2) 概率
            predictions_labels = tf.argmax(predictions_probs, axis=-1) # (B, H, W) 标签

            for i in range(img_batch.shape[0]):
                if samples_shown >= num_samples:
                    break

                plt.figure(figsize=(18, 6))

                # 原图
                plt.subplot(1, 3, 1)
                plt.title("Input Image")
                plt.imshow(img_batch[i])
                plt.axis('off')

                # 真实掩码 (Ground Truth)
                plt.subplot(1, 3, 2)
                plt.title("Ground Truth (Crosswalk=1)")
                plt.imshow(tf.squeeze(mask_batch_true_label[i]), cmap='gray', vmin=0, vmax=1) # 显示原始 0/1 掩码
                plt.axis('off')

                # 预测掩码
                plt.subplot(1, 3, 3)
                plt.title("Prediction (Crosswalk=1)")
                plt.imshow(predictions_labels[i], cmap='gray', vmin=0, vmax=1) # 显示预测 0/1 标签
                plt.axis('off')

                plt.tight_layout()
                plt.show()
                samples_shown += 1

            if samples_shown >= num_samples:
                break
    except Exception as e:
        print(f"测试和可视化过程中出错: {e}")
    print("--- 预测可视化完成 ---")


In [None]:
# --- 主执行流程 ---
if __name__ == "__main__": # 适用于脚本执行，在 Notebook 中直接运行即可
    # 训练模型
    history, trained_model = train_model()

    # 绘制历史曲线
    if history:
        plot_history(history)
    else:
        print("未能获取训练历史，无法绘制曲线。")

    # 在测试集上可视化结果
    if trained_model:
        test_and_visualize(trained_model, num_samples=3) # 可视化 3 个样本
    else:
        print("模型训练失败或未返回，无法进行测试可视化。")

    print("\n--- 脚本执行完毕 ---")

# 应用到自己的图像 （要提前处理成352*640适应模型）

In [None]:
import tensorflow as tf
from tensorflow.keras import models, mixed_precision
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import time


IMG_HEIGHT = 352
IMG_WIDTH = 640
NUM_CLASSES = 2
MODEL_PATH = "unet_crosswalk_model_final.h5" 

INPUT_IMAGE_DIR = r"D:\Class\CV\Task2\CampusZebra_new" 
OUTPUT_MASK_DIR = r"D:\Class\CV\Task2\CampusZebra_result" 
#可视化数量
NUM_SAMPLES_TO_SHOW = 2

def preprocess_single_image_for_prediction(image_path, target_height=IMG_HEIGHT, target_width=IMG_WIDTH):

    try:
        img = Image.open(image_path).convert('RGB') 
        img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
        img_array = np.array(img_resized)
        img_normalized = img_array.astype(np.float32) / 255.0
        img_normalized = np.clip(img_normalized, 0.0, 1.0)
        img_batch = np.expand_dims(img_normalized, axis=0)
        return img_batch
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

def predict_and_save_masks(model, input_dir, output_dir):
    print(f"\n--- 开始批量预测 (可视化前) ---")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    image_files = glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))
    if not image_files:
        print(f"错误：在 {input_dir} 未找到图像。")
        return

    print(f"找到 {len(image_files)} 张图像进行预测...")
    for i, img_path in enumerate(image_files):
        img_batch = preprocess_single_image_for_prediction(img_path)
        if img_batch is None: continue

        predictions_probs = model.predict(img_batch, verbose=0)
        predictions_labels = tf.argmax(predictions_probs, axis=-1)
        predicted_mask_labels = tf.squeeze(predictions_labels, axis=0).numpy()
        output_mask_array = (predicted_mask_labels * 255).astype(np.uint8)
        output_mask_image = Image.fromarray(output_mask_array, mode='L')

        base_name = os.path.splitext(os.path.basename(img_path))[0]
        output_filename = base_name + "_pred_mask.png"
        output_path = os.path.join(output_dir, output_filename)
        try:
            output_mask_image.save(output_path)
        except Exception as e:
            print(f"保存掩码 {output_filename} 出错: {e}")
    print(f"--- 批量预测完成，掩码保存在: {output_dir} ---")


# ---可视化函数 ---
def visualize_predictions(input_dir, predicted_mask_dir, num_samples=NUM_SAMPLES_TO_SHOW):
    """
    加载原始图像和对应的已保存预测掩码，并排显示。
    """
    print(f"\n--- 开始可视化预测结果 ---")
    print(f"原始图像目录: {input_dir}")
    print(f"预测掩码目录: {predicted_mask_dir}")

    if not os.path.isdir(input_dir):
        print(f"错误：找不到输入图像目录 {input_dir}")
        return
    if not os.path.isdir(predicted_mask_dir):
        print(f"错误：找不到预测掩码目录 {predicted_mask_dir}。请先运行预测。")
        return

    image_files = glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))

    if not image_files:
        print(f"错误：在 {input_dir} 中未找到用于可视化的图像。")
        return

    print(f"将尝试可视化最多 {num_samples} 个样本...")
    samples_shown = 0
    for img_path in image_files:
        if samples_shown >= num_samples:
            break

        base_name = os.path.splitext(os.path.basename(img_path))[0]
        predicted_mask_filename = base_name + "_pred_mask.png"
        predicted_mask_path = os.path.join(predicted_mask_dir, predicted_mask_filename)

        if not os.path.exists(predicted_mask_path):
            print(f"警告：找不到图像 '{os.path.basename(img_path)}' 对应的预测掩码 '{predicted_mask_filename}'，跳过。")
            continue

        try:
            # 加载原始图像
            original_image = Image.open(img_path).convert('RGB')
            # 加载预测掩码 (灰度图)
            predicted_mask_image = Image.open(predicted_mask_path).convert('L')

            # 创建图像
            plt.figure(figsize=(12, 6))

            # 左侧：原始图像
            plt.subplot(1, 2, 1)
            plt.title(f"Original Image: {os.path.basename(img_path)}")
            plt.imshow(original_image)
            plt.axis('off')

            # 右侧：预测掩码
            plt.subplot(1, 2, 2)
            plt.title(f"Predicted Mask: {predicted_mask_filename}")
            plt.imshow(predicted_mask_image, cmap='gray', vmin=0, vmax=255) # 灰度图显示
            plt.axis('off')

            plt.tight_layout()
            plt.show()
            samples_shown += 1

        except Exception as e:
            print(f"处理或显示图像 {os.path.basename(img_path)} 时出错: {e}")

    if samples_shown == 0:
         print("未能成功可视化任何样本。请检查路径和文件是否存在。")
    print(f"--- 可视化完成，共显示 {samples_shown} 个样本 ---")


# --- 主执行部分 ---
if __name__ == "__main__":
    # 检查输入输出目录是否已设置
    if "path/to/your" in INPUT_IMAGE_DIR or "path/to/save" in OUTPUT_MASK_DIR:
         print("错误：请先在代码中修改 INPUT_IMAGE_DIR 和 OUTPUT_MASK_DIR 为你的实际路径！")
    else:
        # ---- 选择性执行：如果还没有预测结果，先执行预测 ----
        # 1. 加载模型
        if not os.path.exists(MODEL_PATH):
            print(f"错误：找不到模型文件 {MODEL_PATH}。无法执行预测。")
            model_available = False
        else:
            try:
                loaded_model = models.load_model(MODEL_PATH, compile=False)
                print("模型加载成功！")
                model_available = True
            except Exception as e:
                print(f"加载模型时出错: {e}")
                model_available = False

        # 2. 运行预测 (如果需要且模型可用)
        run_prediction = False # 设置为 True 如果你需要先生成预测掩码文件
        if run_prediction and model_available:
             predict_and_save_masks(loaded_model, INPUT_IMAGE_DIR, OUTPUT_MASK_DIR)
        elif run_prediction and not model_available:
             print("模型未加载，无法执行预测。")
        else:
             print("跳过预测步骤，直接进行可视化（假设掩码已存在）。")
        # ---- 预测部分结束 ----


        # 3. 可视化结果 (无论是否刚刚预测，都尝试可视化)
        visualize_predictions(INPUT_IMAGE_DIR, OUTPUT_MASK_DIR, num_samples=NUM_SAMPLES_TO_SHOW)

    print("\n--- 脚本执行完毕 ---")