## Imports

In [5]:
# Imports
import os
import h5py
import numpy as np
import random
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models, datasets, backend
from tensorflow.keras.layers import Conv2D, Reshape
from tensorflow.keras.callbacks import EarlyStopping

from PIL import Image
from sklearn.model_selection import train_test_split

## Carga de datos

In [6]:
# Cargar el directorio con los archivos .h5
directory = "./BraTS2020_training_data/content/data"

In [7]:
# Cargar archivos a array
h5_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
print(f"Found {len(h5_files)} .h5 files")

Found 57195 .h5 files


In [8]:
# Abrir los primeros archivos .h5 de la lista para inspeccionarlos
if h5_files:
    selected_file = random.choice(h5_files)
    file_path = os.path.join(directory, selected_file)
    with h5py.File(file_path, 'r') as file:
        print("\nKeys for each file:", list(file.keys()))
        for key in file.keys():
            print(f"\nData type of {key}:", type(file[key][()]))
            print(f"Shape of {key}:", file[key].shape)
            print(f"Array dtype: {file[key].dtype}")
            print(f"Array max val: {np.max(file[key])}")
            print(f"Array min val: {np.min(file[key])}")
            print("*"*10)
            print(f"Mean: {np.mean(file[key])}")
            print(f"Standard deviation: {np.std(file[key])}")
            print("*"*10)
            
            # Verificar valores nulos
            if np.isnan(file[key]).any():
                print("Hay valores NaN en los datos.")
            else:
                print("No se encontraron valores NaN.")
else:
    print("No .h5 files found.")


Keys for each file: ['image', 'mask']

Data type of image: <class 'numpy.ndarray'>
Shape of image: (240, 240, 4)
Array dtype: float64
Array max val: 7.844274508692563
Array min val: -0.2880464900765529
**********
Mean: -4.884981308350689e-17
Standard deviation: 1.0
**********
No se encontraron valores NaN.

Data type of mask: <class 'numpy.ndarray'>
Shape of mask: (240, 240, 3)
Array dtype: uint8
Array max val: 0
Array min val: 0
**********
Mean: 0.0
Standard deviation: 0.0
**********
No se encontraron valores NaN.


## Dataset

### Visualizar

In [9]:
def visualize_image_and_masks(image, mask):
    """
    Args:
        image: Tensor o array de la imagen (H, W, C).
        mask: Tensor o array de la máscara (H, W) o (H, W, C).
    """
    plt.figure(figsize=(12, 6))
    
    # Mostrar imagen
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Imagen")
    plt.axis("off")
    
    # Mostrar máscara (asumiendo que está codificada en un solo canal)
    if len(mask.shape) == 3 and mask.shape[-1] > 1:
        mask_display = tf.argmax(mask, axis=-1)  # Reducir la máscara a un solo canal si está en one-hot encoding
    else:
        mask_display = mask
    
    plt.subplot(1, 2, 2)
    plt.imshow(mask_display, cmap="jet")
    plt.title("Máscara")
    plt.axis("off")
    
    plt.show()

### Crear Dataset

In [10]:
# Cargar archivos
h5_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.h5')]
np.random.seed(42)
np.random.shuffle(h5_files)

# Dividir dataset en entrenaminto y validación (80:10)
split_idx = int(0.8 * len(h5_files))
train_files = h5_files[:split_idx]
val_files = h5_files[split_idx:]

In [11]:
# Preprocesamiento
def preprocess(h5_file):
    with h5py.File(h5_file.numpy().decode('utf-8'), 'r') as file:
        image = file['image'][()]
        mask = file['mask'][()]
        
        # Reescalar la imagen: (H, W, C) -> (C, H, W)
        image = image.transpose((2, 0, 1))
        mask = mask.transpose((2, 0, 1))

        # Ajustar los valores de los píxeles en la imagen para que estén entre 0 y 1
        for i in range(image.shape[0]):
            min_val = np.min(image[i])
            image[i] = image[i] - min_val
            max_val = np.max(image[i]) + 1e-4
            image[i] = image[i] / max_val

        # Reescalar la imagen: (C, H, W) -> (H, W, C) para ser ompatibles con TensorShape
        image = image.transpose((1, 2, 0))
        mask = mask.transpose((1, 2, 0))

        # Normalizar entre 0 y 1
        image = tf.cast(image, tf.float32) / 255.0
        mask = tf.cast(mask, tf.float32) / 255.0

    return image, mask

In [12]:
# Establecer las formas después de tf.py_function
def preprocess_with_shape(h5_file):
    image, mask = tf.py_function(preprocess, [h5_file], [tf.float32, tf.float32])
    image.set_shape((240, 240, 4))  # Forma de la imagen: (C, H, W)
    mask.set_shape((240, 240, 3))   # Forma de la máscara: (C, H, W)
    return image, mask

### Entrenamiento

In [13]:
# Crear datasets para entrenamiento y validación
train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
train_dataset = train_dataset.map(preprocess_with_shape, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.batch(16).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices(val_files)
val_dataset = val_dataset.map(preprocess_with_shape, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(16).prefetch(tf.data.AUTOTUNE)

In [14]:
# Ver ejemplos del dataset
for image, mask in train_dataset.take(1):
    print("Forma de la imagen de entrenamiento:", image.shape)
    print("Forma de la máscara de entrenamiento:", mask.shape)

Forma de la imagen de entrenamiento: (16, 240, 240, 4)
Forma de la máscara de entrenamiento: (16, 240, 240, 3)


# Implementación del Modelo

### Bloque SE-ResNet

In [15]:
# Bloque SE-ResNet
def se_resnet_block(input_tensor, n_filters, kernel_size=3, stride=1, reduction_ratio=16):
    # Bloque Residual
    # Primera convolución
    x = layers.Conv2D(n_filters, kernel_size, strides=stride, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # Segunda convolución
    x = layers.Conv2D(n_filters, kernel_size, strides=stride, padding='same')(x)
    x = layers.BatchNormalization()(x)

    # Bloque SE
    # Squeeze - Global Average Pooling
    se = layers.GlobalAveragePooling2D()(x)
    se = layers.Dense(n_filters // reduction_ratio, activation='relu')(se) # Bottleneck
    # se = layers.Dropout(0.2)(se)
    se = layers.Dense(n_filters, activation='sigmoid')(se) # Excitación
    se = layers.Reshape((1, 1 , n_filters))(se) # Ajuste de dimensiones
    x = layers.Multiply()([x, se]) # Recalibración

    # Shortcut connection - Identity
    if input_tensor.shape[-1] != n_filters:
        shortcut = layers.Conv2D(n_filters, (1, 1), padding='same')(input_tensor)
    else:
        shortcut = input_tensor

    x = layers.Add()([x, shortcut])  # Suma residual
    x = layers.ReLU()(x)

    return x

### Attention Gate

In [16]:
# Attention Gate
def attention_gate(dec, enc, n_filters):
    # Reducir dimensión de Encoder
    enc1 = layers.Conv2D(n_filters, (1, 1), padding='same')(enc)
    enc1 = layers.BatchNormalization()(enc1)

    # Reducir dimensión de Gatting Signal
    gatting_sig = layers.Conv2D(n_filters, (1, 1), padding='same')(dec)
    gatting_sig = layers.BatchNormalization()(gatting_sig)

    # Combinar Gatting Signal y Skip Connection
    combined = layers.Add()([enc1, gatting_sig])
    combined = layers.ReLU()(combined)

    # Mapa de Atención
    attention_map = layers.Conv2D(1, (1, 1), activation='sigmoid',padding='same')(combined)

    # Aplicar mapa de Atención
    output = layers.Multiply()([enc, attention_map])

    return output

### Encoder

In [17]:
# Encoder
def encoder_block(input_tensor, n_filters, apply_pooling=True):
    x = se_resnet_block(input_tensor, n_filters)
    p = layers.MaxPooling2D((2, 2))(x)

    return x, p # x se utiliza en el skip connection; p pasa al siguiente bloque
    
# Downsampling
def build_encoder(input_tensor, filters_list):
    skips = []
    
    # Primera Convolución
    x = layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same')(input_tensor)
    x = layers.ReLU()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    
    for i, n_filters in enumerate(filters_list):
        # Verifica si es el último bloque
        apply_pooling = i < len(filters_list) - 1  # True excepto en el último bloque
        x, p = encoder_block(x, n_filters, apply_pooling=apply_pooling)
        skips.append(x)  # Guardar skip connection
        
        # Avanzar al siguiente bloque solo si hay pooling
        if apply_pooling:
            x = p

    return x, skips

### Decoder

In [18]:
# Decoder
def decoder_block(input_tensor, skip_tensor, n_filters):
    # Gatting Signal
    gated_skip = attention_gate(input_tensor, skip_tensor, n_filters)
    gated_skip = layers.Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(gated_skip)

    x = layers.Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor) # Upsampling
    x = layers.Concatenate()([x, gated_skip])  # Combina con skip connection
    x = se_resnet_block(x, n_filters)

    return x

# Upsampling
def build_decoder(encoder_output, skips, filters_list):
    x = encoder_output
    for filters, skip_tensor in zip(filters_list, reversed(skips)):        
        x = decoder_block(x, skip_tensor, filters)

    x = layers.Conv2D(3, (1, 1), strides=1, padding='same', activation='softmax')(x)
    
    return x

## Creación de modelo

In [19]:
# SE-ResANet
def se_resanet(input_shape, encoder_filters, decoder_filters):
    input_tensor = layers.Input(shape=input_shape)
    # input_tensor = layers.InputLayer(input_shape=input_shape)

    # Encoder
    encoder_output, skips = build_encoder(input_tensor, encoder_filters)

    # Decoder
    decoder_output = build_decoder(encoder_output, skips, decoder_filters)

    # Output
    # outputs = layers.Conv2D(filters=4, kernel_size=1, strides=1, padding='same', activation='softmax')(decoder_output)
    outputs = layers.Conv2D(3, (1, 1), activation='softmax')(decoder_output)

    # Modelo
    return models.Model(inputs=input_tensor, outputs=outputs)

In [20]:
input_shape = (240, 240, 4)  # Dimensiones de las imágenes BraTS2020

encoder_filters = [64, 128, 256, 512] # Filtros de cada nivel del Encoder
decoder_filters = [512, 256, 128, 64] # Filtros de cada nivel del Decoder

model = se_resanet(input_shape, encoder_filters, decoder_filters)

# Resumen
model.summary()

In [21]:
# Dice Coefficient
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = backend.flatten(y_true)
    y_pred_f = backend.flatten(y_pred)
    intersection = backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (backend.sum(y_true_f) + backend.sum(y_pred_f) + smooth)

# Recall
def recall(y_true, y_pred):
    true_positives = backend.sum(backend.round(backend.clip(y_true * y_pred, 0, 1)))
    possible_positives = backend.sum(backend.round(backend.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + backend.epsilon())

# Ejecución

In [22]:
# Compilar
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),  # Optimizar con Adam
    loss='categorical_crossentropy',
    metrics=['accuracy', dice_coefficient, recall]  # Métrica de precisión
)

## Entrenamiento

In [23]:
class CustomMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Imprimir las métricas con 4 decimales
        metrics = {key: f"{value:.4f}" for key, value in logs.items()}
        print(f"Epoch {epoch + 1}: {metrics}")
    
    def on_train_end(self, logs=None):
        # Mostrar las métricas finales al terminar el entrenamiento
        print("\nEntrenamiento terminado. Métricas finales:")
        for key, value in self.model.history.history.items():
            total = sum(value) / len(value)
            print(f"{key}: {total:.4f}")

In [None]:
early_stopping = EarlyStopping(monitor='val_loss', patience=5)

model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=50,
    callbacks=[early_stopping, CustomMetricsCallback()]
)

Epoch 1/50
[1m 173/2860[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m17:36:07[0m 24s/step - accuracy: 0.0123 - dice_coefficient: 3.3457e-05 - loss: 4.8743e-05 - recall: 0.0000e+00

## Evaluación del modelo

In [14]:
# Evaluar el modelo
val_loss, val_accuracy, val_dice, val_recall = model.evaluate(val_dataset)

print(f"Validation Loss: {val_loss}")
print(f"Validation Accuracy: {val_accuracy}")
print(f"Validation Dice Coefficient: {val_dice}")
print(f"Validation Recall: {val_recall}")

ValueError: as_list() is not defined on an unknown TensorShape.