<a href="https://colab.research.google.com/github/Masciel-Sevilla/CalculadoraVLSM_MascielSevilla/blob/master/Untitled21.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers # Import for L2 regularization
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import cv2
import os
from glob import glob

# --- FORCE ALL COMPUTATIONS TO FLOAT32 ---
tf.keras.mixed_precision.set_global_policy('float32')

# -------------------- CONFIGURACIÓN --------------------
IMG_HEIGHT = 128
IMG_WIDTH = 128
NUM_CLASSES = 6
BATCH_SIZE = 16
EPOCHS = 50 # Increased epochs, EarlyStopping will manage
LEARNING_RATE = 1e-3
DATASET_ROOT_PATH = "Balanced" # Renamed to avoid confusion with internal "Balanced" folder

# -------------------- CBAM MODULE --------------------
class CBAMModule(layers.Layer):
    def __init__(self, channels, reduction_ratio=16, **kwargs):
        super(CBAMModule, self).__init__(**kwargs)
        self.channels = channels
        self.reduction_ratio = reduction_ratio
        self.channel_avg_pool = layers.GlobalAveragePooling2D()
        self.channel_max_pool = layers.GlobalMaxPooling2D()
        self.channel_fc1 = layers.Dense(channels // reduction_ratio, activation='relu')
        self.channel_fc2 = layers.Dense(channels, activation='sigmoid')
        self.spatial_conv = layers.Conv2D(1, 7, padding='same', activation='sigmoid')

    def call(self, inputs):
        avg_pool = self.channel_avg_pool(inputs)
        max_pool = self.channel_max_pool(inputs)
        avg_out = self.channel_fc2(self.channel_fc1(avg_pool))
        max_out = self.channel_fc2(self.channel_fc1(max_pool)) # Corrected: should be max_pool
        channel_attention = layers.Add()([avg_out, max_out])
        channel_attention = layers.Reshape((1, 1, self.channels))(channel_attention)
        channel_refined = layers.Multiply()([inputs, channel_attention])

        avg_spatial = tf.reduce_mean(channel_refined, axis=-1, keepdims=True)
        max_spatial = tf.reduce_max(channel_refined, axis=-1, keepdims=True)
        spatial_concat = layers.Concatenate(axis=-1)([avg_spatial, max_spatial])
        spatial_attention = self.spatial_conv(spatial_concat)
        spatial_refined = layers.Multiply()([channel_refined, spatial_attention])
        return spatial_refined

# -------------------- ASPP MODULE --------------------
class ASPPModule(layers.Layer):
    def __init__(self, filters=256, **kwargs):
        super(ASPPModule, self).__init__(**kwargs)
        self.conv_1x1 = layers.Conv2D(filters, 1, padding='same', kernel_regularizer=regularizers.l2(1e-4)) # Added L2
        self.bn_1x1 = layers.BatchNormalization()
        self.conv_3x3_2 = layers.Conv2D(filters, 3, padding='same', dilation_rate=2, kernel_regularizer=regularizers.l2(1e-4)) # Added L2
        self.bn_3x3_2 = layers.BatchNormalization()
        self.conv_3x3_4 = layers.Conv2D(filters, 3, padding='same', dilation_rate=4, kernel_regularizer=regularizers.l2(1e-4)) # Added L2
        self.bn_3x3_4 = layers.BatchNormalization()
        self.conv_3x3_8 = layers.Conv2D(filters, 3, padding='same', dilation_rate=8, kernel_regularizer=regularizers.l2(1e-4)) # Added L2
        self.bn_3x3_8 = layers.BatchNormalization()
        self.global_avg_pool = layers.GlobalAveragePooling2D(keepdims=True)
        self.conv_1x1_gap = layers.Conv2D(filters, 1, padding='same', kernel_regularizer=regularizers.l2(1e-4)) # Added L2
        self.bn_gap = layers.BatchNormalization()
        self.conv_final = layers.Conv2D(filters, 1, padding='same', kernel_regularizer=regularizers.l2(1e-4)) # Added L2
        self.bn_final = layers.BatchNormalization()
        self.dropout = layers.Dropout(0.4) # Keep dropout

    def call(self, inputs, training=None):
        input_shape = tf.shape(inputs)
        conv_1x1 = tf.nn.relu(self.bn_1x1(self.conv_1x1(inputs), training=training))
        conv_3x3_2 = tf.nn.relu(self.bn_3x3_2(self.conv_3x3_2(inputs), training=training))
        conv_3x3_4 = tf.nn.relu(self.bn_3x3_4(self.conv_3x3_4(inputs), training=training))
        conv_3x3_8 = tf.nn.relu(self.bn_3x3_8(self.conv_3x3_8(inputs), training=training))

        gap = self.global_avg_pool(inputs)
        # tf.image.resize works fine with KerasTensors usually
        gap = tf.image.resize(gap, [input_shape[1], input_shape[2]])

        concat = layers.Concatenate()([conv_1x1, conv_3x3_2, conv_3x3_4, conv_3x3_8, gap])
        output = tf.nn.relu(self.bn_final(self.conv_final(concat), training=training))
        return self.dropout(output, training=training)

# -------------------- CONVNEXT BACKBONE --------------------
def create_convnext_backbone(input_shape):
    # Consider ConvNeXtTiny or ConvNeXtSmall for lighter model if needed
    # base_model = tf.keras.applications.ConvNeXtTiny( # Lighter option
    #     input_shape=input_shape, weights='imagenet', include_top=False
    # )
    base_model = tf.keras.applications.ConvNeXtBase(
        input_shape=input_shape, weights='imagenet', include_top=False
    )
    layer_names = [
        'convnext_base_stage_0_block_0_layer_scale',
        'convnext_base_stage_1_block_0_layer_scale',
        'convnext_base_stage_2_block_0_layer_scale',
        'convnext_base_stage_3_block_2_layer_scale'
    ]
    outputs = [base_model.get_layer(name).output for name in layer_names]

    # FIX: Wrap tf.cast in a Lambda layer for KerasTensor compatibility
    casted_outputs = [layers.Lambda(lambda x: tf.cast(x, tf.keras.backend.floatx()))(o) for o in outputs]

    return tf.keras.Model(inputs=base_model.input, outputs=casted_outputs)

# -------------------- MODELO PRINCIPAL --------------------
def create_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES):
    inputs = layers.Input(shape=input_shape, dtype=tf.keras.backend.floatx())

    # Create the backbone model
    backbone = create_convnext_backbone(input_shape)

    # Apply backbone to inputs
    features = backbone(inputs)

    early, mid, deep = features[0], features[2], features[-1]

    aspp = ASPPModule(256)(deep)
    cbam = CBAMModule(256)(aspp)
    decoder = layers.UpSampling2D((4, 4))(cbam)

    early = layers.Conv2D(64, 1, activation='relu', kernel_regularizer=regularizers.l2(1e-4))(early) # Added L2
    early = layers.BatchNormalization()(early)
    early = layers.UpSampling2D((2, 2))(early)

    mid = layers.Conv2D(96, 1, activation='relu', kernel_regularizer=regularizers.l2(1e-4))(mid) # Added L2
    mid = layers.BatchNormalization()(mid)
    mid = layers.UpSampling2D((2, 2))(mid)

    x = layers.Concatenate()([decoder, mid])
    x = layers.Conv2D(256, 3, padding='same', activation='relu', kernel_regularizer=regularizers.l2(1e-4))(x) # Added L2
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu', kernel_regularizer=regularizers.l2(1e-4))(x) # Added L2
    x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D((4, 4))(x)

    x = layers.Concatenate()([x, early])
    x = layers.Conv2D(128, 3, padding='same', activation='relu', kernel_regularizer=regularizers.l2(1e-4))(x) # Added L2
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu', kernel_regularizer=regularizers.l2(1e-4))(x) # Added L2
    x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D((2, 2))(x)

    outputs = layers.Conv2D(num_classes, 1, activation='softmax', dtype=tf.keras.backend.floatx())(x)
    return keras.Model(inputs, outputs)

# -------------------- MÉTRICAS PERSONALIZADAS --------------------
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    # Ensure y_pred is not exactly 0 or 1 to avoid NaNs with log for focal loss
    y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1.0 - tf.keras.backend.epsilon())

    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true + y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

def iou_metric(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1.0 - tf.keras.backend.epsilon())

    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true + y_pred, axis=[1, 2, 3]) - intersection
    return tf.reduce_mean((intersection + smooth) / (union + smooth))

class DiceLoss(tf.keras.losses.Loss):
    def __init__(self, name="dice_loss", smooth=1e-6):
        super().__init__(name=name)
        self.smooth = smooth

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1.0 - tf.keras.backend.epsilon())

        y_true_f = tf.reshape(y_true, [-1, NUM_CLASSES])
        y_pred_f = tf.reshape(y_pred, [-1, NUM_CLASSES])

        intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
        union = tf.reduce_sum(y_true_f + y_pred_f, axis=0)

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - tf.reduce_mean(dice)

# New: Optional Focal Loss for severe class imbalance
class FocalLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=0.25, name="focal_loss"):
        super().__init__(name=name)
        self.gamma = gamma
        self.alpha = alpha

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)

        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)

        # Calculate Cross Entropy
        cross_entropy = -y_true * tf.math.log(y_pred)

        # Calculate Focal Loss
        pt = y_pred * y_true + (1 - y_pred) * (1 - y_true) # p_t for true class
        loss = self.alpha * tf.pow(1. - pt, self.gamma) * cross_entropy
        return tf.reduce_mean(tf.reduce_sum(loss, axis=-1)) # Sum over classes, then mean over batch and spatial dims


# -------------------- PREPARAR DATOS --------------------
def load_data_and_create_datasets():
    image_paths = []
    mask_paths = []

    # Corrected DATASET_PATH usage
    for subset in ['train', 'test', 'val']:
        image_paths.extend(sorted(glob(f"{DATASET_ROOT_PATH}/Balanced/{subset}/images/*")))
        mask_paths.extend(sorted(glob(f"{DATASET_ROOT_PATH}/Balanced/{subset}/masks/*")))

    if not image_paths or not mask_paths:
        raise ValueError(f"No images or masks found. Check the dataset path: {DATASET_ROOT_PATH}. "
                         f"Found images: {len(image_paths)}, masks: {len(mask_paths)}")

    images = [cv2.resize(cv2.imread(p), (IMG_WIDTH, IMG_HEIGHT)).astype(np.float32) for p in image_paths]
    masks = [cv2.resize(cv2.imread(p, cv2.IMREAD_GRAYSCALE), (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST).astype(np.int32) for p in mask_paths]

    images = np.array(images) / 255.0
    masks = np.array(masks)
    masks_one_hot = keras.utils.to_categorical(masks, num_classes=NUM_CLASSES)

    # Calculate class weights for imbalance handling (optional)
    flat_masks_int = np.argmax(masks_one_hot, axis=-1).flatten()
    class_counts = np.bincount(flat_masks_int, minlength=NUM_CLASSES)
    # Avoid division by zero for classes that might not exist in training data
    class_counts = np.where(class_counts == 0, 1, class_counts)
    total_samples = len(flat_masks_int)
    class_weights = total_samples / (NUM_CLASSES * class_counts)
    class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}
    print(f"Calculated Class Weights: {class_weights_dict}")


    X_train_raw, X_val_raw, y_train_raw, y_val_raw = train_test_split(images, masks_one_hot, test_size=0.2, random_state=42)

    def augment_fn(image, mask):
        # Random Horizontal Flip
        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_left_right(image)
            mask = tf.image.flip_left_right(mask)

        # Random Vertical Flip
        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_up_down(image)
            mask = tf.image.flip_up_down(mask)

        # Random Rotations (multiples of 90 degrees for simpler implementation with masks)
        k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
        image = tf.image.rot90(image, k)
        mask = tf.image.rot90(mask, k)

        # Random brightness (only for image)
        image = tf.image.random_brightness(image, max_delta=0.1)
        # Random contrast (only for image)
        image = tf.image.random_contrast(image, lower=0.8, upper=1.2)

        return image, mask

    def preprocess_fn(image, mask):
        # Ensure correct data types for TensorFlow operations
        image = tf.image.convert_image_dtype(image, tf.float32)
        mask = tf.image.convert_image_dtype(mask, tf.float32)
        return image, mask

    train_dataset = tf.data.Dataset.from_tensor_slices((X_train_raw, y_train_raw))
    train_dataset = train_dataset.shuffle(buffer_size=len(X_train_raw)) \
                                 .map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE) \
                                 .map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE) \
                                 .batch(BATCH_SIZE) \
                                 .prefetch(tf.data.AUTOTUNE)

    val_dataset = tf.data.Dataset.from_tensor_slices((X_val_raw, y_val_raw))
    val_dataset = val_dataset.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE) \
                             .batch(BATCH_SIZE) \
                             .prefetch(tf.data.AUTOTUNE)

    return train_dataset, val_dataset, X_val_raw, y_val_raw, class_weights_dict


train_ds, val_ds, X_val_raw, y_val_raw, class_weights = load_data_and_create_datasets()

# -------------------- COMPILAR Y ENTRENAR --------------------
model = create_model()

# Print model summary to check parameters and size
print("\nModel Summary:")
model.summary()

# Use custom DiceLoss or FocalLoss
custom_loss = DiceLoss(name="dice_loss") # Or FocalLoss(gamma=2.0, alpha=0.25)
# If using FocalLoss, you might not need class_weight in model.fit, or combine them carefully.

model.compile(optimizer=keras.optimizers.Adam(LEARNING_RATE),
              loss=custom_loss,
              metrics=[iou_metric, dice_coefficient, 'accuracy'])

# Callbacks for better training
callbacks = [
    tf.keras.callbacks.ModelCheckpoint("best_model.keras", save_best_only=True, monitor='val_loss', mode='min', verbose=1),
    tf.keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True, monitor='val_loss', mode='min', verbose=1),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=7, min_lr=1e-6, verbose=1, monitor='val_loss', mode='min')
]

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=EPOCHS,
                    callbacks=callbacks,
                    class_weight=class_weights) # Apply class weights

# -------------------- GUARDAR MODELO (Recommended .keras format) --------------------
model.save("modelo_mejorado_segmentacion.keras") # Recommended format

# -------------------- MOSTRAR MÉTRICAS --------------------
print("\nEvaluación final (usando el mejor modelo):")
# Load the best model to evaluate
best_model = keras.models.load_model("best_model.keras",
                                     custom_objects={
                                         'DiceLoss': DiceLoss,
                                         'iou_metric': iou_metric,
                                         'dice_coefficient': dice_coefficient,
                                         'FocalLoss': FocalLoss # Add if you plan to use it
                                     })

results = best_model.evaluate(val_ds, verbose=2) # Evaluate on the dataset directly
print(f"Pérdida (Dice Loss): {results[0]:.4f}")
print(f"IoU: {results[1]:.4f}")
print(f"DICE: {results[2]:.4f}")
print(f"Precisión: {results[3]:.4f}")


# -------------------- VISUALIZACIÓN DE PREDICCIONES --------------------
def display_predictions(model_to_use, raw_images, raw_masks, num_samples=3):
    # Select a few random samples from the raw validation data
    indices = np.random.choice(len(raw_images), num_samples, replace=False)

    for idx in indices:
        image = raw_images[idx]
        true_mask_one_hot = raw_masks[idx]

        # Preprocess the single image for prediction (matching dataset preprocessing)
        input_image = tf.image.convert_image_dtype(image, tf.float32)
        input_image = tf.expand_dims(input_image, 0) # Add batch dimension

        pred_mask_probs = model_to_use.predict(input_image)[0] # Get rid of batch dim

        # Convert one-hot to single channel for visualization
        true_mask_single_channel = np.argmax(true_mask_one_hot, axis=-1)
        pred_mask_single_channel = np.argmax(pred_mask_probs, axis=-1)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(image) # Already 0-1 scaled, no need for .numpy() for raw image
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        axes[1].imshow(true_mask_single_channel, cmap='viridis', vmin=0, vmax=NUM_CLASSES-1) # Set vmin/vmax for consistent coloring
        axes[1].set_title('Ground Truth Mask')
        axes[1].axis('off')

        axes[2].imshow(pred_mask_single_channel, cmap='viridis', vmin=0, vmax=NUM_CLASSES-1)
        axes[2].set_title('Predicted Mask')
        axes[2].axis('off')

        plt.tight_layout()
        plt.show()

print("\nVisualizing Predictions from Best Model:")
display_predictions(best_model, X_val_raw, y_val_raw, num_samples=5)

Calculated Class Weights: {0: np.float64(0.20199047836666972), 1: np.float64(5.8397371792391155), 2: np.float64(11.995724180345873), 3: np.float64(6.9607691999558154), 4: np.float64(26.536845417692465), 5: np.float64(1.6304642973413686)}





Model Summary:


Epoch 1/50
[1m105/105[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 597ms/step - accuracy: 0.6000 - dice_coefficient: 0.5865 - iou_metric: 0.4323 - loss: 1.0278
Epoch 1: val_loss improved from inf to 0.92625, saving model to best_model.keras
[1m105/105[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m221s[0m 955ms/step - accuracy: 0.6006 - dice_coefficient: 0.5872 - iou_metric: 0.4329 - loss: 1.0272 - val_accuracy: 0.8067 - val_dice_coefficient: 0.8036 - val_iou_metric: 0.7122 - val_loss: 0.9263 - learning_rate: 0.0010
Epoch 2/50
[1m105/105[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 598ms/step - accuracy: 0.6880 - dice_coefficient: 0.6838 - iou_metric: 0.5294 - loss: 0.8855
Epoch 2: val_loss improved from 0.92625 to 0.85304, saving model to best_model.keras
[1m105/105[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 761ms/step - accuracy: 0.6880 - dice_coefficient: 0.6839 - iou_metric: 0.5294 - loss: 0.8852 - val_accuracy: 0.7476 - val_dice_coefficient: 0.74

ValueError: The `{arg_name}` of this `Lambda` layer is a Python lambda. Deserializing it is unsafe. If you trust the source of the config artifact, you can override this error by passing `safe_mode=False` to `from_config()`, or calling `keras.config.enable_unsafe_deserialization().