In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout

class PatchEmbedding(layers.Layer):
    def __init__(self, patch_size, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.projection = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_size)
        self.flatten = layers.Reshape((-1, embed_dim))

    def call(self, x):
        x = tf.cast(x, tf.float16)
        patches = self.projection(x)
        flattened = self.flatten(patches)
        return tf.cast(flattened, tf.float16)

    def get_config(self):
        config = super().get_config()
        config.update({
            "patch_size": self.patch_size,
            "embed_dim": self.embed_dim
        })
        return config

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout
        
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="gelu"),
            layers.Dense(embed_dim),
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)

    def call(self, inputs):
        inputs = tf.cast(inputs, tf.float16)
        attn_output = self.att(inputs, inputs)
        attn_output = tf.cast(attn_output, tf.float16)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = tf.cast(ffn_output, tf.float16)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "ff_dim": self.ff_dim,
            "dropout": self.dropout_rate
        })
        return config

class VisionTransformer(layers.Layer):
    def __init__(self, num_heads=8, embed_dim=256, ff_dim=2048, 
                 num_transformer_blocks=6, window_size=512, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_transformer_blocks = num_transformer_blocks
        self.window_size = window_size
        self.dropout_rate = dropout
        
        # Definir dimensiones de la imagen y patches
        self.img_height = 2664
        self.img_width = 44
        self.patch_size = (36, 4)
        self.num_patches = (self.img_height // self.patch_size[0]) * (self.img_width // self.patch_size[1])
        
        # Inicializar capas
        self.patch_embed = PatchEmbedding(self.patch_size, embed_dim)
        self.position_embed = layers.Embedding(self.num_patches + 1, embed_dim)
        self.cls_token = tf.Variable(tf.zeros([1, 1, embed_dim], dtype=tf.float16))
        
        # Crear bloques transformer
        self.transformer_blocks = [
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_transformer_blocks)
        ]
        
        # MLP head
        self.mlp_head = tf.keras.Sequential([
            layers.Dense(ff_dim, activation='gelu'),
            layers.Dropout(dropout),
            layers.Dense(self.patch_size[0] * self.patch_size[1])
        ])

    def call(self, inputs):
        inputs = tf.cast(inputs, tf.float16)
        x = self.patch_embed(inputs)
        
        batch_size = tf.shape(x)[0]
        cls_tokens = tf.cast(tf.repeat(self.cls_token, batch_size, axis=0), tf.float16)
        x = tf.concat([cls_tokens, x], axis=1)
        
        positions = tf.range(start=0, limit=self.num_patches + 1)
        position_embeddings = tf.cast(self.position_embed(positions), tf.float16)
        x = x + position_embeddings
        
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)
        
        x = x[:, 1:, :]
        x = self.mlp_head(x)
        
        final_h = self.img_height
        final_w = self.img_width
        
        x = tf.reshape(x, [-1, final_h // self.patch_size[0], 
                          final_w // self.patch_size[1], 
                          self.patch_size[0] * self.patch_size[1]])
        
        x = tf.reshape(x, [-1, final_h, final_w, 1])
        
        return x

    def get_config(self):
        config = super().get_config()
        config.update({
            "num_heads": self.num_heads,
            "embed_dim": self.embed_dim,
            "ff_dim": self.ff_dim,
            "num_transformer_blocks": self.num_transformer_blocks,
            "window_size": self.window_size,
            "dropout": self.dropout_rate
        })
        return config

    @classmethod
    def from_config(cls, config):
        # Extraer solo los argumentos que necesitamos
        vit_config = {
            'num_heads': config.get('num_heads', 8),
            'embed_dim': config.get('embed_dim', 256),
            'ff_dim': config.get('ff_dim', 2048),
            'num_transformer_blocks': config.get('num_transformer_blocks', 6),
            'window_size': config.get('window_size', 256),
            'dropout': config.get('dropout', 0.1)
        }
        return cls(**vit_config)

def create_vit_model(input_shape):
    # Configurar la política de precisión mixta
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    
    # Crear el modelo
    inputs = Input(shape=input_shape, dtype=tf.float16)
    vit = VisionTransformer(
        num_heads=8,
        embed_dim=256,
        ff_dim=512,
        num_transformer_blocks=6,
        window_size=256,
        dropout=0.1
    )
    outputs = vit(inputs)
    return Model(inputs=inputs, outputs=outputs)

# Uso del modelo
if __name__ == "__main__":
    model = create_vit_model(input_shape=(2664, 44, 1))
    print(model.summary())




# Crear el modelo
model = create_vit_model(input_shape=(2664, 44, 1))

# Ver el resumen del modelo
model.summary()