# Train Signal Classification and Reconstruction Network (SCRNet)
Uses the Tensoflow machine learning library to build and train the Signal Classification and Reconstruction Network (SCRNet) to classify and reconstruct compressed Extended-MNIST <span style="color:blue">***(EMNIST)*** letter</span> images with imposed Poisson noise.

* First Run: **create_comp_noisy_emnist_letters_training_data.ipynb** to create compressed/noisy training data

**Author:** Fabian Santiago  
**Update:** September 18, 2024

Jupyer Notebook Version: 6.5.4  
Python Version: 3.11.5  
TensorFlow Version: 2.16.2

***[Download EMNIST](https://pypi.org/project/emnist/)***


## Import Modules and Libraries

In [None]:
# Import necessary modules and libraries 
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, MaxPooling2D, Conv2D, Reshape, Flatten, Add
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from emnist import extract_training_samples
from emnist import extract_test_samples
import numpy as np
import h5py

## Set Values: Seed, Compression, and Output Dimension

In [None]:
# Set the seed for TensorFlow and Keras
tf.random.set_seed(101)

# Dimension of output, original are 28 x 28
out_dim  = 28 # Do not change

In [None]:
# Create ductionary training-sets
in_signals_by_cmp = dict.fromkeys(['4x4','7x7','14x14','28x28'],[])

## Helper Function Definitions

In [None]:
# SCRNet architecture builder function
def build_SCRNet(in_dim = 7, out_dim = 28, enc_dim = 256, learning_rate = 0.001):

    # Input layer
    inputs = Input(shape=(in_dim**2,))
    
    # Transform input to output signal size
    Dec1 = Dense(out_dim**2, activation='sigmoid')(inputs)

    ##################
    # Classification #
    ##################
    reshpimg = Reshape((out_dim, out_dim, 1))(Dec1)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(reshpimg)
    maxp1 = MaxPooling2D((2, 2))(conv1)
    conv2 = Conv2D(64, (3, 3), activation='relu')(maxp1)
    maxp2 = MaxPooling2D((2, 2))(conv2)
    flatt = Flatten()(maxp2)
    hid1 = Dense(128, activation='relu')(flatt)
    dpred = Dense(26, activation='softmax', name='digit-pred')(hid1)

    ##########################
    # Reconstruction Network #
    ##########################
    # Encoder
    Enc2_hid = Dense(enc_dim, activation='sigmoid')(Dec1)
    Enc2 = Dense(in_dim**2, activation='sigmoid')(Enc2_hid)

    # Add dpred to match Enc2's shape
    dpred_transformed = Dense(in_dim**2, activation='sigmoid')(dpred)
    addenc = Add()([Enc2, dpred_transformed])

    # Decoder
    Dec2_hid = Dense(enc_dim, activation='sigmoid')(addenc)
    reconimg = Dense(out_dim**2, activation='sigmoid', name='reconimg')(Dec2_hid)
    
    # Define the model
    model = Model(inputs=inputs, outputs=[reconimg, dpred])

    # Compile the model
    # Set optimizer and learning rate
    tf_opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # Compile the model
    model.compile(optimizer=tf_opt,
        loss={
            'reconimg': 'mse',  # MSE for reconstruction
            'digit-pred': 'mse'  # MSE for character recognition
        },
        loss_weights={
            'reconimg': 1.00,  # Weight for the reconstruction loss
            'digit-pred': 1.00,      # Weight for character recognition loss
        },
        metrics={'reconimg': 'mse', 'digit-pred': 'categorical_crossentropy'})
    return model

# Define function for fitting model using multiple stages with different epochs and batch sizes
def fit_model(in_model, epochs_list, batch_sizes, input_data, output_data):

    if len(epochs_list) != len(batch_sizes):
        raise ValueError(f"Input lists must be of equal size. epochs_list has length {len(epochs_list)} but batch_sizes has length {len(batch_sizes)}.")
    
    for idx, epoch in enumerate(epochs_list):
        print(f"Training Stage #{idx+1} with {epoch} epochs and a batch size of {batch_sizes[idx]}.\n")
        
        fit_hist = in_model.fit(input_data, output_data, epochs=epoch, batch_size=batch_sizes[idx], validation_split=0.2)
        if idx == 0:
            fit_hist_all = fit_hist.history
        else:
            for key in fit_hist_all.keys():
                fit_hist_all[key] = fit_hist_all[key] + fit_hist.history[key]
    return in_model, fit_hist_all

## Load Training Data
Load traning data or load EMNIST data and create training data if it has not already been created

In [None]:
# Directory name
directory = 'training_data'
    
# Load EMNIST: Used for 4x4, 7x7, 14x14, and 28x28 input architectures
clean_train, train_labels = extract_training_samples('letters')    
clean_test, test_labels   = extract_test_samples('letters')


for cmp_dim in [4,7,14,28]:
    # Set EMNIST dataset name
    dat_file = f'{directory}/emnist_{cmp_dim}x{cmp_dim}_train.h5'

    # Load compressed/noisy training data
    with h5py.File(dat_file, 'r') as dat_file:
        # Load compressed noisy training/test images
        noisy_train = dat_file['noisy_train'][:]
        noisy_test  = dat_file['noisy_test'][:]
    
    # Create Dictionary for compressed/noisy signals
    in_signals = dict(zip(['train','test'],[noisy_train,noisy_test]))
    
    # Add to dictionary
    in_signals_by_cmp[f'{cmp_dim}x{cmp_dim}'] = in_signals
    
# Prepare original mnist data for model training
clean_train  = np.array([matrix.reshape(out_dim**2,) for matrix in clean_train/255])
clean_test   = np.array([matrix.reshape(out_dim**2,) for matrix in clean_test/255])
train_labels = train_labels-1
test_labels  = test_labels-1
train_labels = to_categorical(train_labels, num_classes=26)
test_labels  = to_categorical(test_labels, num_classes=26)

## Compile & Train Models: 7x7, 14x14, and 28x28
### Train 4x4 Compressed and Gaussian Noisy Input Model

In [None]:
# Create & Compile model
scrnet4x4 = build_SCRNet(in_dim=4, out_dim=28)

In [None]:
# Train model
# Set number of epochs (epochs_list) for each batch size (batch_sizes)
scrnet4x4, _ = fit_model(in_model=scrnet4x4, epochs_list=[20,30], batch_sizes =[100,150], input_data = in_signals_by_cmp['4x4']['train'], output_data = [clean_train, train_labels])

### Train 7x7 Compressed and Noisy Input Model

In [None]:
# Create & Compile model
scrnet7x7 = build_SCRNet(in_dim=7, out_dim=28)

In [None]:
# Train model
# Set number of epochs (epochs_list) for each batch size (batch_sizes)
scrnet7x7, _ = fit_model(in_model=scrnet7x7, epochs_list=[20,30], batch_sizes =[100,150], input_data = in_signals_by_cmp['7x7']['train'], output_data = [clean_train, train_labels])

### Train 14x14 Compressed and Noisy Input Model

In [None]:
# Create & Compile model
scrnet14x14 = build_SCRNet(in_dim=14, out_dim=28)

In [None]:
# Train model
# Set number of epochs (epochs_list) for each batch size (batch_sizes)
scrnet14x14, _ = fit_model(in_model=scrnet14x14, epochs_list=[20,30], batch_sizes =[100,150], input_data = in_signals_by_cmp['14x14']['train'], output_data = [clean_train, train_labels])

### Train 28x28 Noisy Input Model

In [None]:
# Create & Compile model
scrnet28x28 = build_SCRNet(in_dim=28, out_dim=28)

In [None]:
# Train model
# Set number of epochs (epochs_list) for each batch size (batch_sizes)
scrnet28x28, _ = fit_model(in_model=scrnet28x28, epochs_list=[20,30], batch_sizes =[100,150], input_data = in_signals_by_cmp['28x28']['train'], output_data = [clean_train, train_labels])

### Save Trained Models

In [None]:
# Save the trained model
scrnet4x4.save('trained_models/scrnet4x4.keras')
scrnet7x7.save('trained_models/scrnet7x7.keras')
scrnet14x14.save('trained_models/scrnet14x14.keras')
scrnet28x28.save('trained_models/scrnet28x28.keras')