In [1]:
import collections
from collections import namedtuple  
import glob
import SimpleITK as sitk
from torch.utils.data import Dataset
import os
import csv
import numpy as np
from util import XyzTuple, xyz2irc, logging, getCache
import functools
import random
import torch
import functools

In [2]:
CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz, classes')

In [3]:
def getCandidateInfoDict(requireOnDisk_bool=True):  #把candidateInfoList包成Dict
    candidateInfo_list = getCandidateInfoList()
    candidateInfo_dict = {}

    for candidateInfo_tup in candidateInfo_list:
        candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
                                      []).append(candidateInfo_tup)

    return candidateInfo_dict

In [4]:
def getCandidateInfoList():
    mhd_list = glob.glob('E:/LUNA/Luna16_AugData/subset*/*.mhd')
    # mhd_list = glob.glob('C:/Users/oplab/Desktop/Luna16_data/Luna16_img/subset0/*.mhd')
    # mhd_list = glob.glob('../Luna_Test_Data/subset*/*.mhd')
    # print(mhd_list)
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    candidateInfo_list = []
    with open('E:/LUNA/annotation/annotations_with_malignancy.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])
            isMal_bool = {'False': False, 'True': True}[row[5]] #it record the malignancy or not
            classes = 0 if isMal_bool else 1

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            candidateInfo_list.append(
                CandidateInfoTuple(
                    True,
                    True,
                    isMal_bool,
                    annotationDiameter_mm,
                    series_uid,
                    annotationCenter_xyz,
                    classes
                )
            )

    with open('E:/LUNA/annotation/candidates_V2.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])

            if not isNodule_bool:
                candidateInfo_list.append(
                    CandidateInfoTuple(
                        False,
                        False,
                        False,
                        0.0,
                        series_uid,
                        candidateCenter_xyz,
                        2
                    )
                )

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list

In [5]:
candidateInfoDict = getCandidateInfoDict()

In [6]:
class Ct:
    def __init__(self, series_uid):
        # mhd_path = glob.glob(
        #     # '../data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
        #     '../Luna_Data/subset*/{}.mhd'.format(series_uid))[0]
        mhd_path = glob.glob('E:/LUNA/Luna16_AugData/subset*/{}.mhd'.format(series_uid))
        # mhd_path = glob.glob('C:/Users/oplab/Desktop/Luna16_data/Luna16_img/subset0/{}.mhd'.format(series_uid))
        # mhd_path = glob.glob('../Luna_Test_Data/subset*/{}.mhd'.format(series_uid))
        
        # print(mhd_path)

        ct_mhd = sitk.ReadImage(mhd_path)
        if ct_mhd.GetDimension()==4 and ct_mhd.GetSize()[3]==1:
            ct_mhd = ct_mhd[...,0]
        self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.

        self.series_uid = series_uid
        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

        candidateInfo_list = candidateInfoDict[self.series_uid]

        self.positiveInfo_list = [
            candidate_tup
            for candidate_tup in candidateInfo_list
            if candidate_tup.isNodule_bool
        ] #只將nodule放入list中
        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list) #mask的大小和hu_a一致
        self.positive_indexes = (self.positive_mask.sum(axis=(1,2)) #axis=(1,2)是將row和column上的所有true加起來
                                 .nonzero()[0].tolist())  #會將所有有被標記的記進來
        # print(self.positive_indexes)

    #為了幫所有的pixel記上是否為nodule的label，我們需要mask，並使用threshold的方式來框出哪裡是nodule而哪裡不是
    def getFullMask(self):
        return self.positive_mask
    def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
        boundingBox_a = np.zeros_like(self.hu_a, dtype=bool) # all False tensor

        for candidateInfo_tup in positiveInfo_list: #traverse所有的nodules
            center_irc = xyz2irc(
                candidateInfo_tup.center_xyz,
                self.origin_xyz,
                self.vxSize_xyz,
                self.direction_a,
            )
            ci = int(center_irc.index)  # the center of voxel
            cr = int(center_irc.row)
            cc = int(center_irc.col)

            index_radius = 2
            try:
                # 從index找，看哪裡會遇到空氣，當兩邊都遇到空氣後就把邊界設在大的那段
                while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
                        self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
                    index_radius += 1
            except IndexError:
                index_radius -= 1

            row_radius = 2
            try:
                # 從row找，看哪裡會遇到空氣，當兩邊都遇到空氣後就把邊界設在大的那段
                while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
                        self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
                    row_radius += 1
            except IndexError:
                row_radius -= 1

            col_radius = 2
            try:
                # 從column找，看哪裡會遇到空氣，當兩邊都遇到空氣後就把邊界設在大的那段
                while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
                        self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
                    col_radius += 1
            except IndexError:
                col_radius -= 1


            boundingBox_a[
                 ci - index_radius: ci + index_radius + 1,
                 cr - row_radius: cr + row_radius + 1,
                 cc - col_radius: cc + col_radius + 1] = True #將box裡的所有格子設成TRUE

        mask_a = boundingBox_a & (self.hu_a > threshold_hu)  #最後會對box和threshold低於-700的值做and

        return mask_a
    def getRawFullCT(self, center_xyz, contextSlices_count):
        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
                             self.direction_a)
        slice_ndx = center_irc.index
        ct_a = self.hu_a
        pos_a = self.positive_mask
        
        ct_t = torch.zeros((contextSlices_count * 2 + 1, 512, 512))  #預設是上下兩張

        start_ndx = slice_ndx - contextSlices_count
        end_ndx = slice_ndx + contextSlices_count + 1
        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
            context_ndx = max(context_ndx, 0) #避免邊界，遇到邊界會重複
            context_ndx = min(context_ndx, ct_a.shape[0] - 1)
            ct_t[i] = torch.from_numpy(ct_a[context_ndx].astype(np.float32))

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
        # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
        # The upper bound nukes any weird hotspots and clamps bone down
        ct_t.clamp_(-1000, 1000)

        pos_t = torch.from_numpy(pos_a[slice_ndx - contextSlices_count : slice_ndx + contextSlices_count + 1])
        
        return ct_t, pos_t, center_irc.index, center_irc.row, center_irc.col

In [7]:
candi_list = getCandidateInfoList()
print(candi_list[0])

CandidateInfoTuple(isNodule_bool=True, hasAnnotation_bool=True, isMal_bool=True, diameter_mm=32.27003025, series_uid='1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886', center_xyz=(67.82725575, 85.37992457, -109.74672379999998), classes=0)


In [8]:
print(len(candi_list))

754600


In [9]:
for i, candi in enumerate(candi_list):
    origin, mask, index, row, col = Ct(candi.series_uid).getRawFullCT(candi.center_xyz, 5)
    torch.save(origin, 'E:/LUNA/Luna_Classification_Data/origin/{}_({},{},{})_{}.pt'.format(candi.classes, index, row, col, candi.series_uid))
    torch.save(mask, 'E:/LUNA/Luna_Classification_Data/mask/{}_({},{},{})_{}.pt'.format(candi.classes, index, row, col, candi.series_uid))
    if i % 100 == 0:
        print("finised image: ", i)
    # if (i < 3):
    #     fullMask = Ct(candi.series_uid).getFullMask()
    #     np.save('E:/LUNA/Luna_Classification_Data/grouping/{}_({},{},{})_{}.npy'.format(candi.classes, index, row, col, candi.series_uid), fullMask)

finised image:  0
finised image:  100
finised image:  200
finised image:  300
finised image:  400
finised image:  500
finised image:  600
finised image:  700
finised image:  800
finised image:  900
finised image:  1000
finised image:  1100
finised image:  1200
finised image:  1300
finised image:  1400
finised image:  1500
finised image:  1600
finised image:  1700
finised image:  1800
finised image:  1900
finised image:  2000
finised image:  2100
finised image:  2200
finised image:  2300
finised image:  2400
finised image:  2500
finised image:  2600
finised image:  2700
finised image:  2800
finised image:  2900
finised image:  3000
finised image:  3100
finised image:  3200
finised image:  3300
finised image:  3400
finised image:  3500
finised image:  3600
finised image:  3700
finised image:  3800
finised image:  3900
finised image:  4000
finised image:  4100
finised image:  4200
finised image:  4300
finised image:  4400
finised image:  4500
finised image:  4600
finised image:  4700
fini

KeyboardInterrupt: 

In [None]:
print(origin.size())
print(mask.size())