In [None]:
_exp_name = "CTSlice"

In [None]:
# Import necessary packages.
import numpy as np
import pandas as pd
import torch
import os
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder, VisionDataset
# This is for the progress bar.
from tqdm.auto import tqdm
import random
from random import shuffle

# cache
from functools import lru_cache

In [None]:
myseed = 6666  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
#This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

In [None]:
class DADataset(Dataset):
    def __init__(self, path, files=None):
        super(DADataset, self).__init__()
        self.path = path

        # Read mdh file
        data = sitk.ReadImage(self.path)
        ct_scan = sitk.GetArrayFromImage(data)
        ct_scan[ct_scan < -1000] = -1000
        ct_scan[ct_scan > 1000] = 1000

        self.num_slices = ct_scan.shape[0]
        self.files = ct_scan

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # Split the string using the underscore as a delimiter
        pid = self.path.split("\\")[-1][:-4]
        slice_num = idx
        if idx == len(self.files) - 1:
            im = torch.from_numpy(np.stack((self.files[idx], self.files[idx]), axis=0)).float()
            label = torch.from_numpy(self.files[idx]).float()
            print("last", idx)
        else:
            im = torch.from_numpy(np.stack((self.files[idx], self.files[idx + 1]), axis=0)).float()
            label = torch.from_numpy(self.files[idx]).float()

        return pid, slice_num, im, label, self.num_slices

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.down_conv1 = self.double_conv(2, 64)
        self.down_conv2 = self.double_conv(64, 128)
        self.down_conv3 = self.double_conv(128, 256)
        self.down_conv4 = self.double_conv(256, 512)
        self.up_conv1 = self.double_conv(512 + 256, 256)
        self.up_conv2 = self.double_conv(256 + 128, 128)
        self.up_conv3 = self.double_conv(128 + 64, 64)
        self.up_conv4 = nn.Conv2d(64, 1, kernel_size=1)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Downward path
        x1 = self.down_conv1(x)
        x2 = self.maxpool(x1)
        x3 = self.down_conv2(x2)
        x4 = self.maxpool(x3)
        x5 = self.down_conv3(x4)
        x6 = self.maxpool(x5)
        x7 = self.down_conv4(x6)

        # Upward path
        x = self.upsample(x7)
        x = torch.cat([x, x5], dim=1)
        x = self.up_conv1(x)
        x = self.upsample(x)
        x = torch.cat([x, x3], dim=1)
        x = self.up_conv2(x)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv3(x)
        x = self.up_conv4(x)

        return x

In [None]:
# "cuda" only when GPUs are available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize a model and put it on the specified device.
model = UNet().to(device)

# The number of batch size.
batch_size = 10

In [None]:
import matplotlib.pyplot as plt
from numpy import save

In [None]:
# Construct test datasets.
# The argument "loader" tells how torchvision reads the data.
model_best = UNet().to(device)
model_best.load_state_dict(torch.load(f"{_exp_name}_best.ckpt"))
model_best.eval()

write_root = "F:\Luna16_new2_data/Luna16_img"
if not os.path.exists(write_root):
    os.mkdir("F:\Luna16_new2_data")
    os.mkdir("F:\Luna16_new2_data/Luna16_img")

# split to train test
for sub_root in os.listdir("../Luna16_data/Luna16_img"):
    root = os.path.join("../Luna16_data/Luna16_img", sub_root)
    for path in os.listdir(root):
        if path.find("mhd") >= 0:
            imagePath = os.path.join(root, path)
            test_set =  DADataset(imagePath)
            test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

            # start testing
            with torch.no_grad():
                # init var
                count = 0
                pred_mat = []
                orig_mat = []

                # start pred
                slicesize = 0
                for _, slice_num, data, _, sliceSize in tqdm(test_loader):
                    orig_mat.append(torch.narrow(data, 1, 0, 1))
                    if (slice_num[-1].item() + 1) < float(sliceSize[0].item()):
                        test_pred = model_best(data.to(device))
                        pred_mat.append(test_pred)
                        # print(slice_num[-1].item())
                    slicesize = sliceSize[0].item()

                # save orig and pred content

                # Initialize an index to keep track of where to insert tensors
                orig_np = np.empty((slicesize, 512, 512))
                index = 0

                # Iterate through the list of tensors and insert them into 'pred_np'
                for i, tensor in enumerate(orig_mat):
                    num_elements = tensor.size(0)
                    orig_np[index:index + num_elements] = tensor.cpu().numpy().reshape(num_elements, 512, 512)
                    index += num_elements

                # orig_np = np.empty((slicesize, 512, 512))
                # # orig to np array
                # for i, tensor in enumerate(orig_mat):
                #     if tensor.size(0) == 10:
                #         orig_np[i * 10:(i + 1) * 10] = tensor.cpu().numpy().reshape(10, 512, 512)
                #     elif tensor.size(0) < 10:
                #         orig_np[-tensor.size(0):] = tensor.cpu().numpy().reshape(tensor.size(0), 512, 512)

                # Initialize an index to keep track of where to insert tensors
                pred_np = np.empty((slicesize - 1, 512, 512))
                index = 0

                # Iterate through the list of tensors and insert them into 'pred_np'
                for i, tensor in enumerate(pred_mat):
                    num_elements = tensor.size(0)
                    pred_np[index:index + num_elements] = tensor.cpu().numpy().reshape(num_elements, 512, 512)
                    index += num_elements

                # pred_np = np.empty((slicesize - 1, 512, 512))
                # # orig to np array
                # for i, tensor in enumerate(pred_mat):
                #     if tensor.size(0) == 10:
                #         pred_np[i * 10:(i + 1) * 10] = tensor.cpu().numpy().reshape(10, 512, 512)
                #     elif tensor.size(0) < 10:
                #         pred_np[-tensor.size(0):] = tensor.cpu().numpy().reshape(tensor.size(0), 512, 512)
                # print(orig_np.shape)
                # print(pred_np.shape)

                # Alternate the 2D matrices from orig_np and pred_np in the result array
                final = np.empty((orig_np.shape[0] + pred_np.shape[0], 512, 512))
                final[::2] = orig_np
                final[1::2] = pred_np
                # print(final.shape)
                # print(final)

                # Load or create a SimpleITK image
                current_image = sitk.ReadImage(imagePath)  # Replace with the path to your image

                current_origin = current_image.GetOrigin()
                current_spacing = current_image.GetSpacing()

                sitk_image = sitk.GetImageFromArray(final)
                # image_short = sitk.Cast(image, sitk.sitkInt16)

                # Set the image origin, spacing, and direction (modify as needed)
                sitk_image.SetOrigin((current_origin[0], current_origin[1], current_origin[2]))
                sitk_image.SetSpacing((current_spacing[0], current_spacing[1], current_spacing[2] / 2))

                # Save the image as a MetaImage file
                if not os.path.exists(os.path.join(write_root, sub_root)):
                    os.mkdir(os.path.join(write_root, sub_root))
                sitk.WriteImage(sitk_image, os.path.join(write_root, sub_root, path))
            # break
    break

  0%|          | 0/13 [00:00<?, ?it/s]

last 120


  0%|          | 0/12 [00:00<?, ?it/s]

last 118


  0%|          | 0/17 [00:00<?, ?it/s]

last 160


  0%|          | 0/54 [00:00<?, ?it/s]

last 537


  0%|          | 0/13 [00:00<?, ?it/s]

last 123


  0%|          | 0/20 [00:00<?, ?it/s]

last 194


  0%|          | 0/14 [00:00<?, ?it/s]

last 132


  0%|          | 0/68 [00:00<?, ?it/s]

last 671


  0%|          | 0/14 [00:00<?, ?it/s]

last 132


  0%|          | 0/49 [00:00<?, ?it/s]

last 482


  0%|          | 0/28 [00:00<?, ?it/s]

last 273


  0%|          | 0/20 [00:00<?, ?it/s]

last 196


  0%|          | 0/25 [00:00<?, ?it/s]

last 245


  0%|          | 0/30 [00:00<?, ?it/s]

last 294


  0%|          | 0/59 [00:00<?, ?it/s]

last 587


  0%|          | 0/14 [00:00<?, ?it/s]

last 139


  0%|          | 0/28 [00:00<?, ?it/s]

last 279


  0%|          | 0/16 [00:00<?, ?it/s]

last 156


  0%|          | 0/73 [00:00<?, ?it/s]

last 729


  0%|          | 0/28 [00:00<?, ?it/s]

last 274


  0%|          | 0/13 [00:00<?, ?it/s]

last 124


  0%|          | 0/15 [00:00<?, ?it/s]

last 146


  0%|          | 0/25 [00:00<?, ?it/s]

last 244


  0%|          | 0/20 [00:00<?, ?it/s]

last 194


  0%|          | 0/13 [00:00<?, ?it/s]

last 128


  0%|          | 0/14 [00:00<?, ?it/s]

last 132


  0%|          | 0/20 [00:00<?, ?it/s]

last 193


  0%|          | 0/28 [00:00<?, ?it/s]

last 275


  0%|          | 0/25 [00:00<?, ?it/s]

last 249


  0%|          | 0/28 [00:00<?, ?it/s]

last 279


  0%|          | 0/15 [00:00<?, ?it/s]

last 140


  0%|          | 0/21 [00:00<?, ?it/s]

last 208


  0%|          | 0/26 [00:00<?, ?it/s]

last 254


  0%|          | 0/19 [00:00<?, ?it/s]

last 182


  0%|          | 0/33 [00:00<?, ?it/s]

last 324


  0%|          | 0/49 [00:00<?, ?it/s]

last 480


  0%|          | 0/23 [00:00<?, ?it/s]

last 220


  0%|          | 0/30 [00:00<?, ?it/s]

last 296


  0%|          | 0/13 [00:00<?, ?it/s]

last 122


  0%|          | 0/28 [00:00<?, ?it/s]

last 279


  0%|          | 0/18 [00:00<?, ?it/s]

last 176


  0%|          | 0/12 [00:00<?, ?it/s]

last 116


  0%|          | 0/14 [00:00<?, ?it/s]

last 138


  0%|          | 0/48 [00:00<?, ?it/s]

last 470


  0%|          | 0/48 [00:00<?, ?it/s]

last 476


  0%|          | 0/25 [00:00<?, ?it/s]

last 249


  0%|          | 0/24 [00:00<?, ?it/s]

last 231


  0%|          | 0/25 [00:00<?, ?it/s]

last 243


  0%|          | 0/37 [00:00<?, ?it/s]

last 368


  0%|          | 0/30 [00:00<?, ?it/s]

last 299


  0%|          | 0/29 [00:00<?, ?it/s]

last 289


  0%|          | 0/13 [00:00<?, ?it/s]

last 126


  0%|          | 0/13 [00:00<?, ?it/s]

last 126


  0%|          | 0/16 [00:00<?, ?it/s]

last 150


  0%|          | 0/13 [00:00<?, ?it/s]

last 122


  0%|          | 0/13 [00:00<?, ?it/s]

last 128


  0%|          | 0/47 [00:00<?, ?it/s]

last 464


  0%|          | 0/48 [00:00<?, ?it/s]

last 473


  0%|          | 0/16 [00:00<?, ?it/s]

last 156


  0%|          | 0/27 [00:00<?, ?it/s]

last 266


  0%|          | 0/21 [00:00<?, ?it/s]

last 203


  0%|          | 0/14 [00:00<?, ?it/s]

last 138


  0%|          | 0/28 [00:00<?, ?it/s]

last 279


  0%|          | 0/11 [00:00<?, ?it/s]

last 108


  0%|          | 0/33 [00:00<?, ?it/s]

last 324


  0%|          | 0/18 [00:00<?, ?it/s]

last 175


  0%|          | 0/28 [00:00<?, ?it/s]

last 279


  0%|          | 0/74 [00:00<?, ?it/s]

last 732


  0%|          | 0/13 [00:00<?, ?it/s]

last 128


  0%|          | 0/15 [00:00<?, ?it/s]

last 146


  0%|          | 0/52 [00:00<?, ?it/s]

last 515


  0%|          | 0/55 [00:00<?, ?it/s]

last 544


  0%|          | 0/23 [00:00<?, ?it/s]

last 228


  0%|          | 0/27 [00:00<?, ?it/s]

last 264


  0%|          | 0/14 [00:00<?, ?it/s]

last 132


  0%|          | 0/14 [00:00<?, ?it/s]

last 132


  0%|          | 0/49 [00:00<?, ?it/s]

last 482


  0%|          | 0/12 [00:00<?, ?it/s]

last 119


  0%|          | 0/43 [00:00<?, ?it/s]

last 428


  0%|          | 0/40 [00:00<?, ?it/s]

last 391


  0%|          | 0/13 [00:00<?, ?it/s]

last 126


  0%|          | 0/15 [00:00<?, ?it/s]

last 140


  0%|          | 0/41 [00:00<?, ?it/s]

last 403


  0%|          | 0/14 [00:00<?, ?it/s]

last 132


  0%|          | 0/25 [00:00<?, ?it/s]

last 245


  0%|          | 0/20 [00:00<?, ?it/s]

last 199


  0%|          | 0/16 [00:00<?, ?it/s]

last 151


  0%|          | 0/13 [00:00<?, ?it/s]

last 123


  0%|          | 0/14 [00:00<?, ?it/s]

last 139
