In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from collections import namedtuple
import csv
import functools
import glob
import os
import sys

sys.path.append(os.path.abspath('..'))

CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, diameter_mm, series_uid, center_xyz',)

@functools.lru_cache(1) # cache result on memory to avoid hit the disk every time
def getCandidateInfoList(requreOnDisk_bool = True):
    mhd_list = glob.glob('../data/luna16/subset*/*.mhd')
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    diameter_dict = {}
    with open('../data/luna16/annotations.csv', 'r') as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])
            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
    
    candidateInfo_list = []
    with open('../data/luna16/candidates.csv', 'r') as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requreOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])

            # we need to find the nearest annotation since candidates and annotations data approximate nodule center in different ways
            candidateDiameter_mm = 0.0
            for annotation_tup in diameter_dict.get(series_uid, []):
                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                for i in range(3):
                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                    if delta_mm > annotationDiameter_mm / 4:
                        break
                else:
                    candidateDiameter_mm = annotationDiameter_mm
                    break
            
            candidateInfo_list.append(CandidateInfoTuple(isNodule_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list

In [4]:
getCandidateInfoList()

[CandidateInfoTuple(isNodule_bool=True, diameter_mm=32.27003025, series_uid='1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886', center_xyz=(67.61451718, 85.02525992, -109.8084416)),
 CandidateInfoTuple(isNodule_bool=True, diameter_mm=25.23320204, series_uid='1.3.6.1.4.1.14519.5.2.1.6279.6001.511347030803753100045216493273', center_xyz=(63.4740118048, 73.9174523314, -213.736128767)),
 CandidateInfoTuple(isNodule_bool=True, diameter_mm=23.35064438, series_uid='1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192', center_xyz=(57.42, 81.14, -118.09)),
 CandidateInfoTuple(isNodule_bool=True, diameter_mm=23.35064438, series_uid='1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192', center_xyz=(56.4889724157, 85.9418105037, -115.9731945)),
 CandidateInfoTuple(isNodule_bool=True, diameter_mm=23.35064438, series_uid='1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192', center_xyz=(50.4361084414, 90.0424445754, -113.908439345)),
 CandidateI

In [5]:
import numpy as np
import SimpleITK as sitk

from util import XyzTuple, xyz2irc

class Ct:
    def __init__(self, series_uid):
        mhd_path = glob.glob(f'../data/luna16/subset*/{series_uid}.mhd')[0]
        ct_mhd = sitk.ReadImage(mhd_path)
        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        # voxels are expressed in Hounsfield units HU.
        # -1000 => the lowest density => air
        # 0 => water
        # 1000 => bones
        # < -1000 usually means out of range of scanner
        ct_a.clip(-1000, 1000, ct_a)

        self.series_uid = series_uid
        self.hu_a = ct_a

        # our image have voxel coordinates I,R,C (index, row, column)
        # annotations use patient coordinate system XYZ (X positive=>patient left, Y positive=>patient behind, Z positive=> patient head) in mm
        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

    def getRawCandidate(self, center_xyz, width_irc):
        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])

            if start_ndx < 0:
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]
        return ct_chunk, center_irc

In [6]:
from disk import getCache
raw_cache = getCache('part2ch10_raw')

@functools.lru_cache(1, typed=True)
def getCt(series_uid):
    return Ct(series_uid)

@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    return ct_chunk, center_irc

In [7]:
import copy
from torch.utils.data import Dataset

class LunaDataset(Dataset):
    def __init__(self, val_stride=0, isValSet_bool=None, series_uid=None) -> None:
        super().__init__()
        self.candidateInfo_list = copy.copy(getCandidateInfoList())
        if series_uid:
            self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid == series_uid]

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.candidateInfo_list = self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list
        elif val_stride > 0:
            del self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list

    def __len__(self):
        return len(self.candidateInfo_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.candidateInfo_list[ndx]
        width_irc = (32, 48, 48)

        candidate_a, center_irc = getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, width_irc)

        candidate_t = torch.from_numpy(candidate_a)
        candidate_t = candidate_t.to(torch.float32)
        candidate_t = candidate_t.unsqueeze(0)

        pos_t = torch.tensor([
                not candidateInfo_tup.isNodule_bool,
                candidateInfo_tup.isNodule_bool
            ],
            dtype=torch.long,
        )

        return (candidate_t, pos_t, candidateInfo_tup.series_uid, torch.tensor(center_irc),)

In [25]:
dataset = LunaDataset(10)

count = 0
for x in dataset:
    count += 1
    if count > 10000:
        break

In [26]:
count = 0
for x in dataset:
    count += 1
    if count > 10000:
        break