In [None]:
import wandb
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Input, Flatten, LSTM, Dense, Reshape, Dropout
from tensorflow.keras.models import Model

from types import SimpleNamespace
import keras
import numpy as np
import pandas as pd
import pydicom
import cv2
from scipy.ndimage import zoom
from sklearn import preprocessing
from glob import glob
import re
import sqlite3

import matplotlib.pyplot as plt
import os 
import nibabel as nib
from glob import glob
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from keras.utils import to_categorical
from keras import backend as K
tf.config.run_functions_eagerly(True)

In [None]:
wandb.login(anonymous="allow")

****DATA PREPROCESSING****

In [None]:
def window_converter(image, window_width=400, window_level=50):      
    img_min = window_level - window_width // 2
    img_max = window_level + window_width // 2
    window_image = image.copy()
    window_image[window_image < img_min] = img_min
    window_image[window_image > img_max] = img_max
    #image = (image / image.max() * 255).astype(np.float64)
    return window_image

def transform_to_hu(medical_image, image):
    meta_image = pydicom.dcmread(medical_image)
    intercept = meta_image.RescaleIntercept
    slope = meta_image.RescaleSlope
    hu_image = image * slope + intercept
    return hu_image

def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    # Correct DICOM pixel_array if PixelRepresentation == 1.
        pixel_array = dcm.pixel_array
        if dcm.PixelRepresentation == 1:
            bit_shift = dcm.BitsAllocated - dcm.BitsStored
            dtype = pixel_array.dtype 
            pixel_array = (pixel_array << bit_shift).astype(dtype) >> bit_shift
        return pixel_array
    
def resize_img(img_paths, target_size=(128, 128)):
        volume_shape = (target_size[0], target_size[1], len(img_paths)) 
        volume = np.zeros(volume_shape, dtype=np.float64)
        for i, image_path in enumerate(img_paths):
            image = pydicom.read_file(image_path)
            image = standardize_pixel_array(image)
            hu_image = transform_to_hu(image_path, image)
            window_image = window_converter(hu_image)
            image = cv2.resize(window_image, target_size)
            volume[:,:,i] = image
        return volume
    
def normalize_volume(resized_volume):
    original_shape = resized_volume.shape
    flattened_image = resized_volume.reshape((-1,))
    scaler = preprocessing.MinMaxScaler()
    normalized_flattened_image = scaler.fit_transform(flattened_image.reshape((-1, 1)))
    normalized_volume_image = normalized_flattened_image.reshape(original_shape)
    return normalized_volume_image

def generate_patient_processed_data(list_img_paths, list_labels, target_size=(128,128)):

    height = target_size[0]
    width = target_size[1]
    depth = len(list_img_paths)

    volume_array = np.zeros((height, width, depth), dtype=np.float64)

    print("Initializing data preprocessing with the following dimensions-> Volumes:{}".format(volume_array.shape))

    resized_images = resize_img(list_img_paths, target_size=target_size)
    normalized_siz_volume = normalize_volume(resized_images)
    volume_array = normalized_siz_volume
    #volume_mask = create_3D_segmentations(list_seg_paths, target_size=target_size)
    labels = [list_labels for i in range(depth)]
    
    return volume_array, labels#, volume_mask

In [None]:
def extract_number_from_path(path):
    match = re.search(r'(\d+)\.dcm$', path)
    if match:
        return int(match.group(1))
    return 0

 

def get_data_for_3d_volumes(data,train_data_cat, path, number_idx):

    data_to_merge = data[["patient_id", "series_id"]]
    patient_category = train_data_cat[["patient_id", "any_injury"]]

    merged_df = data_to_merge.merge(patient_category, on='patient_id', how='left')

    shuffled_data = merged_df.sample(frac=1, random_state=42)
    shuffled_indexes = shuffled_data.index[:number_idx]
    selected_rows = shuffled_data.loc[shuffled_indexes]
    data_to_merge_processed = selected_rows.reset_index()

    total_paths = []
    patient_ids = []
    series_ids = []
    category = []

    for patient_id in range(len(data_to_merge_processed)):

        p_id = str(data_to_merge_processed["patient_id"][patient_id]) + "/" + str(data_to_merge_processed["series_id"][patient_id])
        str_imgs_path = path + p_id + '/'
        patient_img_paths = []

 

        for file in glob(str_imgs_path + '/*'):
            patient_img_paths.append(file)


        sorted_file_paths = sorted(patient_img_paths, key=extract_number_from_path)
        total_paths.append(sorted_file_paths)
        patient_ids.append(data_to_merge_processed["patient_id"][patient_id])
        series_ids.append(data_to_merge_processed["series_id"][patient_id])
        category.append(data_to_merge_processed["any_injury"][patient_id])

    final_data = pd.DataFrame(list(zip(patient_ids, series_ids, total_paths, category)),
               columns =["Patient_id","Series_id", "Patient_paths", "Patient_category"])

    return final_data

In [None]:
def extract_number_from_path(path):
    match = re.search(r'(\d+)\.dcm$', path)
    if match:
        return int(match.group(1))
    return 0
def get_data_for_3d_volumes(data,train_data_cat, path, number_idx):
    
    data_to_merge = data[["patient_id", "series_id"]]
    patient_category = train_data_cat[["patient_id", "any_injury"]]
    
    merged_df = data_to_merge.merge(patient_category, on='patient_id', how='left')
    
    shuffled_data = merged_df.sample(frac=1, random_state=42)
    shuffled_indexes = shuffled_data.index[:number_idx]
    selected_rows = shuffled_data.loc[shuffled_indexes]
    data_to_merge_processed = selected_rows.reset_index()
    
    total_paths = []
    patient_ids = []
    series_ids = []
    category = []
    
    for patient_id in range(len(data_to_merge_processed)):
    
        p_id = str(data_to_merge_processed["patient_id"][patient_id]) + "/" + str(data_to_merge_processed["series_id"][patient_id])
        str_imgs_path = path + p_id + '/'
        patient_img_paths = []

        for file in glob(str_imgs_path + '/*'):
            patient_img_paths.append(file)
        
        
        sorted_file_paths = sorted(patient_img_paths, key=extract_number_from_path)
        total_paths.append(sorted_file_paths)
        patient_ids.append(data_to_merge_processed["patient_id"][patient_id])
        series_ids.append(data_to_merge_processed["series_id"][patient_id])
        category.append(data_to_merge_processed["any_injury"][patient_id])
    
    final_data = pd.DataFrame(list(zip(patient_ids, series_ids, total_paths, category)),
               columns =["Patient_id","Series_id", "Patient_paths", "Patient_category"])
    
    return final_data

In [None]:
import numpy as np

buffer_size = 10
buffer = []
#
#def process_buffer(buffer):
 #   if len(buffer) >= buffer_size:
  #      X_batch = np.array(buffer)
   #     y_batch = np.array([1])
    #    lstm_model.train_on_batch(X_batch, y_batch)
     #   buffer.clear()

#for features_sequence in features_sequences:
#    for feature_vector in features_sequence:
 #       buffer.append(feature_vector)
  #      process_buffer(buffer)
#
#rocess_buffer(buffer)
#


In [None]:
train_data = pd.read_csv(f"/kaggle/input/rsna-2023-abdominal-trauma-detection/train_series_meta.csv")
cat_data = pd.read_csv("/kaggle/input/rsna-2023-abdominal-trauma-detection/train.csv")
path = "/kaggle/input/rsna-2023-abdominal-trauma-detection/train_images/"
cleaned_df = get_data_for_3d_volumes(train_data, cat_data, path=path, number_idx=200)
print("Data extraction terminated...")

In [None]:
cleaned_df

In [None]:

df_injury = cleaned_df.loc[cleaned_df["Patient_category"] == 1]
df_healthy = cleaned_df.loc[cleaned_df["Patient_category"] == 0]
print(df_injury.count())
print(df_healthy.count())
df_injury = df_injury.iloc[0:20] 
df_healthy = df_healthy.iloc[0:20]

cleaned_df = pd.concat([df_injury, df_healthy])

In [None]:
cleaned_df

In [None]:
cleaned_df = cleaned_df.reset_index(drop=True)
cleaned_df

In [None]:
volume_dcm = []
volume_labels = []

for i in range(40):
    volume_img, depth = generate_patient_processed_data(cleaned_df["Patient_paths"][i], cleaned_df["Patient_category"][i])
    volume_dcm.append(volume_img)
    volume_labels.append(depth)

In [None]:
volume_of_imgs = np.concatenate(volume_dcm, axis=2)
volume_of_labels = np.concatenate(volume_labels, axis=0)
volume_of_imgs.shape, volume_of_labels.shape

In [None]:
transposed_volume_dcm = np.transpose(volume_of_imgs, (2, 0, 1))
transposed_volume_dcm = np.expand_dims(transposed_volume_dcm, axis=3)

In [None]:
train_images , test_images , train_labels, test_labels = train_test_split(transposed_volume_dcm, volume_of_labels, test_size = 0.10, random_state = 0)

****Build Model Arquitecture****

In [None]:
def conv_block(input, num_filters):

	x = Conv2D(num_filters, 3, padding="same")(input)
	x = BatchNormalization()(x)
	x = Activation("relu")(x)

	x = Conv2D(num_filters, 3, padding="same")(x)
	x = BatchNormalization()(x)
	x = Activation("relu")(x)

	return x 

def encoder_block(input, num_filters):

	x = conv_block(input, num_filters)
	p = MaxPool2D((2, 2))(x)
	return x, p

def dense_block(input_shape, num_classes):

	inputs = Input(shape=input_shape)
	
	x = Flatten()(inputs)
	dense_layer_1 = Dense(units=512, activation='relu')(x)
	dense_layer_1 = Dropout(0.4)(dense_layer_1)

	#dense_layer_2 = Dense(units=256, activation='relu')(dense_layer_1)
	#dense_layer_2 = Dropout(0.4)(dense_layer_2)
	output_layer = Dense(units=num_classes, activation='softmax')(dense_layer_1)

	return output_layer

def build_unet_encoder_model_lstm(input_shape):
    units = 128
    inputs = Input(input_shape)

	#ENCODER
    s1, p1 = encoder_block(inputs, units/2)
    s2, p2 = encoder_block(p1, units)
    s3, p3 = encoder_block(p2, units*2)
    s4, p4 = encoder_block(p3, units*4)
    b1 = conv_block(p4, units*8)
    b2 = Flatten()(b1)
    encoder_output = b2 
    
    return Model(inputs, encoder_output)


def unet_encoder_lstm_model(input_shape, lstm_units):
    unet_encoder = build_unet_encoder_model_lstm(input_shape)
    # Convierte la salida del encoder en una secuencia 1D para la LSTM
    encoder_output = unet_encoder.output
    lstm_input = Reshape((-1, encoder_output.shape[1]))(encoder_output)

    # capa LSTM
    lstm_layer = LSTM(lstm_units)(lstm_input)

    # capa densa para la clasificaci√≥n 
    output_layer = Dense(1, activation='sigmoid')(lstm_layer)

    model = Model(inputs=unet_encoder.input, outputs=output_layer)

    return model

****RUN OF THE MODEL****

In [None]:
if __name__ == "__main__":
    
    
    input_shape = (128, 128, 1)
    num_classes = 2
    
    #Organize hyperparameters to track down
    config = SimpleNamespace(
        lstm_units = 100,
        L_R = 1e-5,
        LOSS = "binary_crossentropy",
        METRICS = "accuracy",
        EPOCHS = 200,
        BATCH_SIZE = 32 
    )

    #Start the wandb run
    wandb.init(project="Unet-LSTM-2D-Model", config=config)
        
    # Hyperparameters
    #lstm_units = 100
    #L_R = 1e-5
    #OPTIMIZER=tf.keras.optimizers.SGD(learning_rate=L_R)
    #LOSS = "binary_crossentropy"
    #METRICS = ["accuracy"]
    #EPOCHS = 200
    #BATCH_SIZE = 32
    OPTIMIZER=tf.keras.optimizers.SGD(learning_rate=config.L_R)
    
    #Model training
    
    model = unet_encoder_lstm_model(input_shape, config.lstm_units)
    model.compile(optimizer=OPTIMIZER, loss=config.LOSS, metrics=config.METRICS)
    model.summary()
    
    
    history = model.fit(train_images, train_labels, epochs=config.EPOCHS, batch_size=config.BATCH_SIZE, validation_data=(test_images, test_labels),
                        shuffle=True)

In [None]:
#Log metrics over time to visualize performance
train_metrics = { "accuracy": model.metrics, 
                 "loss": model.loss}
val_metrics = {"val_accuracy": history.history["val_accuracy"],
                  "val_loss": history.history["val_loss"]}
wandb.log(train_metrics)
wandb.log(val_metrics)
    
plt.plot(history.history["accuracy"], label="Accuracy")
plt.plot(history.history["val_accuracy"], label="Validation accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0.75, 1])
plt.legend(loc="lower right")

#Finish the run 

wandb.finish()