In [1]:
import argparse
import glob
import hashlib
import math
import os
import sys

import numpy as np
import scipy.ndimage.measurements as measure
import scipy.ndimage.morphology as morph

import torch
import torch.nn as nn
import torch.optim

from torch.utils.data import DataLoader

from util import enumerateWithEstimate

# from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getCandidateInfoList, CandidateInfoTuple
from dsets_ch13 import (
    Luna2dSegmentationDataset,
    getCt,
    getCandidateInfoList,
    getCandidateInfoDict,
    CandidateInfoTuple,
)
from dsets_ch14 import LunaDataset
from model_ch13 import UNetWrapper
from model_ch14 import LunaModel

from logconf import logging
from util import xyz2irc, irc2xyz
import scipy.ndimage.measurements as measurements
import scipy.ndimage.morphology as morphology

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
logging.getLogger("dsets_ch13").setLevel(logging.WARNING)
logging.getLogger("dsets_ch14").setLevel(logging.WARNING)

In [None]:
segmentation_path = "/home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state"
log.debug(segmentation_path)
classification_path = "/home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state"
log.debug(classification_path)
malignancy_path = None

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
def make_circle_conv(radius):
    diameter = 1 + radius * 2

    a = torch.linspace(-1, 1, steps=diameter) ** 2
    b = (a[None] + a[:, None]) ** 0.5

    circle_weights = (b <= 1.0).to(torch.float32)

    conv = nn.Conv3d(
        1,
        1,
        kernel_size=(1, diameter, diameter),
        padding=(0, radius, radius),
        bias=False,
    )
    conv.weight.data.fill_(1)
    conv.weight.data *= circle_weights / circle_weights.sum()

    return conv


conv_list = nn.ModuleList(
    [make_circle_conv(radius).to(device) for radius in range(1, 8)]
)


def erode(input_mask, radius, threshold=1):
    conv = conv_list[radius - 1]
    input_float = input_mask.to(torch.float32)
    result = conv(input_float)

    # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
    # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])

    return result >= threshold

In [10]:
def initModels():
    with open(segmentation_path, "rb") as f:
        log.debug(segmentation_path)
        log.debug(hashlib.sha1(f.read()).hexdigest())

    seg_dict = torch.load(segmentation_path)

    seg_model = UNetWrapper(
        in_channels=7,
        n_classes=1,
        depth=3,
        wf=4,
        padding=True,
        batch_norm=True,
        up_mode="upconv",
    )
    seg_model.load_state_dict(seg_dict["model_state"])
    seg_model.eval()

    with open(classification_path, "rb") as f:
        log.debug(classification_path)
        log.debug(hashlib.sha1(f.read()).hexdigest())

    cls_dict = torch.load(classification_path)

    cls_model = LunaModel()
    # cls_model = AlternateLunaModel()
    cls_model.load_state_dict(cls_dict["model_state"])
    cls_model.eval()

    if torch.cuda.device_count() > 1:
        seg_model = nn.DataParallel(seg_model)
        cls_model = nn.DataParallel(cls_model)

        seg_model = seg_model.to(device)
        cls_model = cls_model.to(device)

    conv_list = nn.ModuleList(
        [make_circle_conv(radius).to(device) for radius in range(1, 8)]
    )

    return seg_model, cls_model

In [None]:
seg_model, cls_model = initModels()

In [None]:
seg_model

In [3]:
def clusterSegmentationOutput(series_uid, ct, clean_g):
    clean_a = clean_g.cpu().numpy()
    candidateLabel_a, candidate_count = measure.label(clean_a)
    centerIrc_list = measure.center_of_mass(
        ct.hu_a.clip(-1000, 1000) + 1001,
        labels=candidateLabel_a,
        index=list(range(1, candidate_count + 1)),
    )

    candidateInfo_list = []
    for i, center_irc in enumerate(centerIrc_list):
        assert np.isfinite(center_irc).all(), repr(
            [
                series_uid,
                i,
                candidate_count,
                (ct.hu_a[candidateLabel_a == i + 1]).sum(),
                center_irc,
            ]
        )
        center_xyz = irc2xyz(
            center_irc,
            ct.origin_xyz,
            ct.vxSize_xyz,
            ct.direction_a,
        )
        diameter_mm = 0.0
        # pixel_count = (candidateLabel_a == i+1).sum()
        # area_mm2 = pixel_count * ct.vxSize_xyz[0] * ct.vxSize_xyz[1]
        # diameter_mm = 2 * (area_mm2 / math.pi) ** 0.5

        candidateInfo_tup = CandidateInfoTuple(
            None, None, None, diameter_mm, series_uid, center_xyz
        )
        candidateInfo_list.append(candidateInfo_tup)

    return candidateInfo_list, centerIrc_list, candidateLabel_a

========================================================================================================


In [1]:
import argparse
import glob
import hashlib
import math
import os
import sys

import numpy as np
import scipy.ndimage.measurements as measure
import scipy.ndimage.morphology as morph

import torch
import torch.nn as nn
import torch.optim

from torch.utils.data import DataLoader

from util import enumerateWithEstimate

# from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getCandidateInfoList, CandidateInfoTuple
from dsets_ch13 import (
    Luna2dSegmentationDataset,
    getCt,
    getCandidateInfoList,
    getCandidateInfoDict,
    CandidateInfoTuple,
)
from dsets_ch14 import LunaDataset
from model_ch13 import UNetWrapper
from model_ch14 import LunaModel

from logconf import logging
from util import xyz2irc, irc2xyz
import scipy.ndimage.measurements as measurements
import scipy.ndimage.morphology as morphology

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
logging.getLogger("dsets_ch13").setLevel(logging.WARNING)
logging.getLogger("dsets_ch14").setLevel(logging.WARNING)

In [2]:
segmentation_path = "/home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state"
log.debug(segmentation_path)
classification_path = "/home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state"
log.debug(classification_path)
malignancy_path = None

2024-11-03 18:41:26,812 DEBUG    pid:30412 __main__:002:<module> /home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state
2024-11-03 18:41:26,813 DEBUG    pid:30412 __main__:004:<module> /home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state


In [3]:
batch_size = 4
num_workers = 0
use_cuda = torch.cuda.is_available()
if use_cuda:
    device = "cuda"
else:
    device = "cpu"

In [10]:
import model_ch14


def initModels():
    log.debug(segmentation_path)
    seg_dict = torch.load(segmentation_path)

    seg_model = UNetWrapper(
        in_channels=7,
        n_classes=1,
        depth=3,
        wf=4,
        padding=True,
        batch_norm=True,
        up_mode="upconv",
    )

    seg_model.load_state_dict(seg_dict["model_state"])
    seg_model.eval()

    log.debug(classification_path)
    cls_dict = torch.load(classification_path)

    cls_model = "LunaModel"
    model_cls = getattr(model_ch14, cls_model)
    cls_model = model_cls()
    cls_model.load_state_dict(cls_dict["model_state"])
    cls_model.eval()

    if use_cuda:
        if torch.cuda.device_count() > 1:
            seg_model = nn.DataParallel(seg_model)
            cls_model = nn.DataParallel(cls_model)

        seg_model.to(device)
        cls_model.to(device)

    if malignancy_path:
        model_cls = getattr(model_ch14, malignancy_model)
        malignancy_model = model_cls()
        malignancy_dict = torch.load(malignancy_path)
        malignancy_model.load_state_dict(malignancy_dict["model_state"])
        malignancy_model.eval()
        if use_cuda:
            malignancy_model.to(device)
    else:
        malignancy_model = None
    return seg_model, cls_model, malignancy_model


seg_model, cls_model, malignancy_model = initModels()

2024-11-03 18:43:50,128 DEBUG    pid:30412 __main__:005:initModels /home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state
  seg_dict = torch.load(segmentation_path)
2024-11-03 18:43:50,146 DEBUG    pid:30412 __main__:021:initModels /home/lim/Desktop/other/ML/pytorch_dlwpt-code-master/data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state
  cls_dict = torch.load(classification_path)


In [5]:
def initSegmentationDl(series_uid):
    seg_ds = Luna2dSegmentationDataset(
        contextSlices_count=3,
        series_uid=series_uid,
        fullCt_bool=True,
    )
    seg_dl = DataLoader(
        seg_ds,
        batch_size=batch_size * (torch.cuda.device_count() if use_cuda else 1),
        num_workers=num_workers,
        pin_memory=use_cuda,
    )

    return seg_dl


def segmentCt(ct, series_uid):
    with torch.no_grad():
        output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
        seg_dl = initSegmentationDl(series_uid)  #  <3>
        for input_t, _, _, slice_ndx_list in seg_dl:

            input_g = input_t.to(device)
            # Move model to the correct device
            # seg_model = seg_model.to(device)
            prediction_g = seg_model(input_g)

            for i, slice_ndx in enumerate(slice_ndx_list):
                output_a[slice_ndx] = prediction_g[i].cpu().numpy()

        mask_a = output_a > 0.5
        mask_a = morphology.binary_erosion(mask_a, iterations=1)

    return mask_a

In [6]:
def groupSegmentationOutput(series_uid, ct, clean_a):
    candidateLabel_a, candidate_count = measurements.label(clean_a)
    centerIrc_list = measurements.center_of_mass(
        ct.hu_a.clip(-1000, 1000) + 1001,
        labels=candidateLabel_a,
        index=np.arange(1, candidate_count + 1),
    )

    candidateInfo_list = []
    for i, center_irc in enumerate(centerIrc_list):
        center_xyz = irc2xyz(
            center_irc,
            ct.origin_xyz,
            ct.vxSize_xyz,
            ct.direction_a,
        )
        assert np.all(np.isfinite(center_irc)), repr(
            ["irc", center_irc, i, candidate_count]
        )
        assert np.all(np.isfinite(center_xyz)), repr(["xyz", center_xyz])
        candidateInfo_tup = CandidateInfoTuple(
            False, False, False, 0.0, series_uid, center_xyz
        )
        candidateInfo_list.append(candidateInfo_tup)

    return candidateInfo_list

In [11]:
def initClassificationDl(candidateInfo_list):
    cls_ds = LunaDataset(
        sortby_str="series_uid",
        candidateInfo_list=candidateInfo_list,
    )
    cls_dl = DataLoader(
        cls_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=use_cuda,
    )

    return cls_dl


def classifyCandidates(ct, candidateInfo_list):
    cls_dl = initClassificationDl(candidateInfo_list)
    classifications_list = []
    for batch_ndx, batch_tup in enumerate(cls_dl):
        input_t, _, _, series_list, center_list = batch_tup

        input_g = input_t.to(device)
        with torch.no_grad():
            _, probability_nodule_g = cls_model(input_g)
            if malignancy_model is not None:
                _, probability_mal_g = malignancy_model(input_g)
            else:
                probability_mal_g = torch.zeros_like(probability_nodule_g)

        zip_iter = zip(
            center_list,
            probability_nodule_g[:, 1].tolist(),
            probability_mal_g[:, 1].tolist(),
        )
        for center_irc, prob_nodule, prob_mal in zip_iter:
            center_xyz = irc2xyz(
                center_irc,
                direction_a=ct.direction_a,
                origin_xyz=ct.origin_xyz,
                vxSize_xyz=ct.vxSize_xyz,
            )
            cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
            classifications_list.append(cls_tup)
    return classifications_list

In [12]:
import numpy as np

# # Set the device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Move model to the correct device
# seg_model = seg_model.to(device)

# # Move input tensor to the same device
# input_g = input_t.to(device)

# # Now, make the prediction
# prediction_g = seg_model(input_g)

np.bool = np.bool_
series_uid = "1.3.6.1.4.1.14519.5.2.1.6279.6001.979083010707182900091062408058"
ct = getCt(series_uid)
mask_a = segmentCt(ct, series_uid)

candidateInfo_list = groupSegmentationOutput(series_uid, ct, mask_a)
classifications_list = classifyCandidates(ct, candidateInfo_list)
# # )

  mask_a = morphology.binary_erosion(mask_a, iterations=1)
  candidateLabel_a, candidate_count = measurements.label(clean_a)
  centerIrc_list = measurements.center_of_mass(


In [13]:
len(candidateInfo_list), len(classifications_list)

(994, 994)

In [14]:
classifications_list

[(2.223405999757233e-06,
  0.0,
  XyzTuple(x=-157.05469751358032, y=-20.084469691711433, z=26.759995000000004),
  tensor([132, 186,  21])),
 (5.249447667665663e-07,
  0.0,
  XyzTuple(x=-157.05469751358032, y=-22.7407177843628, z=14.759995000000004),
  tensor([126, 182,  21])),
 (1.7852922837846563e-06,
  0.0,
  XyzTuple(x=-157.05469751358032, y=-21.412593738037117, z=24.759995000000004),
  tensor([131, 184,  21])),
 (0.0005576977855525911,
  0.0,
  XyzTuple(x=-148.42189121246338, y=57.61078701834106, z=-177.240005),
  tensor([ 30, 303,  34])),
 (0.023690884932875633,
  0.0,
  XyzTuple(x=-141.11720895767212, y=60.267035110992424, z=-179.240005),
  tensor([ 29, 307,  45])),
 (0.22037701308727264,
  0.0,
  XyzTuple(x=-137.7968988418579, y=40.34517441610717, z=-175.240005),
  tensor([ 31, 277,  50])),
 (0.2829154431819916,
  0.0,
  XyzTuple(x=-137.7968988418579, y=41.673298462432854, z=-175.240005),
  tensor([ 31, 279,  50])),
 (0.12494125962257385,
  0.0,
  XyzTuple(x=-136.46877479553223,