## p3. 完成LunaDataset

In [2]:
import sys

sys.path.append("../../src/")

In [18]:
import torch
from torch.utils.data import Dataset
from util.logconf import logging
from util.util import XyzTuple, xyz2irc
from collections import namedtuple, defaultdict
import csv
import glob
import numpy as np
import SimpleITK as sitk
import os
import pandas as pd
import copy

In [4]:
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

In [25]:
CandidateInfoTuple = namedtuple(
    "CandidateInfoTuple",
    "is_nodule, diameter_mm, series_uid, center_xyz",
)

In [6]:
path = "D:/Code/data/luna16/"

In [33]:
def getCandidateInfoList(require_on_disk_bool=True):
    # 开始时不把所有的数据都解压，这里只取了部分subset，所以需要判断一下路径下真正存在的subset
    mhd_list = glob.glob(path + "subset*/*.mhd")
    present_on_disk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    diameter_dict = defaultdict(list)
    annotations_df = pd.read_csv(path + "annotations.csv")
    for _, row_series in annotations_df.iterrows():
        row_list = row_series.tolist()
        series_uid = row_list[0]
        annotation_center_xyz = tuple([float(x) for x in row_list[1:4]])
        annotation_diameter_mm = float(row_list[4])
        diameter_dict[series_uid].append(
            (annotation_center_xyz, annotation_diameter_mm)
        )

    candidate_info_list = []
    candidate_df = pd.read_csv(path + "candidates.csv")
    for _, row_series in candidate_df.iterrows():
        row_list = row_series.tolist()
        series_uid = row_list[0]
        if series_uid not in present_on_disk_set and require_on_disk_bool:
            continue

        is_nodule = bool(int(row_list[4]))
        candidate_center_xyz = tuple([float(x) for x in row_list[1:4]])

        candidate_diameter_mm = 0.0
        # 因为上面使用了defaultdict，所以这里不需要判断series_uid是否在diameter_dict中
        # 这里对比了候选结节的中心点和直径与annotations.csv中的中心点和直径，如果候选结节的中心点和直径在annotations.csv中的中心点和直径的1/4范围内，则认为是同一个结节
        # 就可以得到候选结节的直径
        # 至于不在1/4范围内的候选结节，直径就是0.0
        for annotation_tuple in diameter_dict[series_uid]:
            annotation_center_xyz, annotation_diameter_mm = annotation_tuple
            for i in range(3):
                delta_mm = abs(candidate_center_xyz[i] - annotation_center_xyz[i])
                if delta_mm > annotation_diameter_mm / 4:
                    break
            else:
                candidate_diameter_mm = annotation_diameter_mm
                break

        candidate_info_list.append(
            CandidateInfoTuple(
                is_nodule,
                candidate_diameter_mm,
                series_uid,
                candidate_center_xyz,
            )
        )

    # namedtuple之间的比较是按照顺序比较的，所以这里是先按照is_nodule是不是结节排序(分组)，而True>False，
    # 所以靠前的都是结节，如果is_nodule相同，那么就按照diameter_mm降序排列
    candidate_info_list.sort(reverse=True)
    return candidate_info_list

In [32]:
print(True < False)

False


In [13]:
class Ct:
    def __init__(self, series_uid):
        # 指定series_uid，找到对应的mhd文件
        mhd_path = glob.glob(path + f"subset*/{series_uid}.mhd")[0]

        ct_mhd = sitk.ReadImage(mhd_path)
        ct_arr = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        # 限制HU值的范围，因为一般空气的HU值在-1000左右，骨头的HU值在1000左右
        ct_arr.clip(-1000, 1000, ct_arr)

        self.series_uid = series_uid
        self.hu_arr = ct_arr

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_arr = np.array(ct_mhd.GetDirection()).reshape(3, 3)

    def getRawCandidate(self, center_xyz, width_irc):
        center_irc = xyz2irc(
            center_xyz,
            self.origin_xyz,
            self.vxSize_xyz,
            self.direction_arr,
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_idx = int(round(center_val - width_irc[axis] / 2))
            end_idx = int(start_idx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_arr.shape[axis], repr(
                [
                    self.series_uid,
                    center_xyz,
                    self.origin_xyz,
                    self.vxSize_xyz,
                    center_irc,
                    axis,
                ]
            )

            # 处理候选结节可能超出CT边界的情况
            if start_idx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_idx = 0
                end_idx = int(width_irc[axis])

            if end_idx > self.hu_arr.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_idx = self.hu_arr.shape[axis]
                start_idx = int(self.hu_arr.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_idx, end_idx))

        ct_chunk = self.hu_arr[tuple(slice_list)]

        return ct_chunk, center_irc

In [14]:
def getCt(series_uid):
    return Ct(series_uid)

In [15]:
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    return ct_chunk, center_irc

In [28]:
class LunaDataset(Dataset):
    # is_val指定是否用来测试，val_stride指定测试集集的步长
    def __init__(self, val_stride=0, is_val=None, series_uid=None):
        self.candidate_info_list = copy.copy(getCandidateInfoList())

        if series_uid:
            self.candidate_info_list = [
                x for x in self.candidate_info_list if x.series_uid == series_uid
            ]

        # 分割数据集
        if is_val:
            assert val_stride > 0, val_stride
            self.candidate_info_list = self.candidate_info_list[::val_stride]
            assert self.candidate_info_list
        elif val_stride > 0:
            del self.candidate_info_list[::val_stride]
            assert self.candidate_info_list

        log.info(f"{repr(self)}: {len(self.candidate_info_list)} {"validation" if is_val else "training"} samples")
        
        # log.info(
        #     "{!r}: {} {} samples".format(
        #         self,
        #         len(self.candidate_info_list),
        #         "validation" if is_val else "training",
        #     )
        # )

    def __len__(self):
        return len(self.candidate_info_list)

    def __getitem__(self, idx):
        candidate_info_tuple = self.candidate_info_list[idx]
        # 预设的候选结节的尺寸
        width_irc = (32, 48, 48)

        candidate_arr, center_irc = getCtRawCandidate(
            candidate_info_tuple.series_uid,
            candidate_info_tuple.center_xyz,
            width_irc,
        )

        candidate_t = torch.from_numpy(candidate_arr).to(torch.float32).unsqueeze(0)

        logits = torch.tensor(
            [
                not candidate_info_tuple.is_nodule,
                candidate_info_tuple.is_nodule,
            ],
            dtype=torch.int64,
        )

        return (
            candidate_t,
            logits,
            candidate_info_tuple.series_uid,
            torch.tensor(center_irc),
        )

In [23]:
# ds = LunaDataset(series_uid="1.3.6.1.4.1.14519.5.2.1.6279.6001.100684836163890911914061745866")
ds = LunaDataset()

2024-06-28 11:15:11,068 INFO     pid:27028 __main__:019:__init__ <__main__.LunaDataset object at 0x0000017D769818E0>: 110143 training samples


In [27]:
ds[0][0].shape, ds[0][1], ds[0][2], ds[0][3]

(torch.Size([1, 32, 48, 48]),
 tensor([0, 1]),
 '1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886',
 tensor([ 91, 360, 341]))

**需要注意的是，由于上面candidate_info_list的排序方式，有可能许多相同uid的ct数据会连续，这时需要考虑一个问题就是，是否不能让同一个uid的ct数据同时出现在训练集与测试集中，以防出现数据泄露的情况**

In [34]:
ds[0][2], ds[1][2], ds[2][2], ds[3][2], ds[4][2], ds[5][2]

('1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886',
 '1.3.6.1.4.1.14519.5.2.1.6279.6001.511347030803753100045216493273',
 '1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192',
 '1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192',
 '1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192',
 '1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192')