In [None]:
import torch                                            # Main deep learning library.
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms            # Handles datasets and transfomations.
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter       # Track training loss

2025-03-11 15:25:19.215112: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-11 15:25:19.232877: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-11 15:25:19.237925: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-11 15:25:19.251331: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
# Define Variational Autoencoder (VAE) Model 

### See Conditional VAE that introduces and additional conditioning variable, tipycally labels, allowing for controled generation.

'''
This class defines the VAE model. It consists of an encoder and a decoder. 
The encoder compresses the input image into a latent space, and the decoder reconstructs 
the image from the latent space.
'''

class VAE(nn.Module):
    def __init__(self, latent_dim=128):                             # latent_dim=128: The size of latent representation.
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder                                                   # Features extraction: 4 concolutional layers extract hierarchical image features,
        self.encoder = nn.Sequential(                               # each layer halves the images size while increasing feature maps.
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.ReLU()
        )
        self.fc_mu = nn.Linear(256 * 16 * 16, latent_dim)           # Latent Space Mwean and Variance: fc_mu and fc_logvar predict mean and          
        self.fc_logvar = nn.Linear(256 * 16 * 16, latent_dim)       # log-variance for the latent distribution.

        # Decoder                                                   # Image reconstruction: Fully connected layer (fc_dec) expands latent vector.
        self.fc_dec = nn.Linear(latent_dim, 256 * 16 * 16)          # Tranpose convolutions upsample back to 256x256 image.
        self.decoder = nn.Sequential(                               # Final layer uses Sigmoid() to constrain pixel values to [0, 1].
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1), nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):                           # Reparametrization trick, allows backpropagation the through the 
        std = torch.exp(0.5 * logvar)                               # stocastic sampling process. Converts mean and log-variance into a latent vactor.
        eps = torch.randn_like(std)                                 # Allows backpropagation through a stochastic operation.
        return mu + eps * std

    def forward(self, x):                                           # Forward pass of the VAE. It encodes the input, samples from latent space, and 
        x = self.encoder(x).view(x.size(0), -1)                     # decodes the sample back to the image space. 1) Encodes image --> latent vector(mu, logvar).
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)               # 2) Samples z using reparametrization trick. 3) Decodes z baxk to an image.
        z = self.reparameterize(mu, logvar)
        x = self.fc_dec(z).view(x.size(0), 256, 16, 16)
        x = self.decoder(x)
        return x, mu, logvar

In [None]:
# Training Function
'''
This function trains the VAE. It initializes the model, optimizer, and TensorBoard writer. 
It then iterates over the dataset for a specified number of epochs, computes the loss, performs 
backpropagation, and updates the model parameters. The loss is logged to TensorBoard, 
and the trained model is saved.
'''

def train_vae(epochs=10, batch_size=32, latent_dim=128):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")           # Uses GPU if avalaible.
    dataloader = get_dataloader(batch_size)                                         # Load dataset using DataLoader.
    vae = VAE(latent_dim).to(device)                                                # Initializes model and Adam optimizer.
    optimizer = optim.Adam(vae.parameters(), lr=1e-4)                               # Creates TensorBoard logger.
    writer = SummaryWriter("runs/vae_experiment")                                   

    for epoch in range(epochs):
        total_loss = 0
        for images, _ in dataloader:
            images = images.to(device)
            optimizer.zero_grad()
            x_recon, mu, logvar = vae(images)
            loss = vae_loss(x_recon, images, mu, logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader.dataset)                             # Log loss and save model. Save loss TensorBoard and save trained model for later use.
        writer.add_scalar('Loss/train', avg_loss, epoch)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
    
    writer.close()
    torch.save(vae.state_dict(), "vae_256x256.pth")
    return vae

In [None]:
# Generate and Visualize Samples

'''
This function generates and visualizes images from the trained VAE. 
It samples random latent vectors, decodes them into images, and 
displays the images using Matplotlib.
'''

def generate_images(vae, num_images=5, latent_dim=128):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae.to(device)
    vae.eval()
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim).to(device)                                  # Samples random latent vectors (z).
        generated_images = vae.decoder(vae.fc_dec(z).view(num_images, 256, 16, 16))         # Uses decoder to generate images from z.
    
    generated_images = generated_images.cpu().numpy().transpose(0, 2, 3, 1)                 # Converts tensor tu NumPy forma for visualization.
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))                               
    for i, ax in enumerate(axes):
        ax.imshow(generated_images[i])
        ax.axis("off")
    plt.show()

# Training Example
# vae = train_vae(epochs=10)
# generate_images(vae)


In [None]:
### VAE jax/flax/neural network implementation

import jax                                  # JAX and JAX NumPy: used for GPU/TPU-accelerated numerical computations.
import jax.numpy as jnp
import flax.linen as nn                     # Provides neural network layers (nn.Module)
from flax import nnx   #### averiguar que es el modulo nnx
import optax                                # Implements optimization algorithms (Adam optimizer)
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.data import Dataset
from jax.scipy.stats import norm
import jax.devices
import glob
import os

# Ensure JAX uses GPU/TPU
jax.config.update("jax_platform_name", "gpu")           # Use "tpu" for TPU support

### Define Encoder
class Encoder(nn.Module):
    latent_dim: int

    @nn.compact                                                         # Convolutional layers extract images features, reducing spatial dimensions.
    def __call__(self, x):
        x = nn.Conv(32, (4, 4), strides=(2, 2), padding='SAME')(x)      # Fully connected layers transform the features into two vectors:
        x = nn.relu(x)                                                  # mean: center of the latent distributio, logvar: spread of the latent space
        x = nn.Conv(64, (4, 4), strides=(2, 2), padding='SAME')(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        mean = nn.Dense(self.latent_dim)(x)
        logvar = nn.Dense(self.latent_dim)(x)
        return mean, logvar

### Reparametrization trick                             # Converts mean and logvar into a sample from latent space
def reparameterize(rng, mean, logvar):                  # eps is a random noise from a normal distribution                                
    std = jnp.exp(0.5 * logvar)                         # ensures gradients can pass through this sampling step
    eps = jax.random.normal(rng, std.shape)
    return mean + eps *std

### Define Decoder
class Decoder(nn.Module):                               # Converts the latent vector (z) ---> reconstructed image
                                                        # Uses transposed convolutions (unsampling) to restore the spatial resolution
    @nn.compact                                         # Uses sigmoid() activation to output pixed values
    def __call__(self, z):
        x = nn.Dense(128)(z)
        x = nn.relu(x)
        x = nn.Dense(64 * 64 * 64)(x)
        x = nn.relu(x)
        x = x.reshape((-1, 64, 64, 64))
        x = nn.ConvTranspose(64, (4, 4), strides=(2, 2), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.ConvTranspose(32, (4, 4), strides=(2, 2), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.ConvTranspose(3, (4, 4), strides=(2, 2), padding='SAME')(x)
        return nn.sigmoid(x)  # Normalize output
    
### Define VAE                                          # Combines encoder and decoder
class VAE(nn.Module):                                   # Takes input images (x), encodes it, samples z, and reconstructs the image
    latent_dim: int

    def setup(self):
        self.encoder = Encoder(self.latent_dim)
        self.decoder = Decoder()

    def __call__(self, x, rng):
        mean, logvar = self.encoder(x)
        z = reparameterize(rng, mean, logvar)
        recon_x = self.decoder(z)
        return recon_x, mean, logvar
    
# Loss Function                                         # Recon_loss: measures how different the generated image is from the original
def vae_loss(model, params, batch, rng):                # KL divergence: enocurages z to follow a standard normal distribution
    recon_x, mean, logvar = model.apply(params, batch, rng)
    recon_loss = jnp.mean((batch - recon_x) ** 2)
    kl_loss = -0.5 * jnp.mean(1 + logvar - mean**2 - jnp.exp(logvar))
    return recon_loss + kl_loss

# Optimizer                                             # Uses the Adam optimizer with a learning rate of 0.001
optimizer = optax.adam(1e-3)            

# Training Step                                         # Computes gradients using jax.value_and_grad(). Ensures runs on GPU/TPU with jax.jit
@jax.jit                                                # Updates model parameters using apply_gradients()
def train_step(state, batch, rng):                      # jax.jit accelerates execution by compiling the function
    loss, grads = jax.value_and_grad(vae_loss)(state.params, batch, rng)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Load and Preprocess Personal Images (NumPy Format)            # Uses batching and prefetching for efficeint training
def load_personal_dataset(image_folder, batch_size=32):
    image_paths = glob(os.path.join(image_folder, "*.npy"))     
    
    def process_image(image_path):
        image = np.load(image_path)
        image = image.astype(np.float32) / 255.0                # Normalize pixel values to [0,1]
        return image
    
    images = [process_image(path) for path in image_paths]
    ds = Dataset.from_tensor_slices(np.array(images))
    ds = ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)   # Optimizing dataset loading with TensorFlow's AUTOTUNE.
    return ds

# Training Loop (Optimized for GPU/TPU)                                 # Initializes model parameters
def train_vae(model, dataset, epochs=10):                               # Iterates over epochs and dataset
    rng = jax.random.PRNGKey(0)                                         # Updates model weights using train_step
    state = train_state.TrainState.create(apply_fn=model.apply, params=model.init(rng, jnp.ones([1, 256, 256, 3])), tx=optimizer)
    for epoch in range(epochs):
        for batch in dataset:
            rng, sub_rng = jax.random.split(rng)
            state, loss = train_step(state, batch, sub_rng)
        print(f'Epoch {epoch}, Loss: {loss}')
    return state

# Visualizing Reconstructions                                           # Visualizes original vs reconstructed images
def visualize_reconstructions(model, state, dataset):                   # Plots original images on top row, reconstructions in bottom row
    rng = jax.random.PRNGKey(1)
    for batch in dataset.take(1):
        batch = np.array(batch)
        recon_x, _, _ = model.apply(state.params, batch, rng)
        fig, axes = plt.subplots(2, len(batch), figsize=(15, 5))
        for i in range(len(batch)):
            axes[0, i].imshow(batch[i])
            axes[0, i].axis('off')
            axes[1, i].imshow(recon_x[i])
            axes[1, i].axis('off')
        plt.show()

# Example Usage
vae = VAE(latent_dim=128)
dataset = load_personal_dataset("path/to/your/numpy/images")
trained_vae = train_vae(vae, dataset)
visualize_reconstructions(vae, trained_vae, dataset)


#### Falta hiperparameter tuning 


In [None]:
class VAE(nn.Module):
    latent_dim: int

    def setup(self):
        self.encoder = nn.Sequential([
            nn.Conv(32, (4, 4), strides=2), nn.relu,
            nn.Conv(64, (4, 4), strides=2), nn.relu,
            nn.Conv(128, (4, 4), strides=2), nn.relu,
            nn.Conv(256, (4, 4), strides=2), nn.relu,
        ])
        self.fc_mu = nn.Dense(self.latent_dim)
        self.fc_logvar = nn.Dense(self.latent_dim)
        self.fc_dec = nn.Dense(256 * 16 * 16)
        self.decoder = nn.Sequential([
            nn.ConvTranspose(128, (4, 4), strides=2), nn.relu,
            nn.ConvTranspose(64, (4, 4), strides=2), nn.relu,
            nn.ConvTranspose(32, (4, 4), strides=2), nn.relu,
            nn.ConvTranspose(3, (4, 4), strides=2), nn.sigmoid,
        ])

    def __call__(self, x):
        x = self.encoder(x)
        x = x.reshape((x.shape[0], -1))
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)
        std = jnp.exp(0.5 * logvar)
        eps = jax.random.normal(jax.random.PRNGKey(0), std.shape)
        z = mu + eps * std
        x = self.fc_dec(z).reshape((-1, 256, 16, 16))
        return self.decoder(x), mu, logvar



### TensorFlow (Keras) implementation

import tensorflow as tf
from tensorflow.keras import layers

class VAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(32, (4, 4), strides=2, activation='relu'),
            layers.Conv2D(64, (4, 4), strides=2, activation='relu'),
            layers.Conv2D(128, (4, 4), strides=2, activation='relu'),
            layers.Conv2D(256, (4, 4), strides=2, activation='relu'),
            layers.Flatten()
        ])
        self.fc_mu = layers.Dense(latent_dim)
        self.fc_logvar = layers.Dense(latent_dim)
        self.fc_dec = layers.Dense(256 * 16 * 16)
        self.decoder = tf.keras.Sequential([
            layers.Conv2DTranspose(128, (4, 4), strides=2, activation='relu'),
            layers.Conv2DTranspose(64, (4, 4), strides=2, activation='relu'),
            layers.Conv2DTranspose(32, (4, 4), strides=2, activation='relu'),
            layers.Conv2DTranspose(3, (4, 4), strides=2, activation='sigmoid'),
        ])

    def call(self, x):
        x = self.encoder(x)
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)
        std = tf.exp(0.5 * logvar)
        eps = tf.random.normal(std.shape)
        z = mu + eps * std
        x = tf.reshape(self.fc_dec(z), (-1, 256, 16, 16))
        return self.decoder(x), mu, logvar