In [None]:
##peculiaridad de la arquitectura de segmentación
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, Activation, BatchNormalization, UpSampling2D, Add
from tensorflow.keras.optimizers import Adam


#otras funciones de perdidas que NO sean las "típicas" error cuadrático, ....
#coeficiente dice.  Mide la intersección entre las máscaras.
def dice_coef(y_true, y_pred):     #función del coeficiente dice. (entre 0 y 1, cuanto más cercano a 1 es mejor)
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):   #pérdidas (ponerle coeficiente negativo)
    return -dice_coef(y_true, y_pred)


def get_unet2D(img_rows, img_cols, img_ch):   #arquitectura de unet2D
    # Encoding phase
    inputs = Input((img_rows, img_cols, img_ch))   #imagen de entrada
    conv1 = Conv2D(32, (3, 3), padding='same')(inputs)  #capa convolucional
    conv1 = BatchNormalization()(conv1)    #aplicamos normalización a la salida del filtro convolucional.  ((Menos la media entre la desviación estándar))
    conv1 = Activation('relu')(conv1)  #separo la de activación. la hago después de normalizar.
    conv1 = Conv2D(32, (3, 3), padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(64, (3, 3), padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(128, (3, 3), padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(256, (3, 3), padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Decoding phase            #NO tenemos "Dense"
    conv5 = Conv2D(512, (3, 3), padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(512, (3, 3), padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)    #"conexiones puente"
    conv6 = Conv2D(256, (3, 3), padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(256, (3, 3), padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(128, (3, 3), padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(64, (3, 3), padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(32, (3, 3), padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)

    # Output
    output = Activation('sigmoid')(conv9)    #la función de activación dependerá si sólo tengo dos clases o más de una (softmax)

    # Compile model with inputs and outputs
    model = Model(inputs=[inputs], outputs=[output])

    
    return model

In [None]:
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

def train():

  model = get_unet2D(32, 32, 3)

  model.compile(optimizer="adam", loss=dice_coef_loss, metrics=[dice_coef])    #las funciones las he definido yo, añadimos otras métricas

  model_checkpoint = ModelCheckpoint(filepath='./models/unet.h5',
                                     monitor='val_loss',   #donde me interesa que sea bueno (pérdidas en validación)
                                     save_best_only=True)  #me guardo sólo el mejor.
  
  rlrp = ReduceLROnPlateau(monitor='val_loss',      #reducimos la función de pérdidas según vaya mi entrenamiento.
                           factor=0.1,              #al principio tasa de aprendizaje grande y va reduciendo.
                           patience=5,              #5 iteraciones igual sin cambio.
                           min_delta=1e-7,          #NO baja de este valor.
                           verbose=1)
  
  earlystopping = EarlyStopping(monitor="val_loss",
                                mode="min", patience=5,    #5 iteraciones sin mejorar
                                restore_best_weights=True,
                                verbose=1)
  
  callbacks = [model_checkpoint, rlrp, earlystopping]

  H = model.fit(data_generator_train,
                steps_per_epoch=X_train.shape[0] // batch_size,
                epochs=20,
                verbose=1,
                callbacks=callbacks,
                validation_data=(X_val, y_val))
