# 生成叶片前景区域（MOCO之后操作）

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

def kd_fill_holes(image_path):
    # 读取图像
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    # 二值化
    ret, binary_image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    # 填充孔洞
    filled_image = cv2.bitwise_not(cv2.bitwise_not(binary_image.copy()))
    kernel = np.ones((5,5),np.uint8)
    filled_image = cv2.bitwise_not(cv2.bitwise_not(cv2.morphologyEx(filled_image, cv2.MORPH_CLOSE, kernel)))
    return filled_image

def segment_foreground(original_image, mask_image):
    # 将掩码图转换为二值图像
    _, binary_mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)
    # 使用掩码图过滤原始图像
    foreground = cv2.bitwise_and(original_image, original_image, mask=binary_mask)
    return foreground

# 原图文件夹和掩码图文件夹路径
original_images_folder = "./experiments/data/Publish_Dataset/Pixel-level_annotation/Image/"
mask_images_folder = "./experiments/predictions/CCAM_Apple_MOCO@scale=0.5,1.0,1.5,2.0@t=0.5@ccam_inference_crf=10/"

# 遍历原图文件夹
for filename in os.listdir(original_images_folder):
    if filename.endswith(".jpg"):
        original_image_path = os.path.join(original_images_folder, filename)
        
        # 加载原始图像
        original_image = cv2.imread(original_image_path)

        # 构建对应的掩码图像路径
        mask_image_path = os.path.join(mask_images_folder, filename.split('.')[0] + '.png')
        
        # 如果对应的掩码图像存在，则进行处理
        if os.path.exists(mask_image_path):
            original_image = cv2.resize(original_image, (512, 512))

            # mask_image = cv2.imread(mask_image_path, cv2.IMREAD_GRAYSCALE)
            mask_image = kd_fill_holes(mask_image_path)

            # 分割前景部分
            foreground = segment_foreground(original_image, mask_image)
            
            # 保存分割后的前景部分为 PNG 格式，其中背景区域被设置为透明
            transparent_png = np.zeros((foreground.shape[0], foreground.shape[1], 4), dtype=np.uint8)
            transparent_png[:, :, :3] = foreground  # 将前景部分复制到新图像的 RGB 通道
            transparent_png[:, :, 3] = (mask_image != 0) * 255  # 使用掩码图的非零像素值作为透明度通道的值

            output_path = os.path.join("./experiments/predictions/Leaf_Foreground/", filename.split('.')[0] + '.png')
            cv2.imwrite(output_path, transparent_png)

# 显著性目标检测评估指标

## 1 IoU,准确率,精确率,召回率,F1 分数,MAE

In [None]:
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import os
import cv2

# 定义转换函数，将图像加载并转换为 PyTorch 张量
def load_and_transform_image(image_path):
    image = Image.open(image_path).convert('L')  # 以灰度模式打开图像
    image = ToTensor()(image)  # 转换为 PyTorch 张量
    return image

def is_image_file(filename):
    # if filename.startswith("Strawberry"):
    #     return False
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))

def read_file_names_from_txt(txt_file):
    with open(txt_file, 'r') as f:
        return set(line.strip() for line in f)
    
def evaluating(pred_folder, gt_folder,file_path):
    # 初始化指标
    iou_total = 0.0
    dice_total = 0.0
    accuracy_total = 0.0
    precision_total = 0.0
    recall_total = 0.0
    f1_total = 0.0
    mae_total = 0.0
    total_samples = 0

    # 遍历 GT 图像文件
    for root, _, gt_files in os.walk(pred_folder):
        for gt_file in gt_files:
            if is_image_file(gt_file):
                # 构建 GT 图像文件和预测显著图像文件的完整路径
                gt_path = os.path.join(gt_folder, gt_file)
                pred_path = os.path.join(pred_folder, gt_file)  # 假设预测图像在一个文件夹中

                if not os.path.exists(pred_path):
                    continue  # 如果预测图像不存在，跳过该文件

                # 加载 GT 图像和预测显著图像
                gt_image = load_and_transform_image(gt_path)
                pred_image = load_and_transform_image(pred_path)

                # 检查大小是否匹配
                if gt_image.size() != pred_image.size():
                    print(f"Size mismatch for {gt_file}: GT {gt_image.size()} vs Pred {pred_image.size()}")
                    continue

                # 计算二值化的 GT 图像和预测显著图像
                gt_binary = (gt_image > 0.5).float()  # 二值化，可以根据阈值调整
                pred_binary = (pred_image > 0.5).float()  # 二值化，可以根据阈值调整

                # 计算 IoU
                intersection = torch.sum(gt_image * pred_binary)
                union = torch.sum((gt_image + pred_binary) > 0)
                iou = intersection / union
                iou_total += iou.item()

                # 计算 Dice
                dice = (2.0 * intersection ) / (torch.sum(pred_binary) + torch.sum(gt_image) )
                dice_total += dice.item()
                
                # 计算准确率
                accuracy = torch.sum((gt_binary == pred_binary).float()) / gt_binary.numel()
                accuracy_total += accuracy.item()

                # 计算精确率、召回率和 F1 分数
                true_positive = torch.sum(gt_binary * pred_binary)
                false_positive = torch.sum((1 - gt_binary) * pred_binary)
                false_negative = torch.sum(gt_binary * (1 - pred_binary))

                precision = true_positive / (true_positive + false_positive + 1e-8)
                recall = true_positive / (true_positive + false_negative + 1e-8)
                f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

                precision_total += precision.item()
                recall_total += recall.item()
                f1_total += f1.item()

                # 计算 MAE
                abs_error = torch.abs(pred_image - gt_image)
                mae = torch.mean(abs_error)
                mae_total += mae.item()

                total_samples += 1

    if total_samples == 0:
        print("No valid samples found.")
        return
    # 计算平均指标值
    average_iou = iou_total / total_samples
    average_accuracy = accuracy_total / total_samples
    average_precision = precision_total / total_samples
    average_recall = recall_total / total_samples
    average_f1 = f1_total / total_samples
    average_mae = mae_total / total_samples
    average_dice = dice_total / total_samples

    # 打印结果
    print("平均 IoU:", average_iou)
    print("平均 Dice:", average_dice)
    print("平均准确率:", average_accuracy)
    print("平均精确率:", average_precision)
    print("平均召回率:", average_recall)
    print("平均 F1 分数:", average_f1)
    print("平均 MAE:", average_mae)
    print()
    with open(file_path, 'a') as f:
        f.write("平均 IoU: " + str(average_iou) + "\n")
        f.write("平均 Dice: " + str(average_dice) + "\n")
        f.write("平均准确率: " + str(average_accuracy) + "\n")
        f.write("平均精确率: " + str(average_precision) + "\n")
        f.write("平均召回率: " + str(average_recall) + "\n")
        f.write("平均 F1 分数: " + str(average_f1) + "\n")
        f.write("平均 MAE: " + str(average_mae) + "\n")
        f.write("\n")


predictions_nums = [0,1,2,3,4,9,"best"]
detco_nums=["CCAM_Apple_MOCO"]
detco_nums = [item for item in detco_nums for _ in range(7)]

predictions_nums = [0,1,2,3,4,9,"best"]
predictions_nums = predictions_nums
detco_nums=["CCAM_Apple_DETCO"]
detco_nums = [item for item in detco_nums for _ in range(7)]

predictions_nums = [0,1,2,3,4,9,"best"]
predictions_nums = predictions_nums
detco_nums=["CCAM_Apple_Plant"]
detco_nums = [item for item in detco_nums for _ in range(7)]

for predictions_num, detco_num in zip(predictions_nums, detco_nums):
    if detco_num.startswith("CCAM_Apple"):
        gt_folder = './experiments/data/万张图/Original_Resized_picture_gt'
        pred_folder = f'./experiments/predictions_{predictions_num}/{detco_num}@scale=2,3@t=0.5@ccam_inference_crf=10'
    elif detco_num.startswith("CCAM_Extra"):
        gt_folder = './experiments/data/万张图/Extra_Original_Resized_Leaf/Original_picture_gt'
        pred_folder = f'./experiments/predictions_{predictions_num}/{detco_num}@scale=2,3@t=0.5@ccam_inference_crf=10'
    tag = "/".join(pred_folder.split('/')[-2:])
    print("tag:", tag)
    file_path = f"./Eval_SOD/{tag}"
    if not os.path.exists(file_path):
        os.makedirs(file_path)
    file_path=file_path+"/"+"result.txt"
    evaluating(pred_folder, gt_folder,file_path)
