In [9]:
import os
import glob
import numpy as np
import cv2
import matplotlib.pyplot as plt
import scipy.ndimage
from skimage import exposure, filters, segmentation, morphology, color, util, measure
from tqdm import tqdm

# ================= 配置区域 =================
# 请确认你的数据路径是否正确
DATA_DIR = "/home/chenx/code/medical_project/data/LIDC-IDRI-slices"
IMAGE_SIZE = (128, 128) 
SAVE_RESULT_DIR = "./output_top10_results" # 结果保存路径
# ===========================================

class LIDC_FinalProcessor:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        if not os.path.exists(SAVE_RESULT_DIR):
            os.makedirs(SAVE_RESULT_DIR)

    def load_data_and_gt(self):
        """
        加载数据并根据 '3/4 共识' 逻辑生成 Ground Truth
        """
        patient_dirs = glob.glob(os.path.join(self.data_dir, '**', 'images'), recursive=True)
        valid_samples = []
        
        print(f"扫描数据集中: {self.data_dir} ...")
        
        # 扫描前 300 个文件夹 (如果机器性能好，可以去掉 [:300] 跑全量)
        # 建议至少跑 200-300 个样本以确保能筛选出高质量的 Top 10
        for img_dir in tqdm(patient_dirs[:300]): 
            base_dir = os.path.dirname(img_dir)
            image_files = glob.glob(os.path.join(img_dir, "*.png")) + glob.glob(os.path.join(img_dir, "*.jpg"))
            
            for img_path in image_files:
                filename = os.path.basename(img_path)
                
                # 读取图像
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                if img is None: continue
                img = cv2.resize(img, IMAGE_SIZE)

                # 读取 Masks
                masks = []
                for i in range(4):
                    mask_path = os.path.join(base_dir, f"mask-{i}", filename)
                    if os.path.exists(mask_path):
                        m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                        m = cv2.resize(m, IMAGE_SIZE)
                        masks.append(m)
                    else:
                        masks.append(np.zeros(IMAGE_SIZE, dtype=np.uint8))

                # === 共识逻辑 ===
                mask_sums = [np.sum(m) for m in masks]
                # 统计有多少个医生认为有结节 (阈值>10排除噪点)
                votes = sum(1 for s in mask_sums if s > 10)

                # 判定: >2 人 (即至少3人) 同意才有病
                if votes > 2:
                    # 取面积最大的 mask 作为 GT
                    gt_mask = masks[np.argmax(mask_sums)]
                    gt_mask = (gt_mask > 127).astype(np.uint8)
                    
                    valid_samples.append({
                        'image': img,
                        'mask': gt_mask,
                        'id': f"{os.path.basename(base_dir)}_{filename}"
                    })
        
        print(f"筛选出 {len(valid_samples)} 个正样本用于测试。")
        return valid_samples

    def process_pipeline(self, image_uint8):
        """
        改进版流水线：
        1. 预处理：CLAHE (对比度增强) + MedianBlur (去噪保边)
        2. ROI提取：生成肺部掩码 (Lung Mask)，排除肋骨和背景
        3. 分割：ROI内的 Otsu 阈值分割
        4. 特征筛选：剔除血管 (长条状) 和 噪点 (过小)
        """
        
        # === 1. 图像复原与增强 ===
        # 使用 CLAHE 替代全局均衡化，避免背景噪点爆炸
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced = clahe.apply(image_uint8)
        
        # 使用中值滤波替代维纳滤波，更好地保留边缘
        denoised = cv2.medianBlur(enhanced, 3)
        
        # 转为 float 用于可能的后续计算，但在 mask 生成中我们主要用 uint8
        img_float = util.img_as_float(denoised)

        # === 2. 肺实质提取 (Lung Mask Generation) ===
        # 目标：生成一个只包含左右肺的 mask
        
        # Otsu 阈值区分身体和背景
        _, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # 清除边界 (去除贴在图像边缘的非肺组织)
        clear_border = segmentation.clear_border(binary)
        
        # 标记连通域
        labels = measure.label(clear_border)
        props = measure.regionprops(labels)
        
        lung_mask = np.zeros_like(binary, dtype=np.uint8)
        
        # 按面积排序区域
        sorted_regions = sorted(props, key=lambda x: x.area, reverse=True)
        
        # 取最大的前2个区域作为肺 (左右肺)
        for region in sorted_regions[:2]:
            # 阈值设为 200，防止漏掉肺尖/肺底较小的切片
            if region.area > 200: 
                r_mask = (labels == region.label)
                # 填充肺内部空洞
                r_mask = scipy.ndimage.binary_fill_holes(r_mask)
                lung_mask[r_mask] = 1
        
        # === [关键修复] 安全检查 ===
        # 如果 Mask 为空 (没找到肺)，直接返回，避免后续 Otsu 报错
        if np.sum(lung_mask) == 0:
            return {
                'denoised': denoised,
                'enhanced': enhanced,
                'lung_mask': lung_mask,
                'prediction': np.zeros_like(image_uint8)
            }

        # === 3. 结节候选分割 (Candidate Segmentation) ===
        
        # 只在肺掩码范围内进行处理
        masked_img = denoised.copy()
        masked_img[lung_mask == 0] = 0
        
        # 计算肺内部的 Otsu 阈值
        try:
            # 只取 lung_mask > 0 的像素来计算阈值
            thresh_val = filters.threshold_otsu(masked_img[lung_mask > 0])
        except Exception:
            # 万一计算失败，回退到固定阈值
            thresh_val = 0.5 * 255
            
        # 二值化提取高亮区域
        candidates = (masked_img > thresh_val).astype(np.uint8)

        # === 4. 形态学筛选 (剔除血管) ===
        cand_labels = measure.label(candidates)
        cand_props = measure.regionprops(cand_labels)
        
        final_mask = np.zeros_like(candidates)
        
        for prop in cand_props:
            # --- 筛选逻辑 ---
            
            # 1. 面积筛选
            if prop.area < 10 or prop.area > 600:
                continue
                
            # 2. 形状筛选 (区分血管 vs 结节)
            # Eccentricity (离心率): 0是圆，1是直线。血管通常 > 0.85
            # Solidity (实心度): 结节通常比较实心 (>0.8)
            
            is_circle = prop.eccentricity < 0.85 
            is_solid = prop.solidity > 0.75 # 稍微放宽一点
            
            if is_circle and is_solid:
                final_mask[cand_labels == prop.label] = 1

        return {
            'denoised': denoised,  
            'enhanced': masked_img, # 展示 Mask 后的效果
            'lung_mask': lung_mask, # 调试用：看肺有没有提取对
            'prediction': final_mask
        }

    def calculate_dice(self, y_true, y_pred):
        """计算 Dice 系数"""
        intersection = np.sum(y_true * y_pred)
        return (2. * intersection) / (np.sum(y_true) + np.sum(y_pred) + 1e-6)

    def visualize_best(self, rank, original, results, gt_mask, dice_score):
        """保存结果图"""
        fig, axes = plt.subplots(1, 5, figsize=(20, 4))
        ax = axes.ravel()

        # 1. 原图
        ax[0].imshow(original, cmap='gray')
        ax[0].set_title(f'Rank {rank}: Original')

        # 2. 增强后 (Masked)
        ax[1].imshow(results['enhanced'], cmap='gray')
        ax[1].set_title('CLAHE + Lung Mask')

        # 3. 肺部掩码 (检查 ROI 提取是否正确)
        ax[2].imshow(results['lung_mask'], cmap='gray')
        ax[2].set_title('Generated Lung Mask')

        # 4. 预测结果 (叠加红色)
        overlay = color.label2rgb(results['prediction'], image=results['enhanced'], 
                                  bg_label=0, colors=['red'], alpha=0.3)
        ax[3].imshow(overlay)
        ax[3].set_title(f'Prediction (Dice: {dice_score:.3f})')

        # 5. Ground Truth (叠加绿色)
        gt_overlay = color.label2rgb(gt_mask, image=original, 
                                     bg_label=0, colors=['green'], alpha=0.3)
        ax[4].imshow(gt_overlay)
        ax[4].set_title('Ground Truth')

        for a in ax: a.axis('off')

        save_name = os.path.join(SAVE_RESULT_DIR, f"rank_{rank}_dice_{dice_score:.3f}.png")
        plt.tight_layout()
        plt.savefig(save_name)
        plt.close(fig) # 关闭以释放内存

    def run(self):
        # 1. 加载数据
        samples = self.load_data_and_gt()
        if not samples:
            print("未找到有效数据，请检查路径。")
            return

        all_results = []
        
        print("开始处理流程...")
        # 2. 遍历处理
        for sample in tqdm(samples):
            img = sample['image']
            gt = sample['mask']
            
            # 运行管线
            res = self.process_pipeline(img)
            
            # 计算 Dice
            dice = self.calculate_dice(gt, res['prediction'])
            
            # 保存结果到内存列表 (为了后续排序)
            all_results.append({
                'dice': dice,
                'image': img,
                'gt': gt,
                'results': res
            })

        # 3. 排序并输出 Top 10
        print("\n正在排序并生成 Top 10 结果...")
        
        # 按 Dice 降序排列
        all_results.sort(key=lambda x: x['dice'], reverse=True)
        
        # 取前 10 个
        top_10 = all_results[:10]

        for rank, item in enumerate(top_10, 1):
            self.visualize_best(
                rank=rank,
                original=item['image'],
                results=item['results'],
                gt_mask=item['gt'],
                dice_score=item['dice']
            )
            print(f"Rank {rank} Saved: Dice = {item['dice']:.4f}")

        # 计算整体平均分
        avg_dice = np.mean([x['dice'] for x in all_results])
        print(f"\n=== 统计信息 ===")
        print(f"总处理样本数: {len(all_results)}")
        print(f"平均 Dice Score: {avg_dice:.4f}")
        print(f"最高 Dice Score: {all_results[0]['dice']:.4f}")
        print(f"结果已保存至: {SAVE_RESULT_DIR}")

if __name__ == "__main__":
    processor = LIDC_FinalProcessor(DATA_DIR)
    processor.run()

扫描数据集中: /home/chenx/code/medical_project/data/LIDC-IDRI-slices ...


100%|██████████| 300/300 [00:01<00:00, 266.24it/s]


筛选出 764 个正样本用于测试。
开始处理流程...


100%|██████████| 764/764 [00:01<00:00, 504.17it/s]



正在排序并生成 Top 10 结果...
Rank 1 Saved: Dice = 0.9619
Rank 2 Saved: Dice = 0.9027
Rank 3 Saved: Dice = 0.8897
Rank 4 Saved: Dice = 0.8734
Rank 5 Saved: Dice = 0.8679
Rank 6 Saved: Dice = 0.8661
Rank 7 Saved: Dice = 0.8619
Rank 8 Saved: Dice = 0.8583
Rank 9 Saved: Dice = 0.8482
Rank 10 Saved: Dice = 0.8314

=== 统计信息 ===
总处理样本数: 764
平均 Dice Score: 0.0523
最高 Dice Score: 0.9619
结果已保存至: ./output_top10_results
