In [None]:
import tensorflow as tf
from tensorflow.keras import layers, losses, Model
from tensorflow.keras.datasets import mnist
import numpy as np

# Load and prepare the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
scale = 0.4
x_train = x_train.astype('float32') / 255. + np.random.normal(scale=scale,size=x_train.shape)
x_test = x_test.astype('float32') / 255. + np.random.normal(scale=scale,size=x_test.shape)

# create training, validation, and testing sets
x_val = x_train[50000:]
y_val = y_train[50000:]
x_train = x_train[:50000]
y_train = y_train[:50000]
x_train = x_train[..., tf.newaxis]
x_val = x_val[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

print(x_train.shape,x_val.shape,x_test.shape,y_train.shape,y_val.shape,y_test.shape)

In [None]:
y_train

In [None]:
import matplotlib.pyplot as plt

#plot original images vs reconstructed images
def plot_mnist_autoencoder(x,xpred,cmap='gray',vmin=0,vmax=1):
  fig,ax = plt.subplots(2,x.shape[0],figsize=(8,1))
  for i,class_ in enumerate(range(x.shape[0])):
        ax[0,i].imshow(x[i],cmap=cmap,vmin=vmin,vmax=vmax)
        ax[0,i].set_xticks([])
        ax[0,i].set_yticks([])

        ax[1,i].imshow(xpred[i],cmap=cmap,vmin=vmin,vmax=vmax)
        ax[1,i].set_xticks([])
        ax[1,i].set_yticks([])
  plt.show()
  return

def plot_mnist_autoencoder2(x, xpred, y, y_pred, cmap='gray', vmin=0, vmax=1):
    # Número de imágenes a visualizar
    num_images = x.shape[0]

    # Configura el espacio de visualización
    fig, ax = plt.subplots(2, num_images, figsize=(15,5))

    # Visualiza cada par de imágenes original y reconstruida junto con sus etiquetas
    for i in range(num_images):
        # Imagen original
        ax[0, i].imshow(x[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax[0, i].set_xticks([])
        ax[0, i].set_yticks([])
        ax[0, i].set_xlabel(f"y: {y[i]}", fontsize=10)  # Muestra la etiqueta real debajo de la imagen
        ax[0, i].xaxis.set_label_position('top')

        # Imagen reconstruida
        ax[1, i].imshow(xpred[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax[1, i].set_xticks([])
        ax[1, i].set_yticks([])
        ax[1, i].set_xlabel(f"y_pred: {y_pred[i]}", fontsize=10)  # Muestra la etiqueta predicha debajo de la imagen
        ax[1, i].xaxis.set_label_position('top')

    plt.tight_layout()
    plt.show()
    return
plot_mnist_autoencoder2(x_train[:15],x_train[:15],y_train,y_train)

from matplotlib.offsetbox import OffsetImage, AnnotationBbox

#plot images on latent space
def plot_mnist_2d(Z,y,images,img_w=28,img_h=28,zoom=0.5,cmap='jet'):
    fig, ax = plt.subplots(figsize=(5,5))
    plt.axis('off')
    for i in range(Z.shape[0]):
        #print('img',i+1,'/',Z.shape[0])
        image = images[i].reshape((img_w, img_h))
        im = OffsetImage(image, zoom=zoom,cmap=cmap)
        ab = AnnotationBbox(im, (Z[i,0], Z[i,1]), xycoords='data', frameon=False)
        ax.add_artist(ab)
        ax.update_datalim([(Z[i,0], Z[i,1])])
        ax.autoscale()
    plt.show()

In [None]:
from tensorflow.keras.constraints import Constraint
class OrthogonalConstraint(Constraint):
    def __call__(self, w):
        a, u, _ = tf.linalg.svd(w, full_matrices=False)
        return u
class DenseTransposeLayer(layers.Layer):
    def __init__(self, units, factor_o=0.1,activation=None, **kwargs):
        super(DenseTransposeLayer, self).__init__(**kwargs)
        self.units = units
        self.factor_o = factor_o
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape, self.units),
            initializer="random_normal",
            trainable=True,regularizer=tf.keras.regularizers.OrthogonalRegularizer(factor=self.factor_o),
            constraint=OrthogonalConstraint()
        )
        #self.b1 = self.add_weight(
        #    shape=(self.units,), initializer="zeros", trainable=True)
        #self.b2 = self.add_weight(
        #    shape=(input_shape[-1],), initializer="zeros", trainable=True)

        super(DenseTransposeLayer, self).build(input_shape)

    def call(self, inputs):
        x1 = tf.matmul(inputs, self.w)  # Si quieres agregar el sesgo, añade + self.b1
        if self.activation is not None:
            x1 = self.activation(x1)
        x2 = tf.matmul(x1, tf.transpose(self.w))  # Si quieres agregar el segundo sesgo, añade + self.b2
        return x2, x1  # Devuelve ambos valores

In [None]:
class PCAutoencoder(Model):
    def __init__(self, encoding_dim, num_classes, factor_o=0.1):
        super(PCAutoencoder, self).__init__()
        self.encoding_dim = encoding_dim
        self.num_classes = num_classes
        self.factor_o = factor_o
        # Encoder layers
        self.encoder_input_layer = layers.Flatten()
        self.encoder_decoder_transpose = DenseTransposeLayer(self.encoding_dim, factor_o=self.factor_o, activation='linear')

        # Clasificador con múltiples capas
        self.classifier_layers = [
            layers.Dense(128, activation='relu'),  # Primera capa densa
            layers.Dropout(0.5),  # Dropout para regularización
            layers.Dense(256, activation='relu'),   # Segunda capa densa
            layers.Dense(num_classes, activation='softmax')  # Capa de salida
        ]

        # Decoder layers will be initialized in build()
        self.decoder_output_layer = None

    def build(self, input_shape):
        # Initialize decoder layers
        self.encoder_decoder_transpose.build(input_shape[1]*input_shape[2])
        self.decoder_output_layer = layers.Reshape(input_shape[1:])
        super().build(input_shape)

    def call(self, inputs):
        x = self.encoder_input_layer(inputs)
        out,encoded = self.encoder_decoder_transpose(x)
        encoded
        # Pasar los datos codificados a través de las capas del clasificador
        for layer in self.classifier_layers:
            encoded = layer(encoded)
        classification = encoded  # La última capa modifica 'encoded' para ser la salida de clasificación
        decoded = self.decoder_output_layer(out)
        return decoded, classification


In [None]:
# Instantiate the autoencoder
encoding_dim = 64
input_shape = (None, 28, 28, 1)
factor_o = 0.1
num_classes = 10
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.1)
pcautoencoder = PCAutoencoder(encoding_dim, num_classes, factor_o=factor_o)

pcautoencoder.build(input_shape)
# Compila el modelo con funciones de pérdida específicas para cada salida y un optimizador
pcautoencoder.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # Ajusta el learning rate si es necesario
    loss={'decoded': 'mse', 'classification': 'sparse_categorical_crossentropy'},  # Asume 'decoded' y 'classification' son los nombres de las salidas
    metrics={'decoded': 'mse', 'classification': 'accuracy'}
)
pcautoencoder.summary()

In [None]:
# Define the loss object and the optimizer
tf.keras.backend.clear_session()
optimizer = tf.keras.optimizers.Adam()
loss_object_reconstruction = tf.keras.losses.MeanSquaredError()
loss_object_classification = tf.keras.losses.SparseCategoricalCrossentropy()

# Definir medidas adicionales para la pérdida de clasificación
train_classification_loss = tf.keras.metrics.Mean(name='train_classification_loss')
test_classification_loss = tf.keras.metrics.Mean(name='val_classification_loss')
# Define measures to track loss
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='val_loss')
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        reconstructed, predictions = pcautoencoder(images, training=True)
        reconstruction_loss = loss_object_reconstruction(images, reconstructed)
        classification_loss = loss_object_classification(labels, predictions)
        a=0.9
        b=1-a
        total_loss = a*reconstruction_loss + b*classification_loss  # Ajusta esta línea según necesites ponderar las pérdidas
    gradients = tape.gradient(total_loss, pcautoencoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, pcautoencoder.trainable_variables))
    train_loss(a*reconstruction_loss)
    train_classification_loss(b*classification_loss)

@tf.function
def test_step(images, labels):
    reconstructed, predictions = pcautoencoder(images, training=False)
    a=0.9
    b=1-a
    t_reconstruction_loss = loss_object_reconstruction(images, reconstructed)
    t_classification_loss = loss_object_classification(labels, predictions)
    test_loss(a*t_reconstruction_loss)
    test_classification_loss(b*t_classification_loss)

epochs = 20
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1024).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).shuffle(buffer_size=1024).batch(batch_size)

for epoch in range(epochs):
    # Reset the metrics at the start of the epoch
    train_loss.reset_states()
    test_loss.reset_states()
    train_classification_loss.reset_states()
    test_classification_loss.reset_states()

    for images, labels in train_dataset:
        train_step(images, labels)

    for val_images, val_labels in val_dataset:
        test_step(val_images, val_labels)

    print(f'Epoch {epoch + 1}, '
          f'Loss: {train_loss.result()}, '
          f'Classification Loss: {train_classification_loss.result()}, '
          f'Test Loss: {test_loss.result()}, '
          f'Test Classification Loss: {test_classification_loss.result()}')

    if (epoch+1) % 5 == 0:
        val_reconstructed, val_label_re = pcautoencoder(val_images, training=False)  # Ignora la salida de clasificación
        print(val_reconstructed.shape)
        # Asegúrate de que plot_mnist_autoencoder pueda manejar esta salida correctamente
        plot_mnist_autoencoder2(val_images, val_reconstructed,val_labels,np.argmax(val_label_re.numpy(),axis=1))

In [None]:
#compute inner product among basis
o_ = tf.linalg.matmul(pcautoencoder.layers[1].get_weights()[0],pcautoencoder.layers[1].get_weights()[0],transpose_a=True)
plt.pcolormesh(o_.numpy())
plt.colorbar()
plt.ylim([64,0])
plt.show()

In [None]:
plt.pcolormesh(pcautoencoder.layers[1].get_weights()[0])
plt.colorbar()
plt.show()