In [None]:
'''
Uses the Tensoflow machine learning library to train the Poisson 
autoencoder inverting network (PAIN) to reconstruct compressed 
MNIST images with Poisson noise, and saves the trained model.

Authors: Fabian Santiago
Last Update: August 18, 2024
'''
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from skimage.util.shape import view_as_windows

In [None]:
# Set the seed for TensorFlow and Keras
tf.random.set_seed(101)

In [None]:
# Define sliding window compression
def median_downsampling(in_image, cmp_dim):
    if cmp_dim < 15:
        # Compression dimension to window size
        window_size = 28//cmp_dim
        # Create sliding windows
        windows = view_as_windows(in_image, (window_size, window_size), step=window_size)
        # Calculate the median over each window
        return np.median(windows, axis=(2, 3))
    else:
        return in_image

# Define compression of entries in an array
def down_sample_list(in_array,cmp_dim):
    out_array = np.empty((len(in_array), cmp_dim, cmp_dim))
    
    # Iterate over array
    for idx, image in enumerate(in_array):
        out_array[idx] = median_downsampling(image,cmp_dim)
        
    # Return array containing compressed entries 
    return out_array 

In [None]:
# Load the MNIST Dataset
########################
# Dimension of the compressed/noisy images (width=height)  
# cmp_dim = 4:(for 4x4), 7:(7x7), 14:(14x14), or 28:(28x284)
cmp_dim = 7

# Dimension of output, original are 28 x 28
out_dim  = 28 

# Load MNIST
(clean_train, _), (clean_test, _) = tf.keras.datasets.mnist.load_data()

# Step 1: Compress images using median in sliding window
cmp_train = down_sample_list(clean_train, cmp_dim)
cmp_test  = down_sample_list(clean_test, cmp_dim)

# Step 2: Add Poisson noise to compressed images
noisy_train_ = np.random.poisson(lam=cmp_train)
noisy_test_  = np.random.poisson(lam=cmp_test)
noisy_train  = np.clip(noisy_train_,0,255)
noisy_test   = np.clip(noisy_test_,0,255)

# Step 3: Reshape Input Arrays
noisy_train = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_train/255])
noisy_test  = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_test/255])

# Step 4: Reshape Expected Ouput Arraus
clean_train = np.array([matrix.reshape(28**2,) for matrix in clean_train/255])
clean_test  = np.array([matrix.reshape(28**2,) for matrix in clean_test/255])

In [None]:
# noisy_train[0].shape

In [None]:
# Define Poisson autoencoder inverting network (PAIN)
####################################################

# Define model architecture
def build_PAIN(in_dim, out_dim, enc_dim = 256):
    # Input layer
    inputs = tf.keras.Input(shape=(in_dim**2,))

    # First Decoder
    Dec1 = tf.keras.layers.Dense(out_dim**2, activation='sigmoid')(inputs)

    # Encoder
    Enc2_hidden = tf.keras.layers.Dense(enc_dim, activation='sigmoid')(Dec1)
    Enc2 = tf.keras.layers.Dense(in_dim**2, activation='sigmoid')(Enc2_hidden)

    # Second Decoder
    Dec2_hidden = tf.keras.layers.Dense(enc_dim, activation='sigmoid')(Enc2)
    outputs = tf.keras.layers.Dense(out_dim**2, activation='sigmoid')(Dec2_hidden)

    # Define the model
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    return model

In [None]:
# Create & Compile the PAIN model
PAIN = build_PAIN(in_dim=cmp_dim, out_dim=out_dim)

# Create an RMSProp optimizer with a specific learning rate
RMSp = tf.keras.optimizers.RMSprop(learning_rate=0.05)

# Compile the model
PAIN.compile(optimizer=RMSp, loss='mean_squared_error')

In [None]:
# Train model saving fitting history
fit_history = PAIN.fit(noisy_train, clean_train, epochs=250, batch_size=250,validation_data=(noisy_test, clean_test))

In [None]:
# # Save training and validation loss
# ###################################

# # Model directory 
# directory = "TrainLoss"

# # If directory does not exist, create it
# if not os.path.exists(directory):
#     os.makedirs(directory)
    
# # Get training loss values
# loss_values = fit_history.history['loss']

# # Get validation loss values
# val_loss_values = fit_history.history.get('val_loss')

# df = pd.DataFrame({
#     'epoch': range(1, len(loss_values) + 1),
#     'loss': loss_values,
#     'val_loss': val_loss_values
# })
# df.to_csv(f'TrainLoss/PAIN{cmp_dim}x{cmp_dim}_loss.csv', index=False)

In [None]:
# # Save Trained Model
# ####################

# # Model directory 
# directory = "TrainedModels"

# # If the directory does not exist, create it
# if not os.path.exists(directory):
#     os.makedirs(directory)

# # Save trained model
# PAIN.save(f'Trained_Models/PAIN{cmp_dim}x{cmp_dim}.keras')

In [None]:
# Apply PAIN to all training inputs
pred_train_out = PAIN.predict(noisy_train)

In [None]:
# Print examples: training set
################################

# Create a 3 by 4 subplot handle
fig, axes = plt.subplots(3, 4, figsize=(10, 8))
axes      = axes.flatten()

# Add title
fig.suptitle('Application of PAIN Architecture\n(MNIST Training Set)',fontsize=20,fontweight='bold', fontfamily='serif')

# Shift window through training dataset 
sft_idx = 0

# Plot noisy and decompressed realizations in each subplot
for idx in range(4):
    # Plot noisy
    axes[idx].imshow(noisy_train[idx+sft_idx].reshape(cmp_dim,cmp_dim),cmap='gray')
    axes[idx].set_xticks([]) # Remove xticklabels
    axes[idx].set_yticks([]) # Remove yticklabels
    axes[idx].set_xlabel('⇩',fontdict={'fontsize': 25, 'fontweight': 'bold', 'fontfamily': 'serif', 'color':'blue'})
    
    # Plot decompressed with PAIN
    axes[idx+4].imshow(pred_train_out[idx+sft_idx].reshape(28,28),cmap='gray')
    axes[idx+4].set_xticks([]) # Remove xticklabels
    axes[idx+4].set_yticks([]) # Remove yticklabels
    
    # Plot original
    axes[idx+8].imshow(clean_train[idx+sft_idx].reshape(28,28),cmap='gray')
    axes[idx+8].set_xticks([]) # Remove xticklabels
    axes[idx+8].set_yticks([]) # Remove yticklabels
    
# Set ylables 
# axes[0].set_ylabel(f'{cmp_dim} x {cmp_dim}\n(Pre-Input)\nCompressed',fontdict={'fontsize': 12, 'fontfamily': 'serif'})
axes[0].set_ylabel(f'{cmp_dim} x {cmp_dim}\n(Input)\nCompressed & Noisy',fontdict={'fontsize': 12, 'fontfamily': 'serif', 'color':'blue'})
axes[4].set_ylabel('28 x 28\n(Output)\nPAIN Reconstruction',fontdict={'fontsize': 12, 'fontfamily': 'serif', 'color':'blue'})
axes[8].set_ylabel('28 x 28\n(Original)\nMNIST',fontdict={'fontsize': 12, 'fontfamily': 'serif'})

# Adjust layout to decrease padding between subplots
plt.subplots_adjust(wspace=0.1, hspace=0.25)

# Display the figure
plt.show()

In [None]:
# Apply PAIN to all validation inputs
pred_test_out = PAIN.predict(noisy_test)

In [None]:
# Print examples: validation set
################################

# Create a 3 by 4 subplot handle
fig, axes = plt.subplots(3, 4, figsize=(10, 8))
axes      = axes.flatten()

# Add title
fig.suptitle('Application of PAIN Architecture\n(MNIST Validation Set)',fontsize=20,fontweight='bold', fontfamily='serif')

# Shift window through validation dataset 
sft_idx = 0

# Plot noisy and decompressed realizations in each subplot
for idx in range(4):
    # Plot noisy
    axes[idx].imshow(noisy_test[idx+sft_idx].reshape(cmp_dim,cmp_dim),cmap='gray')
    axes[idx].set_xticks([]) # Remove xticklabels
    axes[idx].set_yticks([]) # Remove yticklabels
    axes[idx].set_xlabel('⇩',fontdict={'fontsize': 25, 'fontweight': 'bold', 'fontfamily': 'serif', 'color':'blue'})
    
    # Plot decompressed with PAIN
    axes[idx+4].imshow(pred_test_out[idx+sft_idx].reshape(28,28),cmap='gray')
    axes[idx+4].set_xticks([]) # Remove xticklabels
    axes[idx+4].set_yticks([]) # Remove yticklabels
    
    # Plot original
    axes[idx+8].imshow(clean_test[idx+sft_idx].reshape(28,28),cmap='gray')
    axes[idx+8].set_xticks([]) # Remove xticklabels
    axes[idx+8].set_yticks([]) # Remove yticklabels
    
# Set ylables 
# axes[0].set_ylabel(f'{cmp_dim} x {cmp_dim}\n(Pre-Input)\nCompressed',fontdict={'fontsize': 12, 'fontfamily': 'serif'})
axes[0].set_ylabel(f'{cmp_dim} x {cmp_dim}\n(Input)\nCompressed & Noisy',fontdict={'fontsize': 12, 'fontfamily': 'serif', 'color':'blue'})
axes[4].set_ylabel('28 x 28\n(Output)\nPAIN Reconstruction',fontdict={'fontsize': 12, 'fontfamily': 'serif', 'color':'blue'})
axes[8].set_ylabel('28 x 28\n(Original)\nMNIST',fontdict={'fontsize': 12, 'fontfamily': 'serif'})

# Adjust layout to decrease padding between subplots
plt.subplots_adjust(wspace=0.1, hspace=0.25)

# Display the figure
plt.show()