In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import shutil
import pydicom
import numpy as np
import nibabel as nib
from tqdm import tqdm
import multiprocessing
from joblib import Parallel, delayed

l = os.listdir
rme = os.rename
mk = os.makedirs
j = os.path.join
e = os.path.exists
basename = os.path.basename

### Rename dcm files per scan

In [2]:
def reindex(pth):
    dcm_pths = sorted([int(px.split('.')[0]) for px in l(pth) if px.endswith(".dcm")])
    for n,f in enumerate(dcm_pths):
        nw = j(pth, (str(n)+".dcm"))
        rme(j(pth, str(f)+".dcm"),nw)

### Helper function

In [3]:
def calc_black_pct(img):
    return (np.sum(img == 0) / (img.shape[0] * img.shape[1])) * 100

In [4]:
def process_mask(pth):
    m = nib.load(pth).get_fdata()

    m = np.transpose(m, [1, 0, 2])
    m = np.rot90(m, 1, (1, 2))
    m = m[::-1, :, :]
    m = np.transpose(m, [1, 0, 2])

    m = m.astype(np.float32)
    m[m < 0.5] = 0
    m[m>=0.5] =1
    m = m.astype(np.uint8)

    t = 5 
    m = m[::t,::2,::2]

    return m

In [5]:
def preprocess_dcm2img (dicom_image):

    pixel_array = dicom_image.pixel_array
    
    if dicom_image.PixelRepresentation == 1:
        bit_shift = dicom_image.BitsAllocated - dicom_image.BitsStored
        dtype = pixel_array.dtype 
        new_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
        pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dicom_image)
    
    if dicom_image.PhotometricInterpretation == "MONOCHROME1":
        pixel_array = 1 - pixel_array
    
    # transform to hounsfield units
    intercept = dicom_image.RescaleIntercept
    slope = dicom_image.RescaleSlope
    pixel_array = pixel_array * slope + intercept
    
    # windowing
    window_center = int(dicom_image.WindowCenter)
    window_width = int(dicom_image.WindowWidth)
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    pixel_array = pixel_array.copy()
    pixel_array[pixel_array < img_min] = img_min
    pixel_array[pixel_array > img_max] = img_max
    
    # normalization
    pixel_array = (pixel_array - pixel_array.min())/(pixel_array.max() - pixel_array.min())
    
    return (pixel_array * 255).astype(np.uint8)

In [6]:
def preprocess_single_dcm(dcm_path, jpgScan_pth, SIZE):

    dcm = pydicom.dcmread(dcm_path)
    img = preprocess_dcm2img(dcm) 
    img = cv2.resize(img, (SIZE, SIZE))

    #convert dicom to jpeg
    px = basename(dcm_path)
    out_pth = os.path.join(jpgScan_pth, px.replace(".dcm", ".jpeg"))
    
    cv2.imwrite(out_pth, img)

### Initial Preprocess

In [7]:

# helper function needed:
# 1- preprocess_dcm2img (dicom_image)
# 2- preprocess_single_dcm(dcm_path, jpgScan_pth, SIZE)

def initPreprocess(dicom_scanFolder_pth, jpgScan_pth, strt, end, TICK=5, SIZE=256):
    
    print("Initial preprocessing based ROI...")

    dcms = sorted([int(x.replace(".dcm","")) for x in l(dicom_scanFolder_pth)])
    dcms = [str(x)+".dcm" for x in dcms]

    # get tick
    curr_tick = pydicom.dcmread(j(dicom_scanFolder_pth, dcms[0])).SliceThickness
    step = round(TICK/curr_tick)

    dcm_paths = []

    cpu_cores = multiprocessing.cpu_count()

    for idx,i in enumerate(range(0, len(dcms), step)):
        if idx >= strt and idx <=end:
            dcm_paths.append(j(dicom_scanFolder_pth, dcms[i]))
            _ = Parallel(n_jobs=(cpu_cores/2))(delayed(preprocess_single_dcm)(path, jpgScan_pth, SIZE) for path in dcm_paths)
            
    print("DONE")

### Convert dcm to nii

In [8]:
def Dcm2Nii(dcm_pth,nii_pth):
    print("converting...")
    os.system(f'dcm2niix -o {nii_pth} {dcm_pth}')
    print("DONE")

### Generate Masks

In [9]:
def genMasks(nii_pth, ms_pth):
    print("generating masks...")
    input_pth  = j(nii_pth,[f for f in l(nii_pth) if f.endswith('.nii')][0])
    os.system(f"TotalSegmentator -i {input_pth} -o {ms_pth} --fast --roi_subset liver urinary_bladder")
    print("Done")

 ### Get ROI

In [10]:
def getROI(mfolder_pth, thresh):

    print("Extract ROI...")
    ROI =[]
    for m_name in l(mfolder_pth):
        m_pth = j(mfolder_pth,m_name)
        m = process_mask(m_pth)
        for idx,f in enumerate(range(m.shape[0])):
            if calc_black_pct(m[f,:,:]) < thresh: 
                ROI.append(idx)
                break
    print("DONE")
    return tuple(ROI)

### Model input

In [11]:
def select_elements_with_spacing(input_list, divsion):
      
    spacing = len(input_list) // divsion
    if spacing == 0 :
        spacing = 1

    selected_indices = [spacing * i for i in range(0,divsion-1)]
    selected_indices.append(len(input_list)-1)

    selected_elements = [input_list[index] for index in selected_indices]
    
    return selected_elements

In [12]:
def preprocess_jpeg(jpg_scanFolder_pth,jpeg_path):
    
    img = cv2.imread(j(jpg_scanFolder_pth,jpeg_path))
    greyscale = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)/255
    
    return greyscale

In [13]:
def prepareModelInput(jpg_scanFolder_pth, num_frames=10):
    
    print("input preparing...")
    
    frame_pths = sorted([int(f.split('.')[0]) for f in l(jpg_scanFolder_pth)])
    frame_pths = [str(f)+'.jpeg' for f in frame_pths]

    frame_pths = select_elements_with_spacing( frame_pths, num_frames)

    images = []
    for f in frame_pths:
        image = preprocess_jpeg(jpg_scanFolder_pth,f)
        images.append(image)
        
    images = np.stack(images)
    images = np.expand_dims(images, axis=0)
    image = torch.tensor(images, dtype = torch.float)

    print("DONE")
    
    return image

### Model Architecture

In [14]:
# Model Architecure
class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.input = nn.Conv2d(10, 3, kernel_size = 3)
        
        model = models.efficientnet_b0(weights = 'IMAGENET1K_V1')
        
        self.features = model.features
        self.avgpool = model.avgpool
        
        #heads
        self.bowel = nn.Linear(1280, 1) #1,0

        self.extravasation = nn.Linear(1280, 1) #1.0

        self.kidney = nn.Linear(1280, 3)

        self.liver = nn.Linear(1280,3) 

        self.spleen = nn.Linear(1280, 3)
    
    def forward(self, x):
        
        # extract features
        x = self.input(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        
        # output logits
        bowel = self.bowel(x)
        extravsation = self.extravasation(x)
        kidney = self.kidney(x)
        liver = self.liver(x)
        spleen = self.spleen(x)
        
        return bowel, extravsation, kidney, liver, spleen

### Run Model

In [15]:
def runModel(model_path, modelInput): 
    "initial model..."
    model = CNNModel()
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.eval()  
    print('predicting...')
    with torch.no_grad():
        bowel, extravasation, kidney, liver, spleen = model(modelInput)
        print('DONE')
        return bowel, extravasation, kidney, liver, spleen
    

### Get Results

In [16]:
def getResults(b, e, k, l, s, thresh=0.5):
    
    getdic = lambda name: {0: f"Healthy {name}", 1: f"Low Injury {name}", 2: f"High Injury {name}"}

    br = "Healthy bowel" if b < 0.5 else "injured bowel"
    er = "NO extravasation detected" if e < 0.5 else "extravasation DETECTED"

    kr = np.argmax(k).item()
    kr = getdic("Kidneys")[kr]

    lr = np.argmax(l).item()
    lr = getdic("Liver")[lr]

    sr = np.argmax(s).item()
    sr = getdic("Spleen")[sr]

    return br, er, kr, lr, sr


<h3>DEBUG</h3>empty Constant folders (DON'T USE)

In [17]:
import os
import shutil

def reset_constants():
    if e('C:\\App\\niiScan'): shutil.rmtree('C:\\App\\niiScan')
    os.makedirs('C:\\App\\niiScan', exist_ok=True)
    
    if e('C:\\App\\masks'): shutil.rmtree('C:\\App\\masks')
    os.makedirs('C:\\App\\masks', exist_ok=True)

    if e('C:\\App\\jpgScan'): shutil.rmtree('C:\\App\\jpgScan')
    os.makedirs('C:\\App\\jpgScan', exist_ok=True)    

reset_constants()


### Start Pipeline

In [18]:
l = os.listdir
rme = os.rename
mk = os.makedirs
j = os.path.join
e = os.path.exists
basename = os.path.basename

In [19]:

root = 'C:\\App'
dicom_scanFolder_pth = f'{root}\\007' #token by browsing
nifti_scanFolder_pth = f'{root}\\niiScan'
masks_scanFolder_pth = f'{root}\\masks'
jpg_scanFolder_pth   = f'{root}\\jpgScan'
model_pth            = f'{root}\\efficientnet_b0_1.658.pth'

In [20]:
#pipeline
reindex(dicom_scanFolder_pth)
Dcm2Nii(dicom_scanFolder_pth, nifti_scanFolder_pth)
genMasks(nifti_scanFolder_pth, masks_scanFolder_pth) 
strt, end = getROI(masks_scanFolder_pth, thresh = 96.5)
initPreprocess(dicom_scanFolder_pth, jpg_scanFolder_pth, strt, end) #(53.8s)
modelInput = prepareModelInput(jpg_scanFolder_pth, num_frames=10)
b, e, k, l, s = runModel(model_pth, modelInput)
br, er, kr, lr, sr = getResults(b, e, k, l, s, thresh=0.5)
print(f"bowel result >>> {br}\nextravasation result >>> {er}\nKidney result >>> {kr}\nLiver result >>> {lr}\nSpleen result >>> {sr}")


converting...
DONE
generating masks...
Done
Extract ROI...
DONE
Initial preprocessing based ROI...
DONE
input preparing...
DONE
predicting...
DONE
bowel result >>> Healthy bowel
extravasation result >>> NO extravasation detected
Kidney result >>> Healthy Kidneys
Liver result >>> Healthy Liver
Spleen result >>> Healthy Spleen
