Part 1 env setting

In [4]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pydicom
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix


In [5]:
clinical1 = pd.read_csv("../testdata/dataset1/clinical1.csv")
clinical2 = pd.read_csv("../testdata/dataset2/clinical2.csv")

display(clinical1.head())
display(clinical2.head())


Unnamed: 0,PatientID,age,clinical.T.Stage,Clinical.N.Stage,Clinical.M.Stage,Overall.Stage,Histology,gender,Survival.time,deadstatus.event
0,LUNG1-001,78.7515,2.0,3,0,IIIb,large cell,male,2165,1
1,LUNG1-002,83.8001,2.0,0,0,I,squamous cell carcinoma,male,155,1
2,LUNG1-003,68.1807,2.0,3,0,IIIb,large cell,male,256,1
3,LUNG1-004,70.8802,2.0,1,0,II,squamous cell carcinoma,male,141,1
4,LUNG1-005,80.4819,4.0,2,0,IIIb,squamous cell carcinoma,male,353,1


Unnamed: 0,Case ID,Patient affiliation,Age at Histological Diagnosis,Weight (lbs),Gender,Ethnicity,Smoking status,Pack Years,Quit Smoking Year,%GG,...,Recurrence,Recurrence Location,Date of Recurrence,Date of Last Known Alive,Survival Status,Date of Death,Time to Death (days),CT Date,Days between CT and surgery,PET Date
0,AMC-001,Stanford,34,Not Collected,Male,Not Recorded In Database,Nonsmoker,,,Not Assessed,...,yes,distant,10/7/1994,1/7/1997,Dead,1/7/1997,872.0,8/10/1994,9,Not Collected
1,AMC-002,Stanford,33,Not Collected,Female,Not Recorded In Database,Nonsmoker,,,Not Assessed,...,no,,,3/20/1992,Alive,,,2/19/1992,3,Not Collected
2,AMC-003,Stanford,69,Not Collected,Female,Not Recorded In Database,Nonsmoker,,,Not Assessed,...,no,,,6/19/1996,Alive,,,2/23/1995,28,Not Collected
3,AMC-004,Stanford,80,Not Collected,Female,Not Recorded In Database,Nonsmoker,,,Not Assessed,...,no,,,12/13/1996,Alive,,,12/26/1992,47,Not Collected
4,AMC-005,Stanford,76,Not Collected,Male,Not Recorded In Database,Former,30.0,1962.0,Not Assessed,...,yes,distant,1/4/1996,1/7/1997,Alive,,,7/21/1994,2,Not Collected


函数：返回3D volume(CT和segmentation)

In [31]:
import os
import pydicom
import numpy as np
import SimpleITK as sitk

def find_segmentation_file(folder):
    for root, dirs, files in os.walk(folder):
        for f in files:
            if "segmentation" in root.lower() and f.lower().endswith(".dcm"):
                return os.path.join(root, f)
    return None

def get_series_uid_from_seg(seg_path):
    ds = pydicom.dcmread(seg_path)
    return ds.ReferencedSeriesSequence[0].SeriesInstanceUID

def find_ct_folder_by_uid(folder, target_uid):
    for root, dirs, files in os.walk(folder):
        for f in files:
            if f.lower().endswith(".dcm"):
                path = os.path.join(root, f)
                try:
                    ds = pydicom.dcmread(path, stop_before_pixels=True)
                    if ds.SeriesInstanceUID == target_uid:
                        return root
                except:
                    continue
    return None

def read_image_volume(ct_folder):
    slices = []
    for f in os.listdir(ct_folder):
        if f.lower().endswith(".dcm"):
            path = os.path.join(ct_folder, f)
            try:
                ds = pydicom.dcmread(path)
                slices.append((int(ds.InstanceNumber), ds.pixel_array))
            except:
                continue
    slices = sorted(slices, key=lambda x: x[0])
    return np.stack([s[1] for s in slices])

def read_dicomseg_to_mask(seg_path):
    seg_img = sitk.ReadImage(seg_path)
    mask = sitk.GetArrayFromImage(seg_img)  # shape: (Z, H, W)
    return mask

def load_dicomseg_and_ct_volume(patient_folder):
    seg_path = find_segmentation_file(patient_folder)
    if seg_path is None:
        print(f"[跳过] 无 segmentation 文件：{patient_folder}")
        return None, None

    target_uid = get_series_uid_from_seg(seg_path)
    ct_folder = find_ct_folder_by_uid(patient_folder, target_uid)
    if ct_folder is None:
        print(f"[跳过] 无匹配的 CT 图像序列：{patient_folder}")
        return None, None

    print(f"[✔] 加载中：{patient_folder}")
    image_3d = read_image_volume(ct_folder)
    mask_3d = read_dicomseg_to_mask(seg_path)

    # 检查是否 shape 匹配
    if image_3d.shape != mask_3d.shape:
        print(f"[⚠️] CT 与 mask 尺寸不符：{image_3d.shape} vs {mask_3d.shape}")

    return image_3d, mask_3d, ct_folder, seg_path


In [32]:
image, mask, ct_folder, seg_path = load_dicomseg_and_ct_volume("../testdata/dataset1/LUNG1-001")

print("CT shape:", image.shape if image is not None else None)
print("Mask shape:", mask.shape if mask is not None else None)


[✔] 加载中：../testdata/dataset1/LUNG1-001
[⚠️] CT 与 mask 尺寸不符：(134, 512, 512) vs (536, 512, 512)
CT shape: (134, 512, 512)
Mask shape: (536, 512, 512)


In [33]:
def get_z_coords(image_sitk):
    """根据 SimpleITK 图像返回每层的物理 Z 坐标"""
    size = image_sitk.GetSize()      # (W, H, Z)
    spacing = image_sitk.GetSpacing()  # (sx, sy, sz)
    origin = image_sitk.GetOrigin()
    direction = np.array(image_sitk.GetDirection()).reshape(3, 3)

    z_direction = direction[:, 2]  # z轴方向单位向量
    z_coords = [
        origin[2] + i * spacing[2] * z_direction[2]
        for i in range(size[2])
    ]
    return z_coords

In [36]:
# 验证 Z 坐标顺序是否一致
def read_dicom_series(folder_path):
    """从一个包含 DICOM 文件的文件夹中读取为 3D 图像"""
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(folder_path)
    reader.SetFileNames(dicom_names)
    return reader.Execute()

ct_sitk = read_dicom_series(ct_folder)
seg_sitk = sitk.ReadImage(seg_path)

ct_z = get_z_coords(ct_sitk)
seg_z = get_z_coords(seg_sitk)

print("\nZ 轴坐标比较（前10层）:")
for i in range(min(10, len(ct_z), len(seg_z))):
    print(f"Slice {i}: CT Z = {ct_z[i]:.2f}, SEG Z = {seg_z[i]:.2f}")


Z 轴坐标比较（前10层）:
Slice 0: CT Z = -681.50, SEG Z = -681.50
Slice 1: CT Z = -678.50, SEG Z = -680.50
Slice 2: CT Z = -675.50, SEG Z = -679.50
Slice 3: CT Z = -672.50, SEG Z = -678.50
Slice 4: CT Z = -669.50, SEG Z = -677.50
Slice 5: CT Z = -666.50, SEG Z = -676.50
Slice 6: CT Z = -663.50, SEG Z = -675.50
Slice 7: CT Z = -660.50, SEG Z = -674.50
Slice 8: CT Z = -657.50, SEG Z = -673.50
Slice 9: CT Z = -654.50, SEG Z = -672.50


In [37]:
def align_segmentation_by_z(seg_img, seg_z, ct_z):
    """
    给定 segmentation 图像和其 Z 坐标序列，返回只保留与 CT Z 匹配部分的掩膜 volume
    """
    z_set = set(round(z, 1) for z in ct_z)  # 四舍五入统一精度
    matched_indices = [i for i, z in enumerate(seg_z) if round(z, 1) in z_set]
    mask_array = sitk.GetArrayFromImage(seg_img)
    return mask_array[matched_indices]


In [38]:
aligned_mask = align_segmentation_by_z(seg_sitk, seg_z, ct_z)
