In [56]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

class PatchEmbedding(layers.Layer):
    def __init__(self, patch_size, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.projection = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_size)
        self.flatten = layers.Reshape((-1, embed_dim))

    def call(self, x):
        patches = self.projection(x)
        flattened = self.flatten(patches)
        return flattened

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="gelu"),
            layers.Dense(embed_dim),
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)

    def call(self, inputs):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)
    
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Multiply
from tensorflow.keras import backend as K
from tensorflow.keras.layers import LeakyReLU

def attention_block(x, g, inter_channel):
    """
    x: Skip connection input
    g: Gate signal (from previous layer)
    inter_channel: Number of intermediate channels
    """
    # Get input shapes
    g_shape = K.int_shape(g)
    x_shape = K.int_shape(x)
    
    # Calculate target dimensions
    target_h = x_shape[1]
    target_w = x_shape[2]
    
    # Ensure inter_channel is valid
    inter_channel = max(1, int(inter_channel))
    
    # Resize gate signal if needed
    if g_shape[1] != target_h or g_shape[2] != target_w:
        g = tf.image.resize(g, (target_h, target_w))
        if g_shape[3] != x_shape[3]:
            g = Conv2D(x_shape[3], (1, 1), padding='same')(g)
    
    # Transform signals
    theta_x = Conv2D(inter_channel, (1, 1), padding='same')(x)
    phi_g = Conv2D(inter_channel, (1, 1), padding='same')(g)
    
    # Compute compatibility
    f = LeakyReLU(alpha=0.3)(theta_x + phi_g)
    psi_f = Conv2D(1, (1, 1), padding='same')(f)
    
    # Generate attention weights
    rate = tf.nn.sigmoid(psi_f)
    
    # Apply attention
    att_x = Multiply()([x, rate])
    
    return att_x

def unet_transformer_inpainting(input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_shape)
    
    # Initial patch embedding
    patch_size = (72, 4)
    embed_dim = 256
    x = PatchEmbedding(patch_size, embed_dim)(inputs)
    
    # Transformer blocks
    transformer_layers = []
    for _ in range(4):
        x = TransformerBlock(embed_dim, num_heads=8, ff_dim=4*embed_dim)(x)
        transformer_layers.append(x)
    
    # Reshape back to spatial dimensions
    x = layers.Reshape((IMG_HEIGHT//patch_size[0], IMG_WIDTH//patch_size[1], embed_dim))(x)
    

    # importar inicializador de semilla
    from tensorflow.keras.initializers import RandomNormal, RandomUniform, TruncatedNormal, VarianceScaling, glorot_normal, glorot_uniform, he_normal, he_uniform

    # Importar funciones de activación
    from tensorflow.keras.layers import LeakyReLU, PReLU, ReLU, ELU, ThresholdedReLU

    # Definir inicializador con semilla
    initializer = glorot_normal(seed=0)

    # Definir función de activación
    activation = LeakyReLU(alpha=0.3)



    
    kernel_size = (3, 3)
    # Capa de entrada


    # Bloque 1
    conv1 = Conv2D(32, (1500,6), activation=activation, padding='same', kernel_initializer=initializer)(inputs)
    conv1 = Conv2D(32, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    # Bloque 2
    conv2 = Conv2D(64, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(pool1)
    conv2 = Conv2D(64, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 1))(conv2)

    # Bloque 3
    conv3 = Conv2D(128, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(pool2)
    conv3 = Conv2D(128, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 1))(conv3)

    # Bloque 4
    conv4 = Conv2D(256, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(pool3)
    conv4 = Conv2D(256, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv4)
    drop4 = Dropout(0)(conv4)
    pool4 = MaxPooling2D(pool_size=(3, 2))(drop4)

    # Bloque 5
    conv41 = Conv2D(512, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(pool4)
    conv41 = Conv2D(512, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv41)
    drop41 = Dropout(0.01)(conv41)
    pool41 = MaxPooling2D(pool_size=(3, 1))(drop41)



    # Capa de bottleneck
    conv5 = Conv2D(1024, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(pool41)
    conv5 = Conv2D(1024, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv5)

    # Bloque 61
    up61 = Conv2DTranspose(512, kernel_size, strides=(3, 1), padding='same', kernel_initializer=initializer)(conv5)
    up61 = concatenate([up61, drop41], axis=3)
    conv61 = Conv2D(512, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(up61)
    conv61 = Conv2D(512, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv61)

    # Bloque 6
    up6 = Conv2DTranspose(256, kernel_size, strides=(3, 2), padding='same', kernel_initializer=initializer)(conv61)
    up6 = concatenate([up6, drop4], axis=3)
    conv6 = Conv2D(256, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(up6)
    conv6 = Conv2D(256, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv6)

    # Bloque 7
    up7 = Conv2DTranspose(128, kernel_size, strides=(2, 1), padding='same', kernel_initializer=initializer)(conv6)
    up7 = concatenate([up7, conv3], axis=3)
    conv7 = Conv2D(128, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(up7)
    conv7 = Conv2D(128, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv7)

    # Bloque 8
    up8 = Conv2DTranspose(64, kernel_size, strides=(2, 1), padding='same', kernel_initializer=initializer)(conv7)
    up8 = concatenate([up8, conv2], axis=3)
    conv8 = Conv2D(64, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(up8)
    conv8 = Conv2D(64, kernel_size, activation=activation, padding='same', kernel_initializer=initializer)(conv8)

    # Bloque 9
    up9 = Conv2DTranspose(32, kernel_size, strides=(2, 2), padding='same', kernel_initializer=initializer)(conv8)
    up9 = concatenate([up9, conv1], axis=3)
    conv9 = Conv2D(16, (1,1), activation=activation, padding='same', kernel_initializer=initializer)(up9)
    conv10 = Conv2D(16, (1,1), activation=activation, padding='same', kernel_initializer=initializer)(conv9)
    conv11 = Conv2D(1, (1,1), padding='same', activation=activation, kernel_initializer=initializer)(conv10)
    

        # Add global attention at skip connections
    for i, skip in enumerate([conv4, conv3, conv2, conv1]):
        x = attention_block(skip, x, embed_dim//(2**i))
        x = concatenate([x, skip], axis=3)
    
    outputs = Conv2D(1, (1,1), activation='relu')(x)
    
    return Model(inputs, outputs)

# Definir el modelo

model = unet_transformer_inpainting(input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))






# obtener el índice de ejecución actual
cell_index=(get_ipython().execution_count)