<a href="https://www.kaggle.com/code/waylandwong/lung-cancer?scriptVersionId=140248912" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# 肺癌检测

In [None]:
pip install torch torchvision

In [None]:
pip install diskcache cassandra-driver

## 预先加载磁盘中的CT和标注数据

In [None]:
import copy
import csv
import functools
import glob
import os

from collections import namedtuple

import numpy as np

import torch
import torch.cuda
from torch.utils.data import Dataset


CandidateInfoTuple = namedtuple(
    'CandidateInfoTuple',
    'isNodule_bool, diameter_mm, series_uid, center_xyz',
)

# 加载候选点信息数据
@functools.lru_cache(1)
def getCandidateInfoList(requireOnDisk_bool=True):
    # We construct a set with all series_uids that are present on disk.
    # This will let us use the data, even if we haven't downloaded all of
    # the subsets yet.
    mhd_list = glob.glob('/kaggle/input/luna-lung-cancer-dataset/seg-lungs-LUNA16/seg-lungs-LUNA16/*.mhd')
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    diameter_dict = {}
    with open('/kaggle/input/luna-lung-cancer-dataset/annotations.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])

            diameter_dict.setdefault(series_uid, []).append(
                (annotationCenter_xyz, annotationDiameter_mm)
            )

    candidateInfo_list = []
    with open('/kaggle/input/luna-lung-cancer-dataset/candidates.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])

            candidateDiameter_mm = 0.0
            for annotation_tup in diameter_dict.get(series_uid, []):
                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                for i in range(3):
                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                    if delta_mm > annotationDiameter_mm / 4:
                        break
                else:
                    candidateDiameter_mm = annotationDiameter_mm
                    break

            candidateInfo_list.append(CandidateInfoTuple(
                isNodule_bool,
                candidateDiameter_mm,
                series_uid,
                candidateCenter_xyz,
            ))
    # 按第一个字段（是否是结节）排序
    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list


## logconf.py

In [None]:
import logging
import logging.handlers

root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)

# Some libraries attempt to add their own root logger handlers. This is
# annoying and so we get rid of them.
for handler in list(root_logger.handlers):
    root_logger.removeHandler(handler)

logfmt_str = "%(asctime)s %(levelname)-8s pid:%(process)d %(name)s:%(lineno)03d:%(funcName)s %(message)s"
formatter = logging.Formatter(logfmt_str)

streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
streamHandler.setLevel(logging.DEBUG)

root_logger.addHandler(streamHandler)


### util.py

In [None]:
import collections
import copy
import datetime
import gc
import time

# import torch
import numpy as np

import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])
XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])

def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_a):
    cri_a = np.array(coord_irc)[::-1]
    origin_a = np.array(origin_xyz)
    vxSize_a = np.array(vxSize_xyz)
    coords_xyz = (direction_a @ (cri_a * vxSize_a)) + origin_a
    # coords_xyz = (direction_a @ (idx * vxSize_a)) + origin_a
    return XyzTuple(*coords_xyz)

def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_a):
    origin_a = np.array(origin_xyz)
    vxSize_a = np.array(vxSize_xyz)
    coord_a = np.array(coord_xyz)
    cri_a = ((coord_a - origin_a) @ np.linalg.inv(direction_a)) / vxSize_a
    cri_a = np.round(cri_a)
    return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))


def importstr(module_str, from_=None):
    """
    >>> importstr('os')
    <module 'os' from '.../os.pyc'>
    >>> importstr('math', 'fabs')
    <built-in function fabs>
    """
    if from_ is None and ':' in module_str:
        module_str, from_ = module_str.rsplit(':')

    module = __import__(module_str)
    for sub_str in module_str.split('.')[1:]:
        module = getattr(module, sub_str)

    if from_:
        try:
            return getattr(module, from_)
        except:
            raise ImportError('{}.{}'.format(module_str, from_))
    return module


# class dotdict(dict):
#     '''dict where key can be access as attribute d.key -> d[key]'''
#     @classmethod
#     def deep(cls, dic_obj):
#         '''Initialize from dict with deep conversion'''
#         return cls(dic_obj).deepConvert()
#
#     def __getattr__(self, attr):
#         if attr in self:
#             return self[attr]
#         log.error(sorted(self.keys()))
#         raise AttributeError(attr)
#         #return self.get(attr, None)
#     __setattr__= dict.__setitem__
#     __delattr__= dict.__delitem__
#
#
#     def __copy__(self):
#         return dotdict(self)
#
#     def __deepcopy__(self, memo):
#         new_dict = dotdict()
#         for k, v in self.items():
#             new_dict[k] = copy.deepcopy(v, memo)
#         return new_dict
#
#     # pylint: disable=multiple-statements
#     def __getstate__(self): return self.__dict__
#     def __setstate__(self, d): self.__dict__.update(d)
#
#     def deepConvert(self):
#         '''Convert all dicts at all tree levels into dotdict'''
#         for k, v in self.items():
#             if type(v) is dict: # pylint: disable=unidiomatic-typecheck
#                 self[k] = dotdict(v)
#                 self[k].deepConvert()
#             try: # try enumerable types
#                 for m, x in enumerate(v):
#                     if type(x) is dict: # pylint: disable=unidiomatic-typecheck
#                         x = dotdict(x)
#                         x.deepConvert()
#                         v[m] = x#

#             except TypeError:
#                 pass
#         return self
#
#     def copy(self):
#         # override dict.copy()
#         return dotdict(self)


def prhist(ary, prefix_str=None, **kwargs):
    if prefix_str is None:
        prefix_str = ''
    else:
        prefix_str += ' '

    count_ary, bins_ary = np.histogram(ary, **kwargs)
    for i in range(count_ary.shape[0]):
        print("{}{:-8.2f}".format(prefix_str, bins_ary[i]), "{:-10}".format(count_ary[i]))
    print("{}{:-8.2f}".format(prefix_str, bins_ary[-1]))

# def dumpCuda():
#     # small_count = 0
#     total_bytes = 0
#     size2count_dict = collections.defaultdict(int)
#     size2bytes_dict = {}
#     for obj in gc.get_objects():
#         if isinstance(obj, torch.cuda._CudaBase):
#             nbytes = 4
#             for n in obj.size():
#                 nbytes *= n
#
#             size2count_dict[tuple([obj.get_device()] + list(obj.size()))] += 1
#             size2bytes_dict[tuple([obj.get_device()] + list(obj.size()))] = nbytes
#
#             total_bytes += nbytes
#
#     # print(small_count, "tensors equal to or less than than 16 bytes")
#     for size, count in sorted(size2count_dict.items(), key=lambda sc: (size2bytes_dict[sc[0]] * sc[1], sc[1], sc[0])):
#         print('{:4}x'.format(count), '{:10,}'.format(size2bytes_dict[size]), size)
#     print('{:10,}'.format(total_bytes), "total bytes")


def enumerateWithEstimate(
        iter,
        desc_str,
        start_ndx=0,
        print_ndx=4,
        backoff=None,
        iter_len=None,
):
    """
    In terms of behavior, `enumerateWithEstimate` is almost identical
    to the standard `enumerate` (the differences are things like how
    our function returns a generator, while `enumerate` returns a
    specialized `<enumerate object at 0x...>`).

    However, the side effects (logging, specifically) are what make the
    function interesting.

    :param iter: `iter` is the iterable that will be passed into
        `enumerate`. Required.

    :param desc_str: This is a human-readable string that describes
        what the loop is doing. The value is arbitrary, but should be
        kept reasonably short. Things like `"epoch 4 training"` or
        `"deleting temp files"` or similar would all make sense.

    :param start_ndx: This parameter defines how many iterations of the
        loop should be skipped before timing actually starts. Skipping
        a few iterations can be useful if there are startup costs like
        caching that are only paid early on, resulting in a skewed
        average when those early iterations dominate the average time
        per iteration.

        NOTE: Using `start_ndx` to skip some iterations makes the time
        spent performing those iterations not be included in the
        displayed duration. Please account for this if you use the
        displayed duration for anything formal.

        This parameter defaults to `0`.

    :param print_ndx: determines which loop interation that the timing
        logging will start on. The intent is that we don't start
        logging until we've given the loop a few iterations to let the
        average time-per-iteration a chance to stablize a bit. We
        require that `print_ndx` not be less than `start_ndx` times
        `backoff`, since `start_ndx` greater than `0` implies that the
        early N iterations are unstable from a timing perspective.

        `print_ndx` defaults to `4`.

    :param backoff: This is used to how many iterations to skip before
        logging again. Frequent logging is less interesting later on,
        so by default we double the gap between logging messages each
        time after the first.

        `backoff` defaults to `2` unless iter_len is > 1000, in which
        case it defaults to `4`.

    :param iter_len: Since we need to know the number of items to
        estimate when the loop will finish, that can be provided by
        passing in a value for `iter_len`. If a value isn't provided,
        then it will be set by using the value of `len(iter)`.

    :return:
    """
    if iter_len is None:
        iter_len = len(iter)

    if backoff is None:
        backoff = 2
        while backoff ** 7 < iter_len:
            backoff *= 2

    assert backoff >= 2
    while print_ndx < start_ndx * backoff:
        print_ndx *= backoff

    log.warning("{} ----/{}, starting".format(
        desc_str,
        iter_len,
    ))
    start_ts = time.time()
    for (current_ndx, item) in enumerate(iter):
        yield (current_ndx, item)
        if current_ndx == print_ndx:
            # ... <1>
            duration_sec = ((time.time() - start_ts)
                            / (current_ndx - start_ndx + 1)
                            * (iter_len-start_ndx)
                            )

            done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
            done_td = datetime.timedelta(seconds=duration_sec)

            log.info("{} {:-4}/{}, done at {}, {}".format(
                desc_str,
                current_ndx,
                iter_len,
                str(done_dt).rsplit('.', 1)[0],
                str(done_td).rsplit('.', 1)[0],
            ))

            print_ndx *= backoff

        if current_ndx + 1 == start_ndx:
            start_ts = time.time()

    log.warning("{} ----/{}, done at {}".format(
        desc_str,
        iter_len,
        str(datetime.datetime.now()).rsplit('.', 1)[0],
    ))

#
# try:
#     import matplotlib
#     matplotlib.use('agg', warn=False)
#
#     import matplotlib.pyplot as plt
#     # matplotlib color maps
#     cdict = {'red':   ((0.0,  1.0, 1.0),
#                        # (0.5,  1.0, 1.0),
#                        (1.0,  1.0, 1.0)),
#
#              'green': ((0.0,  0.0, 0.0),
#                        (0.5,  0.0, 0.0),
#                        (1.0,  0.5, 0.5)),
#
#              'blue':  ((0.0,  0.0, 0.0),
#                        # (0.5,  0.5, 0.5),
#                        # (0.75, 0.0, 0.0),
#                        (1.0,  0.0, 0.0)),
#
#              'alpha':  ((0.0, 0.0, 0.0),
#                        (0.75, 0.5, 0.5),
#                        (1.0,  0.5, 0.5))}
#
#     plt.register_cmap(name='mask', data=cdict)
#
#     cdict = {'red':   ((0.0,  0.0, 0.0),
#                        (0.25,  1.0, 1.0),
#                        (1.0,  1.0, 1.0)),
#
#              'green': ((0.0,  1.0, 1.0),
#                        (0.25,  1.0, 1.0),
#                        (0.5, 0.0, 0.0),
#                        (1.0,  0.0, 0.0)),
#
#              'blue':  ((0.0,  0.0, 0.0),
#                        # (0.5,  0.5, 0.5),
#                        # (0.75, 0.0, 0.0),
#                        (1.0,  0.0, 0.0)),
#
#              'alpha':  ((0.0, 0.15, 0.15),
#                        (0.5,  0.3, 0.3),
#                        (0.8,  0.0, 0.0),
#                        (1.0,  0.0, 0.0))}
#
#     plt.register_cmap(name='maskinvert', data=cdict)
# except ImportError:
#     pass


### disk.py

In [None]:

import gzip

from cassandra.cqltypes import BytesType

# And the BytesIO line should be changed to the following:
from diskcache import FanoutCache, Disk,core
from diskcache.core import io
from io import BytesIO
from diskcache.core import MODE_BINARY

import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)


class GzipDisk(Disk):
    def store(self, value, read, key=None):
        """
        Override from base class diskcache.Disk.

        Chunking is due to needing to work on pythons < 2.7.13:
        - Issue #27130: In the "zlib" module, fix handling of large buffers
          (typically 2 or 4 GiB).  Previously, inputs were limited to 2 GiB, and
          compression and decompression operations did not properly handle results of
          2 or 4 GiB.

        :param value: value to convert
        :param bool read: True when value is file-like object
        :return: (size, mode, filename, value) tuple for Cache table
        """
        # pylint: disable=unidiomatic-typecheck
        if type(value) is BytesType:
            if read:
                value = value.read()
                read = False

            str_io = BytesIO()
            gz_file = gzip.GzipFile(mode='wb', compresslevel=1, fileobj=str_io)

            for offset in range(0, len(value), 2**30):
                gz_file.write(value[offset:offset+2**30])
            gz_file.close()

            value = str_io.getvalue()

        return super(GzipDisk, self).store(value, read)


    def fetch(self, mode, filename, value, read):
        """
        Override from base class diskcache.Disk.

        Chunking is due to needing to work on pythons < 2.7.13:
        - Issue #27130: In the "zlib" module, fix handling of large buffers
          (typically 2 or 4 GiB).  Previously, inputs were limited to 2 GiB, and
          compression and decompression operations did not properly handle results of
          2 or 4 GiB.

        :param int mode: value mode raw, binary, text, or pickle
        :param str filename: filename of corresponding value
        :param value: database value
        :param bool read: when True, return an open file handle
        :return: corresponding Python value
        """
        value = super(GzipDisk, self).fetch(mode, filename, value, read)

        if mode == MODE_BINARY:
            str_io = BytesIO(value)
            gz_file = gzip.GzipFile(mode='rb', fileobj=str_io)
            read_csio = BytesIO()

            while True:
                uncompressed_data = gz_file.read(2**30)
                if uncompressed_data:
                    read_csio.write(uncompressed_data)
                else:
                    break

            value = read_csio.getvalue()

        return value

def getCache(scope_str):
    return FanoutCache('data-unversioned/cache/' + scope_str,
                       disk=GzipDisk,
                       shards=64,
                       timeout=1,
                       size_limit=3e11,
                       # disk_min_file_size=2**20,
                       )

# def disk_cache(base_path, memsize=2):
#     def disk_cache_decorator(f):
#         @functools.wraps(f)
#         def wrapper(*args, **kwargs):
#             args_str = repr(args) + repr(sorted(kwargs.items()))
#             file_str = hashlib.md5(args_str.encode('utf8')).hexdigest()
#
#             cache_path = os.path.join(base_path, f.__name__, file_str + '.pkl.gz')
#
#             if not os.path.exists(os.path.dirname(cache_path)):
#                 os.makedirs(os.path.dirname(cache_path), exist_ok=True)
#
#             if os.path.exists(cache_path):
#                 return pickle_loadgz(cache_path)
#             else:
#                 ret = f(*args, **kwargs)
#                 pickle_dumpgz(cache_path, ret)
#                 return ret
#
#         return wrapper
#
#     return disk_cache_decorator
#
#
# def pickle_dumpgz(file_path, obj):
#     log.debug("Writing {}".format(file_path))
#     with open(file_path, 'wb') as file_obj:
#         with gzip.GzipFile(mode='wb', compresslevel=1, fileobj=file_obj) as gz_file:
#             pickle.dump(obj, gz_file, pickle.HIGHEST_PROTOCOL)
#
#
# def pickle_loadgz(file_path):
#     log.debug("Reading {}".format(file_path))
#     with open(file_path, 'rb') as file_obj:
#         with gzip.GzipFile(mode='rb', fileobj=file_obj) as gz_file:
#             return pickle.load(gz_file)
#
#
# def dtpath(dt=None):
#     if dt is None:
#         dt = datetime.datetime.now()
#
#     return str(dt).rsplit('.', 1)[0].replace(' ', '--').replace(':', '.')
#
#
# def safepath(s):
#     s = s.replace(' ', '_')
#     return re.sub('[^A-Za-z0-9_.-]', '', s)


## CT file loader

In [None]:
import SimpleITK as sitk


class Ct:
    def __init__(self, series_uid):
        mhd_path = glob.glob(
            '/kaggle/input/luna-lung-cancer-dataset/seg-lungs-LUNA16/seg-lungs-LUNA16/{}.mhd'.format(series_uid)
        )[0]

        ct_mhd = sitk.ReadImage(mhd_path)
        # ct_a
        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
        # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
        # The upper bound nukes any weird hotspots and clamps bone down
        ct_a.clip(-1000, 1000, ct_a)

        self.series_uid = series_uid
        self.hu_a = ct_a

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

    def getRawCandidate(self, center_xyz, width_irc):
        center_irc = xyz2irc(
            center_xyz,
            self.origin_xyz,
            self.vxSize_xyz,
            self.direction_a,
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])

            if start_ndx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]

        return ct_chunk, center_irc
