In [None]:
import tensorflow as tf
import pydicom
import cv2
from sklearn import preprocessing
import math
import numpy as np
import re
import pandas as pd
from glob import glob
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt

In [None]:
def extract_number_from_path(path):
    match = re.search(r'(\d+)\.dcm$', path)
    if match:
        return int(match.group(1))
    return 0

def get_data_for_3d_volumes(data,train_data_cat, path, number_idx):
    
    data_to_merge = data[["patient_id", "series_id"]]
    patient_category = train_data_cat[["patient_id", "any_injury"]]
    
    merged_df = data_to_merge.merge(patient_category, on='patient_id', how='left')
    
    shuffled_data = merged_df.sample(frac=1, random_state=42)
    shuffled_indexes = shuffled_data.index[:number_idx]
    selected_rows = shuffled_data.loc[shuffled_indexes]
    data_to_merge_processed = selected_rows.reset_index()
    
    total_paths = []
    patient_ids = []
    series_ids = []
    category = []
    
    for patient_id in range(len(data_to_merge_processed)):
    
        p_id = str(data_to_merge_processed["patient_id"][patient_id]) + "/" + str(data_to_merge_processed["series_id"][patient_id])
        str_imgs_path = path + p_id + '/'
        patient_img_paths = []

        for file in glob(str_imgs_path + '/*'):
            patient_img_paths.append(file)
        
        
        sorted_file_paths = sorted(patient_img_paths, key=extract_number_from_path)
        total_paths.append(sorted_file_paths)
        patient_ids.append(data_to_merge_processed["patient_id"][patient_id])
        series_ids.append(data_to_merge_processed["series_id"][patient_id])
        category.append(data_to_merge_processed["any_injury"][patient_id])
    
    final_data = pd.DataFrame(list(zip(patient_ids, series_ids, total_paths, category)),
               columns =["Patient_id","Series_id", "Patient_paths", "Patient_category"])
    
    return final_data

In [None]:
train_data = pd.read_csv(f"/kaggle/input/rsna-2023-abdominal-trauma-detection/train_series_meta.csv")
cat_data = pd.read_csv(f"/kaggle/input/rsna-2023-abdominal-trauma-detection/train.csv")
path = "/kaggle/input/rsna-2023-abdominal-trauma-detection/train_images/"
cleaned_df = get_data_for_3d_volumes(train_data, cat_data, path=path, number_idx=100)

In [None]:
patient_exp = cleaned_df.iloc[[96]]
patient_exp

In [None]:
cat_p1 = [cleaned_df["Patient_category"][96] for i in range(len(cleaned_df["Patient_paths"][96]))]

In [None]:
print(len(cleaned_df["Patient_paths"][96]), len(cat_p1))

In [None]:
class SuperDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, x_set: list, y_set: list, batch_size: int, target_size: tuple, window_width: int, window_level: int) -> None:

        """_Initialize the Data Generator_
        """
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.target_size = target_size
        self.window_width = window_width
        self.window_level = window_level
    
    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)
    
    def window_converter(self, image):

        """_Uses the window values in order to create desired contrast to the image_

        Returns
        -------
        _np.ndarray_
            _returns a numpy array with the desired window level applied_
        """

        img_min = self.window_level - self.window_width // 2
        img_max = self.window_level + self.window_width // 2
        window_image = image.copy()
        window_image[window_image < img_min] = img_min
        window_image[window_image > img_max] = img_max
        return window_image

    def transform_to_hu(self, medical_image, image):
        meta_image = pydicom.dcmread(medical_image)
        intercept = meta_image.RescaleIntercept
        slope = meta_image.RescaleSlope
        hu_image = image * slope + intercept
        return hu_image
    
    def standardize_pixel_array(self, dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    
        """_Correct DICOM pixel_array if PixelRepresentation == 1._

        Returns
        -------
        _np.ndarray_
            _returns the pixel array from the dicom file with the
            fixed pixel representation value_
        """
        pixel_array = dcm.pixel_array
        if dcm.PixelRepresentation == 1:
            bit_shift = dcm.BitsAllocated - dcm.BitsStored
            dtype = pixel_array.dtype 
            pixel_array = (pixel_array << bit_shift).astype(dtype) >> bit_shift
        return pixel_array
    
    def resize_img(self, image_path: str) -> np.ndarray:

        """_Resize and fix pixel array_

        Returns
        -------
        _np.ndarray_
            _Returns fixed and normalized image_
        """
        image = pydicom.read_file(image_path)
        image = self.standardize_pixel_array(image)
        hu_image = self.transform_to_hu(image_path, image)
        window_image = self.window_converter(hu_image)
        final_image = cv2.resize(window_image, self.target_size)
        return final_image
    
    def normalize_image(self, image):
        # Ensure the input image is 2D
        if len(image.shape) != 2:
            raise ValueError("Input must be a 2D image.")
        # Reshape the 2D image into a 1D array
        flattened_image = image.reshape((-1,))
        # Create a MinMaxScaler instance
        scaler = MinMaxScaler()
        # Fit and transform the flattened image
        normalized_flattened_image = scaler.fit_transform(flattened_image.reshape((-1, 1)))
        # Reshape the normalized image back to its original shape
        normalized_image = normalized_flattened_image.reshape(image.shape)
        return normalized_image
    
    def __getitem__(self, index):
        batch_x = self.x[index * self.batch_size:(index + 1) * self.batch_size]
        batch_y = self.y[index * self.batch_size:(index + 1) * self.batch_size]

        resized_shape = (len(batch_x), self.target_size[0], self.target_size[1])
        resized_images = np.zeros(resized_shape, dtype=np.float64)
        for i, file_name in enumerate(batch_x):
            preprocessed_image = self.resize_img(file_name)
            normalized_image = self.normalize_image(preprocessed_image)
            resized_images[i,:,:] = normalized_image
        return np.expand_dims(resized_images, -1), np.array(batch_y, dtype=np.float64)

In [None]:
data_gen = SuperDataGenerator(cleaned_df["Patient_paths"][96], cat_p1, 32, (512, 512), 557, 107)

In [None]:
x, y = data_gen[1]

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i, ax in enumerate(axes.flatten()):
    ax.imshow(x[i], cmap='gray')
    ax.axis('off')  

plt.tight_layout() 
plt.show()  