In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pydicom as dcm
import os
import nibabel as nib
from tqdm import tqdm
from tensorflow import keras
from keras import layers, models
import tensorflow_io as tfio
import cv2
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import tensorflow_addons as tfa
import glob
import scipy

In [68]:
def window_image(img, minn,maxx, intercept, slope, rescale=True):
    img = (img*slope +intercept) 
    
    img[img<minn] = minn 
    img[img>maxx] = maxx 
    if rescale: 
        img = (img - minn) / (maxx - minn)
    return img
    
def get_first_of_dicom_field_as_int(x):
    if type(x) == dcm.multival.MultiValue: return int(x[0])
    else: return int(x)
    
def get_windowing(data):
    dicom_fields = [data[('0028','1050')].value,
                    data[('0028','1051')].value,
                    data[('0028','1052')].value,
                    data[('0028','1053')].value]
    return [get_first_of_dicom_field_as_int(x) for x in dicom_fields]
def channeling(img1, img2, img3):
    return np.stack([img1, img2, img3], axis=-1)
def load_data(UID):
    data_path = '../../Downloads/rsna-2022-cervical-spine-fracture-detection/train_images/'
    data = dcm.dcmread(data_path+UID+'/'+os.listdir(data_path+UID)[0])
    _, _, intercept, slope = get_windowing(data)
    return intercept, slope

In [None]:
IMAGE_SIZE = 128
BATCH_SIZE = 4
NUM_CLASSES = 8
NUM_TRAIN_IMAGES = 1000
NUM_VAL_IMAGES = 50

In [98]:
def convolution_block(
    block_input,
    num_filters=256,
    kernel_size=3,
    dilation_rate=1,
    padding="same",
    use_bias=False,
):
    x = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        dilation_rate=dilation_rate,
        padding="same",
        use_bias=use_bias,
        kernel_initializer=keras.initializers.HeNormal(),
    )(block_input)
    x = layers.BatchNormalization()(x)
    return tf.nn.relu(x)

def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x = convolution_block(x, kernel_size=1, use_bias=True)
    out_pool = layers.UpSampling2D(
        size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
    )(x)

    out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

    x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
    output = convolution_block(x, kernel_size=1)
    return output
def DeeplabV3Plus(image_size, num_classes):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    resnet50 = keras.applications.ResNet50(
        weights="imagenet", include_top=False, input_tensor=model_input
    )
    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = DilatedSpatialPyramidPooling(x)

    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=model_input, outputs=model_output)


model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
model.summary()
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=loss,
    metrics=["accuracy"],
)


IndentationError: expected an indented block (2715640895.py, line 38)

In [30]:
path_seg = '../../Downloads/rsna-2022-cervical-spine-fracture-detection/segmentations/'

In [64]:
patients = os.listdir(path_seg)
patients[0]

'1.2.826.0.1.3680043.10633.nii'

In [69]:
def seg_extractor(UID):
    seg = nib.load(path_seg+UID+'.nii')
    seg = seg.get_fdata()
    seg = np.transpose(seg, axes=[2,1,0])
    seg = scipy.ndimage.zoom(seg, (128/seg.shape[0],128/seg.shape[1],128/seg.shape[2]), order=1)
    return seg[::-1,::-1,:]
def ct_extractor(UID):
    img = nib.load(path_img+UID+'.nii.gz')
    img = img.get_fdata()
    img = scipy.ndimage.zoom(img, (128/img.shape[0],1,1), order=1)
    return img


In [77]:
def data_gen():
    for patient in patients[:75]:
        patient = patient[:-4]
        intercept, slope = load_data(patient)
        ct = ct_extractor(patient)
        ct1 = window_image(ct,150,300,intercept,slope)
        ct2 = window_image(ct,300,450,intercept,slope)
        ct3 = window_image(ct,450,600,intercept,slope)
        ct = channeling(ct1,ct2,ct3)
        seg = seg_extractor(patient)
        seg[np.where(seg>7)] = 7
        for image in range(28,100):
            yield ct[:,:,image], seg[:,:,image]

In [91]:
def test_gen():
    for patient in patients[75:]:
        patient = patient[:-4]
        intercept, slope = load_data(patient)
        ct = ct_extractor(patient)
        ct1 = window_image(ct,150,300,intercept,slope)
        ct2 = window_image(ct,300,450,intercept,slope)
        ct3 = window_image(ct,450,600,intercept,slope)
        ct = channeling(ct1,ct2,ct3)
        seg = seg_extractor(patient)
        seg[np.where(seg>7)] = 7
        for image in range(28,100):
            yield ct[:,:,image], seg[:,:,image]

In [96]:
dataset = tf.data.Dataset.from_generator(
     data_gen,
     (tf.float32, tf.int8),
    (tf.TensorShape([128,128,3]), tf.TensorShape([128,128]))
)

testset = tf.data.Dataset.from_generator(
     test_gen,
     (tf.float32, tf.int8),
    (tf.TensorShape([128,128,3]), tf.TensorShape([128,128]))
)

In [97]:
dataset = dataset.batch(15)
dataset = dataset.prefetch(1)


testset = testset.batch(15)
testset = testset.prefetch(1)