In [None]:
# 二值分割
import os
import nibabel as nib
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

def nii_to_png_split(img_dir, label_dir, output_base_dir, test_ratio=0.1):
    os.makedirs(output_base_dir, exist_ok=True)

    # 获取所有文件名（必须一一对应）
    img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])
    label_files = sorted([f for f in os.listdir(label_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])

    assert len(img_files) == len(label_files), "图像和标签数量不一致"

    # 文件级划分 train / test
    train_imgs, test_imgs, train_labels, test_labels = train_test_split(
        img_files, label_files, test_size=test_ratio, random_state=42
    )

    def process_subset(img_list, label_list, subset_name):
        img_output_dir = os.path.join(output_base_dir, subset_name, "images")
        mask_output_dir = os.path.join(output_base_dir, subset_name, "masks")
        os.makedirs(img_output_dir, exist_ok=True)
        os.makedirs(mask_output_dir, exist_ok=True)

        idx = 0
        for img_file, label_file in tqdm(zip(img_list, label_list), total=len(img_list), desc=f"Processing {subset_name}"):
            img_path = os.path.join(img_dir, img_file)
            label_path = os.path.join(label_dir, label_file)

            img_nii = nib.load(img_path).get_fdata()
            label_nii = nib.load(label_path).get_fdata()
            assert img_nii.shape == label_nii.shape, f"维度不一致: {img_file}"

            for i in range(img_nii.shape[2]):
                img_slice = img_nii[:, :, i]
                label_slice = label_nii[:, :, i]

                # Normalize 图像
                img_norm = ((img_slice - np.min(img_slice)) / (np.ptp(img_slice) + 1e-8) * 255).astype(np.uint8)
                label_bin = ((label_slice > 0) * 255).astype(np.uint8)

                img_pil = Image.fromarray(img_norm).convert('L')
                label_pil = Image.fromarray(label_bin).convert('L')

                img_pil.save(os.path.join(img_output_dir, f"{subset_name}_img_{idx:05d}.png"))
                label_pil.save(os.path.join(mask_output_dir, f"{subset_name}_mask_{idx:05d}.png"))

                idx += 1

    # 处理训练集和测试集
    process_subset(train_imgs, train_labels, "train")
    process_subset(test_imgs, test_labels, "test")

if __name__ == "__main__":
    # 输入路径
    img_dir = "/home/hxy/Documents/RawData/Training/img"
    label_dir = "/home/hxy/Documents/RawData/Training/label"

    # 输出路径
    output_base_dir = "./processed_dataset"

    nii_to_png_split(img_dir, label_dir, output_base_dir, test_ratio=0.1)


In [None]:
# 多标签分割
import os
import nibabel as nib
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

def nii_to_png_split(img_dir, label_dir, output_base_dir, test_ratio=0.1):
    os.makedirs(output_base_dir, exist_ok=True)

    img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])
    label_files = sorted([f for f in os.listdir(label_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])

    assert len(img_files) == len(label_files), "图像和标签数量不一致"

    train_imgs, test_imgs, train_labels, test_labels = train_test_split(
        img_files, label_files, test_size=test_ratio, random_state=42
    )

    def process_subset(img_list, label_list, subset_name):
        img_output_dir = os.path.join(output_base_dir, subset_name, "images")
        mask_output_dir = os.path.join(output_base_dir, subset_name, "masks")
        os.makedirs(img_output_dir, exist_ok=True)
        os.makedirs(mask_output_dir, exist_ok=True)

        idx = 0
        for img_file, label_file in tqdm(zip(img_list, label_list), total=len(img_list), desc=f"Processing {subset_name}"):
            img_path = os.path.join(img_dir, img_file)
            label_path = os.path.join(label_dir, label_file)

            img_nii = nib.load(img_path).get_fdata()
            label_nii = nib.load(label_path).get_fdata()
            assert img_nii.shape == label_nii.shape, f"维度不一致: {img_file}"

            for i in range(img_nii.shape[2]):
                img_slice = img_nii[:, :, i]
                label_slice = label_nii[:, :, i]

                # Normalize 图像（保留灰度值）
                img_norm = ((img_slice - np.min(img_slice)) / (np.ptp(img_slice) + 1e-8) * 255).astype(np.uint8)

                # 保留标签的原始类别（如 0,1,2,3...），最多支持 255 类
                label_class = label_slice.astype(np.uint8)

                img_pil = Image.fromarray(img_norm).convert('L')  # 灰度图
                label_pil = Image.fromarray(label_class, mode='L')  # 单通道、像素值表示类别

                img_pil.save(os.path.join(img_output_dir, f"{subset_name}_img_{idx:05d}.png"))
                label_pil.save(os.path.join(mask_output_dir, f"{subset_name}_mask_{idx:05d}.png"))

                idx += 1

    process_subset(train_imgs, train_labels, "train")
    process_subset(test_imgs, test_labels, "test")


if __name__ == "__main__":
    img_dir = "/home/hxy/Documents/RawData/Training/img"
    label_dir = "/home/hxy/Documents/RawData/Training/label"
    output_base_dir = "./processed_dataset_labels"
    nii_to_png_split(img_dir, label_dir, output_base_dir, test_ratio=0.1)


In [None]:
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

def inspect_dataset(image_dir, mask_dir, max_samples=200):
    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
    mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.png')])

    print(f"共找到 {len(image_files)} 张图像，{len(mask_files)} 个标签")
    print("=" * 60)

    for i in range(min(len(image_files), max_samples)):
        img_path = os.path.join(image_dir, image_files[i])
        mask_path = os.path.join(mask_dir, mask_files[i])

        # 读取图像
        img = Image.open(img_path)
        img_np = np.array(img)

        # 读取标签
        mask = Image.open(mask_path)
        mask_np = np.array(mask)

        print(f"样本 {i+1}:")
        print(f"  图像: {image_files[i]}")
        print(f"    - 尺寸: {img.size[::-1]} (H, W)")
        print(f"    - 像素范围: {img_np.min()} ~ {img_np.max()}")

        print(f"  标签: {mask_files[i]}")
        print(f"    - 尺寸: {mask.size[::-1]} (H, W)")
        print(f"    - 类别 ID: {np.unique(mask_np)}")
        print("-" * 60)

if __name__ == "__main__":
    image_dir = "./processed_dataset_labels/train/images"
    mask_dir = "./processed_dataset_labels/train/masks"
    inspect_dataset(image_dir, mask_dir, max_samples=200)


In [None]:
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def visualize(image_path, mask_path, pred_path=None, selected_labels=None, save_path=None):
    image = Image.open(image_path).convert('L')  # 灰度图
    mask = Image.open(mask_path)
    mask_np = np.array(mask)
    image_np = np.array(image)

    if pred_path is not None:
        pred = Image.open(pred_path)
        pred_np = np.array(pred)
    else:
        pred_np = None

    # 过滤标签
    if selected_labels is not None:
        mask_filtered = np.isin(mask_np, selected_labels)
        mask_np = mask_np * mask_filtered.astype(np.uint8)
        if pred_np is not None:
            pred_filtered = np.isin(pred_np, selected_labels)
            pred_np = pred_np * pred_filtered.astype(np.uint8)

    num_cols = 3 if pred_np is not None else 2
    plt.figure(figsize=(5 * num_cols, 5))

    plt.subplot(1, num_cols, 1)
    plt.imshow(image_np, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, num_cols, 2)
    plt.imshow(mask_np, cmap='nipy_spectral', vmin=0, vmax=np.max(mask_np))
    plt.title(f'Ground Truth (Labels: {np.unique(mask_np)})')
    plt.axis('off')

    if pred_np is not None:
        plt.subplot(1, num_cols, 3)
        plt.imshow(pred_np, cmap='nipy_spectral', vmin=0, vmax=np.max(mask_np))
        plt.title(f'Prediction (Labels: {np.unique(pred_np)})')
        plt.axis('off')

    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    else:
        plt.show()


def visualize_selected(image_dir, mask_dir, pred_dir=None, indices=[0, 1], selected_labels=None):
    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
    mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.png')])
    pred_files = sorted([f for f in os.listdir(pred_dir) if f.endswith('.png')]) if pred_dir else None

    for idx in indices:
        if idx >= len(image_files):
            print(f"❌ 索引 {idx} 超出范围")
            continue

        image_path = os.path.join(image_dir, image_files[idx])
        mask_path = os.path.join(mask_dir, mask_files[idx])
        pred_path = os.path.join(pred_dir, pred_files[idx]) if pred_dir else None

        print(f"🔍 Viewing index {idx} : {image_files[idx]}, labels: {selected_labels}")
        visualize(image_path, mask_path, pred_path, selected_labels)


if __name__ == "__main__":
    image_dir = "./processed_dataset_labels/test/images"
    mask_dir = "./processed_dataset_labels/test/masks"
    pred_dir = "./outputs_medsam_200_labels-lr001"  # ✅ 预测结果路径

    # ✅ 你想查看哪些索引的图
    indices_to_view = list(range(150, 200))
    labels_to_highlight = [0, 6, 7, 8, 9]  # ✅ 你想查看的标签

    visualize_selected(
        image_dir=image_dir,
        mask_dir=mask_dir,
        pred_dir=pred_dir,
        indices=indices_to_view,
        selected_labels=labels_to_highlight
    )


In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
from glob import glob
import numpy as np

def load_image(path):
    return Image.open(path).convert('L')  # 灰度

def compare_images(pred_path, orig_path, gt_path, save_path=None, show=True):
    pred = load_image(pred_path)
    orig = load_image(orig_path)
    gt = load_image(gt_path)

    # resize 保证尺寸一致（如有必要）
    gt = gt.resize(orig.size)
    pred = pred.resize(orig.size)

    # 拼接图像
    merged = Image.new('L', (orig.width * 3, orig.height))
    merged.paste(orig, (0, 0))
    merged.paste(gt, (orig.width, 0))
    merged.paste(pred, (orig.width * 2, 0))

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        merged.save(save_path)

    if show:
        titles = ['Original Image', 'Ground Truth', 'Prediction']
        imgs = [orig, gt, pred]
        plt.figure(figsize=(12, 4))
        for i in range(3):
            plt.subplot(1, 3, i + 1)
            plt.imshow(imgs[i], cmap='gray')
            plt.title(titles[i])
            plt.axis('off')
        plt.tight_layout()
        plt.show()

def batch_compare(pred_dir, orig_dir, gt_dir, save_dir=None):
    pred_files = sorted(glob(os.path.join(pred_dir, "*.png")))

    for pred_path in pred_files:
        name = os.path.basename(pred_path)
        idx = name.split('_')[-1]  # 例如 00020.png
        orig_path = os.path.join(orig_dir, f"test_img_{idx}")
        gt_path = os.path.join(gt_dir, f"test_mask_{idx}")

        if not (os.path.exists(orig_path) and os.path.exists(gt_path)):
            print(f"[跳过] 缺失原图或标签: {name}")
            continue

        save_path = os.path.join(save_dir, name) if save_dir else None
        compare_images(pred_path, orig_path, gt_path, save_path)

if __name__ == "__main__":
    pred_dir = "outputs"
    orig_dir = "processed_dataset/test/images"
    gt_dir = "processed_dataset/test/masks"
    save_dir = None  # 若不想保存对比图可设为 None

    batch_compare(pred_dir, orig_dir, gt_dir, save_dir)


In [None]:
import numpy as np
from PIL import Image
import os

pred_dir = "./outputs_medsam_20_labels"
all_classes = []

for name in os.listdir(pred_dir):
    if name.endswith(".png"):
        mask = np.array(Image.open(os.path.join(pred_dir, name)))
        all_classes.extend(np.unique(mask).tolist())

print("预测中出现的类别:", sorted(set(all_classes)))
