In [2]:
import copy
import csv
import functools
import glob
import os
import collections

from collections import namedtuple
from itertools import chain

import SimpleITK as sitk
import numpy as np

import torch
import torch.cuda
from torch.utils.data import Dataset

## `Util`

---

### **# 1** `from util.disk import getCache`

`getCache()` 函数的原码如下:

```python
def getCache(scope_str):
    return FanoutCache('data-unversioned/cache/' + scope_str,
                       disk=GzipDisk,
                       shards=64,
                       timeout=1,
                       size_limit=3e11,
                       # disk_min_file_size=2**20,
                       )
```

其依赖关系包括: `FanoutCache` 和 `GzipDisk`, 其中, `FanoutCache` 可以从 `diskcache` 中导入, 而 `GzipDisk` 需要我们重新定义.

---

#### **# 1.1** `FanoutCache`

In [3]:
from diskcache import FanoutCache, Disk

👆

对于此 `getCache()` 函数, 好像并不是很必要, 所以暂且按下不表.

### **# 2** `from util.util import XyzTuple, xyz2irc`

#### **# 2.1** `XyzTuple`

In [4]:
XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])

#### **# 2.2** `xyz2irc`

其中其依赖 `IrcTuple`

In [5]:
IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])

def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_a):
    origin_a = np.array(origin_xyz)
    vxSize_a = np.array(vxSize_xyz)
    coord_a = np.array(coord_xyz)
    cri_a = ((coord_a - origin_a) @ np.linalg.inv(direction_a)) / vxSize_a
    cri_a = np.round(cri_a)
    return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))

### **# 3** `from util.logconf import logging`

据 Copilot 的说法, `logging` 是被直接导入的，而不是 `logconf` 模块的一部分。所以，如果你想在其他地方使用 `logging`，你应该直接导入它，就像这样：`import logging`。

---

#### 更新

对话如下:

**You:** 有没有可能是这段代码修改了部分关于logging的相关参数?

**Copilot:** 

> 是的，你完全正确。这段代码实际上是在配置Python的`logging`模块。
> 
> 首先，它获取了根日志记录器（`root_logger`），并设置了其日志级别为`INFO`。这意味着只有级别为`INFO`或更高级别的日志消息才会被处理。
> 
> 然后，它移除了所有已经添加到根日志记录器的处理器。这是因为有些库可能会尝试添加自己的处理器，这可能会干扰我们的日志记录。
> 
> 接下来，它定义了一个日志格式字符串（`logfmt_str`），并创建了一个格式器（`formatter`）。这个格式器会按照这个格式字符串来格式化日志消息。
> 
> 然后，它创建了一个流处理器（`streamHandler`），并设置了其格式器和日志级别。这个流处理器会将日志消息输出到标准输出流（通常是控制台）。
> 
> 最后，它将这个流处理器添加到根日志记录器。这样，当我们使用`logging`模块来记录日志时，日志消息就会被这个流处理器处理，并按照我们设置的格式输出到标准输出流。
> 
> 所以，当你在其他地方使用`from util.logconf import logging`时，你实际上是在导入这个已经被配置过的`logging`模块。

由此, 以下代码是必要的

In [6]:
import logging
import logging.handlers

root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)

# Some libraries attempt to add their own root logger handlers. This is
# annoying and so we get rid of them.
for handler in list(root_logger.handlers):
    root_logger.removeHandler(handler)

logfmt_str = "%(asctime)s %(levelname)-8s pid:%(process)d %(name)s:%(lineno)03d:%(funcName)s %(message)s"
formatter = logging.Formatter(logfmt_str)

streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
streamHandler.setLevel(logging.DEBUG)

root_logger.addHandler(streamHandler)

## `Dataset`

In [7]:
path = ('/home/ben/Datasets/luna16/')
sublist = 2

In [8]:
CandidateInfoTuple = namedtuple(
    'CandidateInfoTuple',
    'isNodule_bool, diameter_mm, series_uid, center_xyz',
)

In [9]:
def getCandidateInfoList(requireOnDisk_bool=True):
    mhd_list = list(chain.from_iterable([glob.glob(path + f'subset{i}/*.mhd') for i in range(sublist)]))
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    diameter_dict = {}
    with open(path + 'annotations.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])

            diameter_dict.setdefault(series_uid, []).append(
                (annotationCenter_xyz, annotationDiameter_mm)
            )

    candidateInfo_list = []
    with open(path + 'candidates.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])

            candidateDiameter_mm = 0.0
            for annotation_tup in diameter_dict.get(series_uid, []):
                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                for i in range(3):
                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                    if delta_mm > annotationDiameter_mm / 4:
                        break
                else:
                    candidateDiameter_mm = annotationDiameter_mm
                    break

            candidateInfo_list.append(CandidateInfoTuple(
                isNodule_bool,
                candidateDiameter_mm,
                series_uid,
                candidateCenter_xyz,
            ))

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list

In [10]:
class Ct:
    def __init__(self, series_uid):
        # mhd_path = glob.glob(
        #     'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
        # )[0]
        mhd_path = list(
            chain.from_iterable(
                glob.glob(
                    path + f'subset{i}/{series_uid}.mhd'
                ) for i in range(sublist)
            )
        )[0]
        ct_mhd = sitk.ReadImage(mhd_path)
        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        ct_a.clip(-1000, 1000, ct_a)

        self.series_uid = series_uid
        self.hu_a = ct_a

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

    def getRawCandidate(self, center_xyz, width_irc):
        center_irc = xyz2irc(
            center_xyz,
            self.origin_xyz,
            self.vxSize_xyz,
            self.direction_a,
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx + width_irc[axis])

            if start_ndx < 0:
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]

        return ct_chunk, center_irc

In [11]:
def getCt(series_uid):
    return Ct(series_uid)

In [12]:
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    return ct_chunk, center_irc

In [13]:
class LunaDataset(Dataset):
    def __init__(self,
                 val_stride=0,
                 isValSet_bool=None,
                 series_uid=None,
            ):
        self.candidateInfo_list = copy.copy(getCandidateInfoList())

        if series_uid:
            self.candidateInfo_list = [
                x for x in self.candidateInfo_list if x.series_uid == series_uid
            ]

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.candidateInfo_list = self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list
        elif val_stride > 0:
            del self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list

    def __len__(self):
        return len(self.candidateInfo_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.candidateInfo_list[ndx]
        width_irc = (32, 48, 48)

        candidate_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            width_irc,
        )

        candidate_t = torch.from_numpy(candidate_a)
        candidate_t = candidate_t.to(torch.float32)
        candidate_t = candidate_t.unsqueeze(0)

        pos_t = torch.tensor([
                not candidateInfo_tup.isNodule_bool,
                candidateInfo_tup.isNodule_bool
            ],
            dtype=torch.long,
        )

        return (
            candidate_t,
            pos_t,
            candidateInfo_tup.series_uid,
            torch.tensor(center_irc),
        )

## `RUN dataset()`

In [14]:
ds = LunaDataset()

In [20]:
ds[0][0].shape

torch.Size([1, 32, 48, 48])