<a href="https://colab.research.google.com/github/TAKE-JP-17/Pytorch/blob/main/water_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K

# precision
def precision_m(y_true, y_pred):
    true_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1)))
    predicted_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + tf.keras.backend.epsilon())
    return precision

# recall
def recall_m(y_true, y_pred):
    true_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1)))
    possible_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + tf.keras.backend.epsilon())
    return recall

# f1 score
def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))

def dsc(y_true, y_pred):
    smooth = 1.
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    loss = 1 - dsc(y_true, y_pred)
    return loss

In [None]:
from keras import backend as K
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from tensorflow.keras.layers import UpSampling2D, multiply

# 注意力門控區塊（Attention Gating Block）:
# x：來自編碼路徑的特徵圖，通常具有較大的空間解析度。
# g：來自解碼路徑的特徵圖，通常具有較小的空間解析度。
# inter_shape：中間層的通道數，用於縮減特徵圖的通道數以減少計算量。
# activation：激活函數的類型，通常為 relu。
# name：層的名稱前綴，用於命名每個操作，以便在模型中區分不同的層。
def AttnGatingBlock(x, g, inter_shape, activation, name):
    ''' take g which is the spatially smaller signal, do a conv to get the same
    number of feature channels as x (bigger spatially)
    do a conv on x to also get same geature channels (theta_x)
    then, upsample g to be same size as x
    add x and g (concat_xg)
    relu, 1x1 conv, then sigmoid then upsample the final - this gives us attn coefficients'''

    shape_x = x.shape  # 32
    shape_g = g.shape  # 16


    # theta_x：對 x 進行卷積操作，將空間解析度縮小一半，並將通道數減少到 inter_shape。
    # phi_g：對 g 進行1x1卷積，將其通道數減少到 inter_shape，這樣 g 和 x 的通道數相同。
    theta_x = Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same', name='xl'+name)(x)  # 16
    shape_theta_x = theta_x.shape

    phi_g = Conv2D(inter_shape, (1, 1), padding='same')(g)
    # upsample_g：對 g 進行轉置卷積，將其空間解析度擴大到與 theta_x 一致。這一步是為了讓 g 和 x 在空間解析度上相匹配。
    upsample_g = Conv2DTranspose(inter_shape, (3, 3),strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),padding='same', name='g_up'+name)(phi_g)  # 16

    # concat_xg：將上採樣後的 g 和 theta_x 相加。
    # act_xg：對相加結果應用激活函數，通常是 relu。
    concat_xg = add([upsample_g, theta_x])
    act_xg = Activation(activation)(concat_xg)
    # psi：對激活後的特徵圖進行 1x1 卷積，將其通道數縮減為 1。
    # sigmoid_xg：對 psi 的輸出進行 sigmoid 激活，生成一個權重圖（注意力係數），每個值都在 0 到 1 之間。
    psi = Conv2D(1, (1, 1), padding='same', name='psi'+name)(act_xg)
    sigmoid_xg = Activation('sigmoid')(psi)
    shape_sigmoid = sigmoid_xg.shape

    # upsample_psi：將注意力係數圖 sigmoid_xg 上採樣到與 x 的空間解析度相同。
    upsample_psi = UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

    # Expand upsample_psi to match the number of channels in shape_x[3]/expend_as：將 upsample_psi 擴展到與 x 相同的通道數。
    # upsample_psi = expend_as(upsample_psi, shape_x[3], name)

    # Multiply with x/ y：將注意力權重圖 upsample_psi 與原始特徵圖 x 逐元素相乘，這一步突出重要的區域。
    y = multiply([upsample_psi, x], name='q_attn' + name)
    # result：通過 1x1 卷積將通道數恢復到與 x 相同。
    result = Conv2D(shape_x[3], (1, 1), padding='same',name='q_attn_conv'+name)(y)
    # result_bn：對結果進行批量正規化，以穩定訓練過程。
    result_bn = BatchNormalization(name='q_attn_bn'+name)(result)
    return result_bn

# UnetConv2D：實現 U-Net 中的基本卷積塊，通過兩層 3x3 卷積、選擇性批量正規化和激活，提取圖像特徵。
def UnetConv2D(input, outdim, is_batchnorm, activation, name):
	x = Conv2D(outdim, (3, 3), strides=(1, 1), kernel_initializer=kinit, padding="same", name=name+'_1')(input)
	if is_batchnorm:
		x =BatchNormalization(name=name + '_1_bn')(x)
	x = Activation('relu',name=name + '_1_act')(x)

	x = Conv2D(outdim, (3, 3), strides=(1, 1), kernel_initializer=kinit, padding="same", name=name+'_2')(x)
	if is_batchnorm:
		x = BatchNormalization(name=name + '_2_bn')(x)
	x = Activation('relu', name=name + '_2_act')(x)
	return x

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation

def UnetGatingSignal(input, is_batchnorm, name):
    ''' this is simply 1x1 convolution, bn, activation '''
    shape = input.shape
    x = Conv2D(shape[3] * 1, (1, 1), strides=(1, 1), padding="same",  kernel_initializer=kinit, name=name + '_conv')(input)
    if is_batchnorm:
        x = BatchNormalization(name=name + '_bn')(x)
    x = Activation('relu', name = name + '_act')(x)
    return x

K.set_image_data_format('channels_last')  # TF dimension ordering in this code  # 設置 TensorFlow 的維度順序
kinit = 'glorot_normal' # # 權重初始化方法

def attn_unet(lr, loss_func=None, pretrained_weights=None,input_size = (256,256,3)):
    inputs = Input(shape=input_size)

    conv1 = UnetConv2D(inputs, 32, is_batchnorm=True, activation='relu', name='conv1')
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = UnetConv2D(pool1, 64, is_batchnorm=True, activation='relu', name='conv2')
    conv2 = Dropout(0.1,name='drop_conv3')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    center = UnetConv2D(pool2,128,is_batchnorm=True, activation='relu', name='center')

    # Expansion path
    g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1')
    attn1 = AttnGatingBlock(conv2, g1, 128, activation='relu', name='_1')
    # attn1 = Dropout(0.2, name='drop_attn1')(attn1)
    convt1 = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit, name='convt1')(center)
    up1 = concatenate([convt1, attn1], name='up1')

    convt2= Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit, name='convt2')(up1)
    up2 = concatenate([convt2, conv1], name='up2')
    conv10 = Conv2D(1, (1, 1), activation='sigmoid',  kernel_initializer=kinit, name='final')(up2)

    model = Model(inputs, conv10)

    # compile model
    model.compile(optimizer = Adam(learning_rate=lr), loss = loss_func, metrics = ['accuracy', f1_m, precision_m, recall_m, dsc])

    if(pretrained_weights):
        model.load_weights(pretrained_weights)

    return model

model = attn_unet(0.001, loss_func='binary_crossentropy', input_size=(256,256,3))
model.summary()

In [None]:
import os
import shutil
import cv2
import numpy as np
import rasterio
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import random

# 訓練資料夾路徑
original_train_dir = 'C:/Users/User/gary/wd_data_t2/train'
augmented_train_dir = 'C:/Users/User/gary/wd_data_t2/new_train'
val_dir = 'C:/Users/User/gary/wd_data_t2/val'

# 創建新的資料夾結構
os.makedirs(os.path.join(augmented_train_dir, 'Image'), exist_ok=True)
os.makedirs(os.path.join(augmented_train_dir, 'Label'), exist_ok=True)
os.makedirs(os.path.join(val_dir, 'Image'), exist_ok=True)
os.makedirs(os.path.join(val_dir, 'Label'), exist_ok=True)

# **步驟 1：分割資料集**
def split_dataset(train_dir, val_dir, val_ratio=0.1):
    """
    分割訓練資料集，將一部分數據移到驗證資料夾中。
    """
    image_dir = os.path.join(train_dir, 'Image')
    label_dir = os.path.join(train_dir, 'Label')

    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))

    # 將影像與標籤根據名稱中的編號對應
    image_dict = {img.replace('image_tile_', ''): img for img in image_files}
    label_dict = {lbl.replace('label_tile_', ''): lbl for lbl in label_files}

    common_keys = set(image_dict.keys()) & set(label_dict.keys())
    assert len(common_keys) > 0, "無法找到影像與標籤的匹配項！"

    paired_files = [(image_dict[key], label_dict[key]) for key in common_keys]

    # 隨機選擇驗證集索引
    val_size = int(len(paired_files) * val_ratio)
    val_indices = random.sample(range(len(paired_files)), val_size)

    for idx, (image_file, label_file) in enumerate(paired_files):
        src_image_path = os.path.join(image_dir, image_file)
        src_label_path = os.path.join(label_dir, label_file)

        if idx in val_indices:
            shutil.copy(src_image_path, os.path.join(val_dir, 'Image', image_file))
            shutil.copy(src_label_path, os.path.join(val_dir, 'Label', label_file))
        else:
            shutil.copy(src_image_path, os.path.join(augmented_train_dir, 'Image', image_file))
            shutil.copy(src_label_path, os.path.join(augmented_train_dir, 'Label', label_file))

    print(f"分割完成：訓練集 {len(paired_files) - val_size} 個，驗證集 {val_size} 個。")



# **步驟 2：資料增強**
def augment_image_and_label(image, label, augmentations):
    h, w, _ = image.shape

    # 隨機旋轉
    if 'rotation_range' in augmentations:
        angle = np.random.uniform(-augmentations['rotation_range'], augmentations['rotation_range'])
        M = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
        image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
        label = cv2.warpAffine(label.squeeze(), M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
        label = label[..., None]  # 保證單通道

    # 隨機平移
    if 'width_shift_range' in augmentations or 'height_shift_range' in augmentations:
        max_dx = int(augmentations.get('width_shift_range', 0) * w)
        max_dy = int(augmentations.get('height_shift_range', 0) * h)
        dx = np.random.randint(-max_dx, max_dx + 1)
        dy = np.random.randint(-max_dy, max_dy + 1)
        M = np.float32([[1, 0, dx], [0, 1, dy]])
        image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
        label = cv2.warpAffine(label.squeeze(), M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
        label = label[..., None]

    # 隨機縮放
    if 'zoom_range' in augmentations:
        zoom = np.random.uniform(1 - augmentations['zoom_range'], 1 + augmentations['zoom_range'])
        new_w, new_h = int(w * zoom), int(h * zoom)
        image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        label = cv2.resize(label.squeeze(), (new_w, new_h), interpolation=cv2.INTER_NEAREST)[..., None]

        if zoom > 1:  # 裁切
            start_x = (new_w - w) // 2
            start_y = (new_h - h) // 2
            image = image[start_y:start_y + h, start_x:start_x + w]
            label = label[start_y:start_y + h, start_x:start_x + w]
        else:  # 填充
            pad_x = (w - new_w) // 2
            pad_y = (h - new_h) // 2
            image = cv2.copyMakeBorder(image, pad_y, pad_y, pad_x, pad_x, cv2.BORDER_REFLECT)
            label = cv2.copyMakeBorder(label, pad_y, pad_y, pad_x, pad_x, cv2.BORDER_REFLECT)

    # 隨機水平翻轉
    if augmentations.get('horizontal_flip', False) and np.random.rand() > 0.5:
        image = cv2.flip(image, 1)
        label = cv2.flip(label, 1)

    # 隨機垂直翻轉
    if augmentations.get('vertical_flip', False) and np.random.rand() > 0.5:
        image = cv2.flip(image, 0)
        label = cv2.flip(label, 0)

    # 確保標籤影像為單通道
    label = label.squeeze()[..., None]  # 維持單通道結構

    return image, label


def augment_data(data_dir, augmentations, target_size=None, augment_count=5):
    """
    資料增強函數，對影像和標籤進行同步增強。
    """
    image_dir = os.path.join(data_dir, 'Image')
    label_dir = os.path.join(data_dir, 'Label')
    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))

    for image_file, label_file in zip(image_files, label_files):
        with rasterio.open(os.path.join(image_dir, image_file)) as src:
            img = src.read()
            img = np.transpose(img, (1, 2, 0))

        with rasterio.open(os.path.join(label_dir, label_file)) as src:
            lbl = src.read(1)[..., None]  # 確保標籤讀取為單通道

        if target_size:
            img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
            lbl = cv2.resize(lbl.squeeze(), target_size, interpolation=cv2.INTER_NEAREST)[..., None]

        for i in range(augment_count):
            aug_img, aug_lbl = augment_image_and_label(img, lbl, augmentations)

            # 儲存增強後的影像和標籤
            output_image_path = os.path.join(image_dir, f"{image_file.replace('.tif', '')}_aug_{i}.tif")
            output_label_path = os.path.join(label_dir, f"{label_file.replace('.tif', '')}_aug_{i}.tif")

            with rasterio.open(output_image_path, 'w', driver='GTiff',
                               height=aug_img.shape[0], width=aug_img.shape[1],
                               count=aug_img.shape[2], dtype=aug_img.dtype) as dst:
                dst.write(np.transpose(aug_img, (2, 0, 1)))

            with rasterio.open(output_label_path, 'w', driver='GTiff',
                               height=aug_lbl.shape[0], width=aug_lbl.shape[1],
                               count=1,  # 強制單通道
                               dtype=aug_lbl.dtype) as dst:
                dst.write(aug_lbl.squeeze(), 1)


    print("資料增強完成！")


# 分割資料集並進行增強
split_dataset(original_train_dir, val_dir)
augmentations = {
    'rotation_range': 30,
    'width_shift_range': 0.2,
    'height_shift_range': 0.2,
    'shear_range': 0.2,
    'zoom_range': 0.2,
    'horizontal_flip': True,
    'vertical_flip': True,
    'fill_mode': 'reflect'
}
augment_data(augmented_train_dir, augmentations, augment_count=5)

# **步驟 3：定義數據生成器**
def data_generator_from_dir(data_dir, batch_size):
    image_dir = os.path.join(data_dir, 'Image')
    label_dir = os.path.join(data_dir, 'Label')

    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))

    while True:
        for i in range(0, len(image_files), batch_size):
            batch_image_files = image_files[i:i + batch_size]
            batch_label_files = label_files[i:i + batch_size]

            images = []
            labels = []

            for img_file, label_file in zip(batch_image_files, batch_label_files):
                img_path = os.path.join(image_dir, img_file)
                label_path = os.path.join(label_dir, label_file)

                with rasterio.open(img_path) as src:
                    img = src.read()
                    img = np.transpose(img, (1, 2, 0))
                    img_min, img_max = img.min(), img.max()
                    if img_max > img_min:
                        img = (img - img_min) / (img_max - img_min)
                    else:
                        img = np.zeros_like(img)
                    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR)

                with rasterio.open(label_path) as src:
                    label = src.read(1)  # 僅讀取第一個通道
                    label = cv2.resize(label, (256, 256), interpolation=cv2.INTER_NEAREST)
                    label = label[..., None]/255  # 保持單通道結構

                images.append(img.astype(np.float32))
                labels.append(label.astype(np.float32))


            images = np.array(images, dtype=np.float32)
            labels = np.array(labels, dtype=np.float32)
            #print(f"Images shape: {images.shape}, Labels shape: {labels.shape}")

            yield (images, labels)




# 訓練和驗證資料生成器
batch_size = 4
train_generator = data_generator_from_dir(augmented_train_dir, batch_size=batch_size)
val_generator = data_generator_from_dir(val_dir, batch_size=batch_size)
sample_batch = next(train_generator)
print(f"Sample batch images shape: {sample_batch[0].shape}")
print(f"Sample batch labels shape: {sample_batch[1].shape}")


steps_per_epoch = len(os.listdir(os.path.join(augmented_train_dir, 'Image'))) // batch_size
validation_steps = len(os.listdir(os.path.join(val_dir, 'Image'))) // batch_size

# 設置檢查點回調
checkpoint = ModelCheckpoint('C:/Users/User/gary/space/01_model/wd_t2_e50_att_unet_model_checkpoint.keras', save_best_only=True, monitor='val_loss', mode='min')


# 訓練模型
history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=validation_steps,
    epochs=50,
    verbose=1,
    callbacks=[checkpoint]
)

# **步驟 4：繪製訓練曲線**
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], 'bo-', label='Training Loss')
plt.plot(history.history['val_loss'], 'ro-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import rasterio
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score

import warnings
warnings.filterwarnings('ignore', category=rasterio.errors.NotGeoreferencedWarning)

# 定義批次加載數據生成器
def data_generator_from_dir(data_dir, batch_size):
    image_dir = os.path.join(data_dir, 'Image')
    label_dir = os.path.join(data_dir, 'Label')

    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))

    while True:
        for i in range(0, len(image_files), batch_size):
            batch_image_files = image_files[i:i + batch_size]
            batch_label_files = label_files[i:i + batch_size]

            images = []
            labels = []

            for img_file, label_file in zip(batch_image_files, batch_label_files):
                img_path = os.path.join(image_dir, img_file)
                label_path = os.path.join(label_dir, label_file)

                with rasterio.open(img_path) as src:
                    img = src.read()
                    img = np.transpose(img, (1, 2, 0))
                    img_min, img_max = img.min(), img.max()
                    if img_max > img_min:
                        img = (img - img_min) / (img_max - img_min)
                    else:
                        img = np.zeros_like(img)
                    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR)

                with rasterio.open(label_path) as src:
                    label = src.read(1)  # 僅讀取第一個通道
                    label = cv2.resize(label, (256, 256), interpolation=cv2.INTER_NEAREST)
                    label = label[..., None]/255  # 保持單通道結構

                images.append(img.astype(np.float32))
                labels.append(label.astype(np.float32))


            images = np.array(images, dtype=np.float32)
            labels = np.array(labels, dtype=np.float32)
            #print(f"Images shape: {images.shape}, Labels shape: {labels.shape}")

            yield (images, labels)

# 訓練和驗證資料生成器
batch_size = 8
train_data_dir = 'C:/Users/User/gary/wd_data_t2/new_train'
val_data_dir = 'C:/Users/User/gary/wd_data_t2/val'
test_data_dir = 'C:/Users/User/gary/wd_data_t2/test'

train_generator = data_generator_from_dir(train_data_dir, batch_size=batch_size)
val_generator = data_generator_from_dir(val_data_dir, batch_size=batch_size)
test_generator = data_generator_from_dir(test_data_dir, batch_size=batch_size)

# 定義儲存影像的函數
def save_sample(X, label, pred_label, index, folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    # 儲存影像與標籤
    fig, axes = plt.subplots(1, 3, figsize=(20, 10))

    # 提取影像
    image = X[0]  # 3通道的 RGB 影像

    # 顯示 RGB 影像
    axes[0].imshow(image)  # 正常顯示 RGB 圖像
    axes[0].set_title('RGB Image')
    axes[0].axis('off')

    # 顯示真實標籤 (單通道)
    axes[1].imshow(label[0, :, :, 0], cmap='gray')
    axes[1].set_title('Ground Truth Label')
    axes[1].axis('off')

    # 顯示預測標籤 (經過閥值處理)
    axes[2].imshow(pred_label[0, :, :, 0], cmap='gray')
    axes[2].set_title('Predicted Label (Thresholded)')
    axes[2].axis('off')

    # 計算評估指標
    label_flat = label[0, :, :, 0].flatten()
    pred_flat = pred_label[0, :, :, 0].flatten()

    # 檢查全零情況
    if np.sum(label_flat) == 0 and np.sum(pred_flat) == 0:
        accuracy, precision, recall, f1 = 1.0, 1.0, 1.0, 1.0
    else:
        accuracy = accuracy_score(label_flat, pred_flat)
        precision = precision_score(label_flat, pred_flat, zero_division=1)
        recall = recall_score(label_flat, pred_flat, zero_division=1)
        f1 = f1_score(label_flat, pred_flat, zero_division=1)


    # 在圖表上顯示評估指標
    fig.suptitle(
        f'Accuracy: {accuracy:.4f}\n'
        f'Precision: {precision:.4f}\n'
        f'Recall: {recall:.4f}\n'
        f'F1 Score: {f1:.4f}\n',
        #f'Kappa: {kappa:.4f}',
        fontsize=16
    )

    # 儲存影像
    plt.savefig(os.path.join(folder_path, f'sample_{index}.png'))
    plt.close()  # 關閉圖表，釋放內存
# 使用批次進行預測和儲存結果
def save_results_in_batches(generator, dataset_name, folder_path, total_images, batch_size=8, threshold=0.5):
    print(f'\nSaving results for {dataset_name}...')

    # 計數器
    index = 0

    for X_batch, label_batch in generator:
        # 預測模型輸出
        pred_label_batch = model.predict(X_batch)

        # 對預測結果進行閥值處理（二元化）
        pred_label_batch = (pred_label_batch >= threshold).astype(np.uint8)

        # 儲存每一批次的結果
        for j in range(len(X_batch)):
            if index >= total_images:  # 停止條件，避免超過總影像數量
                print(f"Reached total of {total_images} images, stopping...")
                return

            save_sample(X_batch[j:j+1], label_batch[j:j+1],
                        pred_label_batch[j:j+1],
                        index=index, folder_path=folder_path)
            index += 1


# 加載已訓練好的模型
model = load_model('C:/Users/User/gary/space/01_model/wd_t2_e50_att_unet_model_checkpoint.keras', compile=False)

# 設定步數，根據資料數量和批次大小計算
steps_per_epoch = len(os.listdir(os.path.join(train_data_dir, 'Image'))) // batch_size
validation_steps = len(os.listdir(os.path.join(val_data_dir, 'Image'))) // batch_size

# 計算訓練集總影像數量
total_train_images = len(os.listdir(os.path.join(train_data_dir, 'Image')))

# 儲存訓練集結果
save_results_in_batches(train_generator, 'Training Set', 'C:/Users/User/gary/space/model/01result/02_test2_e50/02_t2_e50_training_results', total_images=total_train_images, batch_size=batch_size)

# 計算驗證集總影像數量
#total_val_images = len(os.listdir(os.path.join(val_data_dir, 'Image')))

# 儲存驗證集結果
#save_results_in_batches(val_generator, 'Validation Set', 'C:/Users/User/gary/space/model/01result/02_test2_e50/02_t2_e50_validation_results', total_images=total_val_images, batch_size=batch_size)

# 計算測試集總影像數量
#total_test_images = len(os.listdir(os.path.join(test_data_dir, 'Image')))

# 儲存測試集結果
#save_results_in_batches(test_generator, 'Test Set', 'C:/Users/User/gary/space/model/01result/02_test2_e50/02_t2_e50_test_results', total_images=total_test_images, batch_size=batch_size)
