In [3]:
import tensorflow as tf
from tensorflow.keras import layers, datasets, models
import numpy as np

# Cargar el conjunto de datos CIFAR-10
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

# Normalizar los datos de entrada
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Convertir etiquetas a one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Parámetros del transformer
PATCH_SIZE = 4  # Dividimos la imagen en patches de 4x4
NUM_PATCHES = (32 // PATCH_SIZE) ** 2  # CIFAR10 tiene imágenes de 32x32
D_MODEL = 64  # Tamaño de la representación por patch
NUM_HEADS = 8  # Número de cabezas para la atención múltiple
D_FF = 128  # Tamaño del feedforward en el bloque transformer
NUM_BLOCKS = 4  # Número de bloques transformer

# Función para dividir las imágenes en patches
def extract_patches(images, patch_size):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims])
    return patches

# Implementar Self-Attention
class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads
        self.wq = layers.Dense(d_model)
        self.wk = layers.Dense(d_model)
        self.wv = layers.Dense(d_model)
        self.dense = layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # Escalado de los valores
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        # Softmax para obtener las atenciones
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

        # Multiplicar las atenciones por los valores
        output = tf.matmul(attention_weights, v)  # (batch_size, num_heads, seq_len_q, depth_v)

        output = tf.transpose(output, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output

# Bloque transformer básico
class TransformerBlock(layers.Layer):
    def __init__(self, d_model, num_heads, dff):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadSelfAttention(d_model, num_heads)
        self.ffn = models.Sequential([
            layers.Dense(dff, activation='relu'),  # Primera capa feedforward
            layers.Dense(d_model)  # Segunda capa feedforward
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(0.1)
        self.dropout2 = layers.Dropout(0.1)

    def call(self, x, training):
        attn_output = self.att(x, x, x)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

# Modelo completo con patches y bloques transformer
class TransformerClassifier(models.Model):
    def __init__(self, num_patches, d_model, num_heads, dff, num_blocks, num_classes):
        super(TransformerClassifier, self).__init__()
        self.patch_proj = layers.Dense(d_model)
        self.transformer_blocks = [TransformerBlock(d_model, num_heads, dff) for _ in range(num_blocks)]
        self.pool = layers.GlobalAveragePooling1D()  # Agregamos un GlobalAveragePooling
        self.fc = layers.Dense(num_classes, activation='softmax')

    def call(self, x, training):
        # Dividir en patches
        patches = extract_patches(x, PATCH_SIZE)
        x = self.patch_proj(patches)

        # Aplicar bloques Transformer
        for block in self.transformer_blocks:
            x = block(x, training=training)

        # Global average pooling para aplanar las dimensiones
        x = self.pool(x)

        # Clasificación final
        return self.fc(x)

# Parámetros del modelo
num_classes = 10
model = TransformerClassifier(NUM_PATCHES, D_MODEL, NUM_HEADS, D_FF, NUM_BLOCKS, num_classes)

# Compilar el modelo
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Entrenamiento
model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_test, y_test))

# Evaluar el modelo en el conjunto de test
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Precisión en test: {test_acc}')


Epoch 1/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m316s[0m 381ms/step - accuracy: 0.2338 - loss: 2.0214 - val_accuracy: 0.3529 - val_loss: 1.7505
Epoch 2/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m329s[0m 390ms/step - accuracy: 0.3973 - loss: 1.6204 - val_accuracy: 0.4607 - val_loss: 1.4781
Epoch 3/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m298s[0m 381ms/step - accuracy: 0.4670 - loss: 1.4599 - val_accuracy: 0.4891 - val_loss: 1.3974
Epoch 4/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m320s[0m 378ms/step - accuracy: 0.5027 - loss: 1.3649 - val_accuracy: 0.4966 - val_loss: 1.3851
Epoch 5/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m322s[0m 379ms/step - accuracy: 0.5306 - loss: 1.2939 - val_accuracy: 0.5152 - val_loss: 1.3495
Epoch 6/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m321s[0m 378ms/step - accuracy: 0.5464 - loss: 1.2577 - val_accuracy: 0.5475 - val_loss: 1.2708
Epoc