In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pydicom
!pip install opencv-python
!pip install pillow # optional
!pip install pandas
!pip3 install numpy
!pip3 install dicom2nifti
!pip3 install nibabel
!pip3 install pydicom
!pip3 install tqdm
!pip3 install nilearn
!pip install --quiet torchio==0.18.90

Collecting pydicom
  Downloading pydicom-2.4.3-py3-none-any.whl (1.8 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.8 MB[0m [31m6.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.7/1.8 MB[0m [31m10.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━[0m [32m1.4/1.8 MB[0m [31m13.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-2.4.3
Collecting dicom2nifti
  Downloading dicom2nifti-2.4.9-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Col

In [None]:
import pathlib as plb
import tempfile
import os
import dicom2nifti
import nibabel as nib
import numpy as np
import pydicom
import sys
import shutil
import nilearn.image
from tqdm import tqdm

import enum
import time
import random
import multiprocessing
from pathlib import Path

import torch
import torchvision
import torchio as tio
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.utils import save_image

import numpy as np
# from unet import UNet
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt

from IPython import display
from tqdm.auto import tqdm
from pathlib import Path

from PIL import Image

from skimage.measure import regionprops_table

In [None]:
debug = False

In [None]:
# Processor Class
class TwoDimensionSlicesProcessing():
  def __init__(self,
               nifti_folder: str,
               output_folder: str,
               label: str,
               label_file: str,
               channels: int,
               save_negatives: bool,
               voxel_size: int = 2,
               plane: str = 'axial'
  ):

    self.nifti_folder = plb.Path(nifti_folder)
    self.output_folder = plb.Path(output_folder)

    # make folder
    if not os.path.isdir(self.output_folder):
      os.mkdir(self.output_folder)

    self.delim = ','

    self.label = label
    self.label_file = plb.Path(label_file)

    # create label file with headers
    if not os.path.isfile(self.label_file):
      self.label_file.parent.mkdir(parents=True, exist_ok=True)

      with open(self.label_file, 'w') as op:
        op.write(
            (f'img_filename{self.delim}'
            f'x_min{self.delim}y_min{self.delim}'
            f'x_max{self.delim}y_max{self.delim}'
            f'cancer_type{self.delim}'
            f'img_width{self.delim}'
            f'img_height\n')
        )

    self.voxel_size = voxel_size
    self.transform = tio.transforms.Resample(self.voxel_size)

    self.channels = channels
    assert self.channels in [1, 3]

    self.save_negatives = save_negatives

    self.plane = plane
    self.plane_to_axis = {
        'axial': 2,     # x-axis
        'coronal': 1,   # y-axis
        'sagittal': 0   # z-axis
    }
    self.axis = self.plane_to_axis.get(self.plane, None)

    assert self.axis is not None


  def find_ct_files(self):
    # find all ct files
    patient_dirs = list(self.nifti_folder.glob('*'))
    # print(patient_dirs)
    ct_dirs = []
    ct_files = 'CTres.nii.gz'
    seg_files = 'SEG.nii.gz'

    for dir in patient_dirs:
      # print(dir)
      sub_dirs = list(dir.rglob(ct_files)) + list(dir.rglob(seg_files))
      # print(sub_dirs)
      if len(sub_dirs)==2:
        ct_dirs.append(sub_dirs) # list of lists[2] with matching SUV and SEG files
      else:
        continue

    return ct_dirs


  def load_and_standardize_spacing(self, ct_file, seg_file):
    # load files
    ct_img = tio.ScalarImage(ct_file)
    seg_label = tio.LabelMap(seg_file)
    print(torch.max(ct_img.data), torch.max(seg_label.data))

    # standardize spacing
    return self.transform(ct_img).data, self.transform(seg_label).data


  def get_2d_bb_indices(self, seg_label):
    # check if no positive labels in slice
    if torch.nonzero(seg_label).numel() == 0:
      return None

    # get connected components and bb labels
    props = regionprops_table(seg_label.numpy(), properties=('label', 'bbox'))
    bb_df = pd.DataFrame(props)

    bb_indices = []

    for index, row in bb_df.iterrows():
      # bbox-0 to bbox-3: (min_row, min_col, max_row, max_col)
      # Pixels belonging to the bounding box are in the half-open interval
      # [min_row; max_row) and [min_col; max_col).
      # ax1: column, ax2: row
      # want to return as ax1_min, ax2_min, ax1_max, ax2_max
      bb_indices.append([row['bbox-1'], row['bbox-0'], row['bbox-3'], row['bbox-2']])

    return bb_indices


  def process_three_channels(self, ct_img_permuted, a_i):
    axis_max = ct_img_permuted.shape[0]

    # skip if lower/upper slice out of bound
    if a_i == 0 or a_i == axis_max - 1:
      return None

    return ct_img_permuted[a_i-1:a_i+2, :, :]


  def run(self):
    # get list of subjects
    ct_nii_files = self.find_ct_files()

    # load SUV image
    for seg_file, ct_file in tqdm(ct_nii_files):
      # swap file names if needed
      if 'CT' in str(seg_file):
        seg_file, ct_file = ct_file, seg_file

      # get patient ID
      patient_id = seg_file.parts[-3]

      # load nii files
      ct_img, seg_label = self.load_and_standardize_spacing(ct_file, seg_file)

      # convert each slice to PIL Image and save
      ct_img = ct_img.squeeze() # convert to C x H x W from B x C x H x W
      seg_label = seg_label.squeeze()
      # print(ct_img.shape)

      # permute image by axis
      if self.axis == 0:
        # sagittal plane
        permute_axis = (0, 1, 2)
      elif self.axis == 1:
        # coronal plane
        permute_axis = (1, 0, 2)
      elif self.axis == 2:
        # axial plane
        permute_axis = (2, 0, 1)

      ct_img_permuted = torch.permute(ct_img, permute_axis)
      seg_label_permuted = torch.permute(seg_label, permute_axis)

      ax_max_idx = ct_img_permuted.shape[0]
      for a_i in range(0, ax_max_idx):
        ct_tensor_slice = ct_img_permuted[a_i, :, :]
        seg_label_slice = seg_label_permuted[a_i, :, :]

        # get bb indices
        slice_bb_indices = self.get_2d_bb_indices(seg_label_slice)

        if not self.save_negatives and not slice_bb_indices:
          # skip if not saving negatives or no lesion in image
          # print(slice_bb_indices)
          continue

        if self.channels == 3:
          ct_tensor_slice = self.process_three_channels(ct_img_permuted, a_i)
          # print(ct_tensor_slice.shape)

        if ct_tensor_slice is None:
          continue

        # save image
        img_file_name = f'{patient_id}_{self.plane}_{a_i:0>3}.jpg'
        img_path = self.output_folder / img_file_name

        kwargs = {'normalize': True}
        save_image(ct_tensor_slice, img_path, **kwargs)

        if slice_bb_indices:
          # create csv label
          with open(self.label_file, 'a') as op:
            if os.stat(self.label_file).st_size ==0:
              op.write(f'img_filename{self.delim}x_min{self.delim}y_min{self.delim}x_max{self.delim}y_max{self.delim}cancer_type{self.delim}img_width{self.delim}img_height\n')

            for ax1_min, ax2_min, ax1_max, ax2_max in slice_bb_indices:

              # writing the fields
              op.write(
                  (f'{img_file_name}{self.delim}'
                  f'{ax1_min}{self.delim}{ax2_min}{self.delim}'
                  f'{ax1_max}{self.delim}{ax2_max}{self.delim}'
                  f'{self.label}{self.delim}{ct_tensor_slice.shape[1]}{self.delim}{ct_tensor_slice.shape[2]}\n')
              )


In [None]:
import pandas as pd
import matplotlib.patches as patches

In [None]:
df = pd.read_csv('/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/Shared_csv/all_patients.csv')

def get_label(id):
  return (df[df['Subject ID']==id]['diagnosis'].values[0]).lower()

In [None]:
# split into 10 parts to
nifti_folders = os.listdir('/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/data/nifti')
nifti_folders = [nifti_folders[i:i + 90] for i in range(0, len(nifti_folders), 90)]
print(len(nifti_folders), len(nifti_folders[0]))

10 90


In [None]:
#
for count, f in enumerate(nifti_folders[0]):
  if (count+1) % 10 == 0:
    print(count+1)
  label = get_label(f)
  if os.path.isdir(f'/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/preprocessed_data/test_ys/CT_{label}/{f}')==True:
    shutil.rmtree(f'/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/preprocessed_data/test_ys/CT_{label}/{f}')
  os.makedirs(f'/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/preprocessed_data/test_ys/CT_{label}/{f}')
  DataProcessor = TwoDimensionSlicesProcessing(nifti_folder=f'/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/data/nifti/{f}',
                                              output_folder=f'/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/preprocessed_data/test_ys/CT_{label}',
                                              label=label,
                                              label_file=f'/content/drive/MyDrive/Capstone_GE_DSI_CV_Project/preprocessed_data/test_ys/CT_{label}/{f}/labels.csv',
                                              channels=3,
                                              save_negatives=False)
  DataProcessor.run()

  0%|          | 0/1 [00:00<?, ?it/s]

tensor(3497.8071) tensor(1, dtype=torch.uint8)


  0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: ignored