In [None]:
import sys
import numpy as np
import tensorflow as tf
from typing import Optional
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Add the path to the custom library to the system path
sys.path.append('..')

# Import the module from the custom library
from src import utils
from src import Model
from src import layers
from src import callbacks
from src import optimizers
from src import activations
from src import loss_functions

### Constants and hyperparameters

In [None]:
n_samples = 1000 # Number of samples to generate
train_test_split_pct = 0.2 # Percentage of samples to use for testing
train_valid_split = 0.2 # Percentage of samples to use for validation
learning_rate = 2e-04 # Learning rate for the optimizer
batch_size = 512 # Number of samples to use for each batch
epochs = 300 # Number of epochs to train the model
seed = 1234 # Seed for reproducibility
data_noise = 0.15 # Noise to add to the data

In [None]:
# Set the seed for reproducibility
np.random.seed(seed)

### Data loading

In [None]:
# Load the dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() # type: ignore

# Add a channel dimension to the images
X_train = np.expand_dims(X_train, axis=-1) # type: ignore
X_test = np.expand_dims(X_test, axis=-1) # type: ignore

# Extract the number of classes in the dataset
num_classes = len(np.unique(y_train)) # type: ignore

# Split the training set into training and validation sets
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=train_valid_split, random_state=seed) # type: ignore

# Print the dataset information
print("Number of classes:", num_classes)
print('Training set:', X_train.shape, y_train.shape)
print('Validation set:', X_valid.shape, y_valid.shape)
print('Testing set:', X_test.shape, y_test.shape)

### Data processing

In [None]:
# Normalization
def normalize(X: np.ndarray) -> np.ndarray:
    """
    Normalize the input data by dividing by the maximum value in the training set.
    
    Parameters:
    - X (np.ndarray): The input data to normalize

    Returns:
    - np.ndarray: The normalized input data
    """
    
    # Normalize the input data
    return X / 255.0

# Normalize the input data
X_train = normalize(X_train)
X_valid = normalize(X_valid)
X_test = normalize(X_test)

### Data visualization

In [None]:
def plot_samples(samples: list[np.ndarray], labels: Optional[list[np.ndarray]] = None) -> None:
    """
    Plot the samples in a grid.
    
    Parameters:
    - samples (list[np.ndarray]): The samples to plot
    - labels (list[np.ndarray]): The labels of the samples
    """
        
    # Plot the samples in a grid
    _, axes = plt.subplots(1, len(samples), figsize=(20, 5))
    
    # Ensure axes is always iterable
    if len(samples) == 1:
        axes = [axes]
        
    # Iterate through the samples
    for i, ax in enumerate(axes):
        ax.imshow(samples[i], cmap='gray')
        ax.axis('off')
        if isinstance(labels, np.ndarray): 
            ax.set_title(f'Label: {labels[i]}')
        
# Plot the first 10 samples
plot_samples(list(X_train[:10]), y_train[:10])

In [None]:
# Define the dimension of the latent space
latent_dim = 32

# Create the encoder
encoder = Model(
    name = 'Encoder',
    modules = [
        layers.Input(shape=X_train.shape[1:]),
        layers.Conv2D(num_filters=32, kernel_size=(3, 3), activation=activations.ReLU(), padding='same'),
        layers.MaxPool2D(size=(2, 2), padding='same'),
        layers.Conv2D(num_filters=64, kernel_size=(3, 3), activation=activations.ReLU(), padding='same'),
        layers.MaxPool2D(size=(2, 2), padding='same'),
        layers.Flatten(),
        layers.Dense(num_units=latent_dim, activation=activations.ReLU(), name='Latent'),
    ]
)

# Create the decoder
decoder = Model(
    name = 'Decoder',
    modules = [
        layers.Input(shape=(latent_dim,), name='Latent'),
        layers.Dense(num_units=7 * 7 * 64, activation=activations.ReLU()),
        layers.Reshape(shape=(7, 7, 64)),
        layers.Conv2D(num_filters=64, kernel_size=(3, 3), activation=activations.ReLU(), padding='same'),
        layers.UpSampling2D(size=(2, 2)),
        layers.Conv2D(num_filters=32, kernel_size=(3, 3), activation=activations.ReLU(), padding='same'),
        layers.UpSampling2D(size=(2, 2)),
        layers.Conv2D(num_filters=1, kernel_size=(3, 3), activation=activations.Sigmoid(), padding='same'),
    ]
)

# Create the auto-encoder
autoencoder = Model(
    name = 'Autoencoder',
    modules = [
        encoder,
        decoder
    ]
)

# Initialize the optimizer
optimizer = optimizers.Adam(learning_rate)

# Initialize the loss function
loss_fn = loss_functions.BinaryCrossEntropy()

In [None]:
# Call the model with a first batch to initialize the weights
# This is not necessary, but it is useful to know the input size
autoencoder(X_train[:batch_size]);

In [None]:
# Display the encoder summary
encoder.summary()

# Display the decoder summary
decoder.summary()

# Display the model summary
autoencoder.summary()

### Training the model

In [None]:
history = autoencoder.fit(
    X_train = X_train,
    y_train = X_train,
    optimizer = optimizer,
    loss_fn = loss_fn,
    X_valid = X_valid,
    y_valid = X_valid,
    batch_size = batch_size,
    epochs = epochs,
    callbacks = [callbacks.EarlyStopping(monitor='val_loss', patience=5)]
)

In [None]:
# Plot the training and validation loss
utils.plot_history(autoencoder.history["loss"], autoencoder.history["val_loss"], "Training and Validation Loss", "Epoch", "Loss")

### Inference

In [None]:
# Compute the statistics of the latent space
latent_train = encoder(X_train)
latent_mean = np.mean(latent_train, axis=0)
latent_std = np.std(latent_train, axis=0)

In [None]:
def sample_latent(num_samples: int = 1) -> np.ndarray:
    """
    Function to sample n random vectors from the latent space.
    
    Parameters:
    - num_samples (int): The number of samples to generate
    
    Returns:
    - np.ndarray: The samples generated from the latent space
    """
    
    # Sample the latent space
    return np.random.normal(latent_mean, latent_std, (num_samples, latent_dim))

In [None]:
# Sample one random vector from the latent space
samples = sample_latent(num_samples=10)

# Decode the sample to generate a new image
decoded_imges = decoder(samples)

# Reshape the image to the original shape
decoded_imges = decoded_imges.reshape(-1, 28, 28, 1)

# Plot the generated image
plot_samples(list(decoded_imges))