In [None]:
# SRSdenoiser: Multiple branch convolutional neural network with final residual step for baseline subtraction and denoising of SRS spectra
# Minimal code example for training and testing the SRSdenoiser NN architecture

### Load requirements

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import array
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from scipy.stats import norm
from scipy.stats import moment 
from tensorflow.keras.callbacks import History
import skimage
from skimage import metrics as sm

from scipy import signal
from scipy import io


from scipy.signal import spline_filter
from plotly.subplots import make_subplots
import plotly.graph_objects as go 

from scipy.signal import find_peaks 
import itertools


import wandb

from tensorflow.keras.models import load_model


import Models.customLossLib as cll
import Models.metrics_2 as mm
import Models.models as net



tole=1.1  ### Set the tolerance for the PeakFinder defined in Models.metrics_2



### Define methods

In [None]:
### Load dataset 

def data_load(pathToData,dataset, shuf=True, seed=1990, prep=1,ShowPlot=True):
    ''' 
    Load and preprocess the datasets
    Inputs:
        pathToData: path to the file to load
        dataset: name of the file to load
        shuf: if True enables shuffling of the dataset
        seed: seed of the random shuffling, set to 1990 to reproduce the training/test split used in the paper
        prep: can be either 0 or 1, allows to choose between two preprocessing routines
        ShowPlot: if True plot examples of raw and GT spectra from the loaded dataset and print the sizes of the output variables
    Outputs:
        normX: factor to normalize the test data before inference. It is calculate as the norm of the training data
        nbin: number of points in the input spectra (default for the HN and LN datasets is 801)
        X_test, X_train: test and train tf tensors for the raw data
        Y_test, Y_train: test and train tf tensors for the GT
        GT_freq, GT_freq_train, GT_freq_test: arrays containg the spectral positions of the Raman bands in the GT
    '''    
    
    import numpy as np
    ###Load Dataset
    
    Sn = np.loadtxt(pathToData+'Snoise_'+dataset+'.txt')
    Sc = np.loadtxt(pathToData+'Sclean_'+dataset+'.txt')

    Sc = Sc.transpose()
    Sn = Sn.transpose()


    nbin=Sc.shape[1]


    ###Shuffling
    shuf = True
    if shuf:
        rng = np.random.RandomState(1990)
        shuffler = rng.permutation(len(Sn))
        Sn = Sn[shuffler]
        Sc = Sc[shuffler]


    ###Preprocessing: use prep=1 for the same normalization used in the paper
    if ShowPlot:
        plt.figure(figsize=(20,6))
        for idx in range(6):
            plt.subplot(1,6,idx+1)
            plt.plot(Sn[idx], '-r')
            plt.plot(Sc[idx], '-b')
        plt.tight_layout()  


    ev = int(Sc.shape[0]*0.8)
    X_train = Sn[:ev]
    X_test = Sn[ev:]

    if prep==0:

        maxX = np.max((X_train),axis=-1)
        minX = np.min((X_train),axis=-1)
        for i in range(len(maxX)):
            X_train[i,:]= (X_train[i,:]-minX[i]) / (maxX[i]-minX[i])

        maxXt = np.max((X_test),axis=-1)
        minXt = np.min((X_test),axis=-1)
        for i in range(len(maxXt)):
            X_test[i,:]= (X_test[i,:]-minXt[i]) / (maxXt[i]-minXt[i])

    elif prep==1:
        stX=np.std(X_train,axis=-1)
        meanX=np.mean(X_train,axis=-1)
        for i in range(len(stX)):
            X_train[i,:]= (X_train[i,:]-meanX[i]) / stX[i] 

        stXt=np.std(X_test,axis=-1)
        meanXt=np.mean(X_test,axis=-1)
        for i in range(len(stXt)):
            X_test[i,:]= (X_test[i,:]-meanXt[i]) / stXt[i]


    normX = np.max(np.abs(X_train))

    X_train /= normX
    X_test /= normX

    X_train = X_train.reshape((X_train.shape[0],X_train.shape[1],1))
    X_test = X_test.reshape((X_test.shape[0],X_test.shape[1],1))

    Y_train = Sc[:ev]
    Y_test = Sc[ev:]


    if prep==0:

        for i in range(len(maxX)):
            Y_train[i,:]= (Y_train[i,:]) / (maxX[i]-minX[i])

        for i in range(len(maxXt)):
            Y_test[i,:]= (Y_test[i,:]) / (maxXt[i]-minXt[i])

    elif prep==1:

        for i in range(len(stX)):
            Y_train[i,:]= Y_train[i,:] / stX[i] 

        for i in range(len(stXt)):
            Y_test[i,:]= Y_test[i,:] / stXt[i]




    Y_train /= normX
    Y_test /= normX

    Y_train = Y_train.reshape((Y_train.shape[0],Y_train.shape[1],1))
    Y_test = Y_test.reshape((Y_test.shape[0],Y_test.shape[1],1))
    
    if ShowPlot:
        print(X_train.shape)
        print(Y_train.shape)
        print(X_test.shape)
        print(Y_test.shape)

    print('Norm factor is '+str(normX)+ '\nUse this value to normalize input data during test and inference phase')

    
    

    ###Load GT frequencies
    while True:
        try:
            str_name = pathToData+dataset+'_GTfrequencies.mat'

            GT_freq = io.loadmat(str_name)

            GT_freq = GT_freq["freq"]
            GT_freq = np.squeeze(GT_freq)
            if shuf:
                GT_freq = GT_freq[shuffler]

            GT_freq_train = GT_freq[:ev]    
            GT_freq_test = GT_freq[ev:]
            break
        except:
            print("GT_freq not loaded...")
            GT_freq_train = np.array(np.zeros((len(X_train),2)) ) 
            GT_freq_test = np.array(np.zeros((len(X_test),2)) )
            break
    
    
    return normX, nbin, X_test,X_train,Y_test,Y_train,GT_freq, GT_freq_train, GT_freq_test

### Logger

In [None]:
### Set to True to use WandB as a logger

logger = False

if logger:
    ###Install wandb: uncomment if need to install
    #!pip install wandb 
    
    ###Set up logger
    import wandb
    from wandb.keras import WandbCallback
    wandb.login()

In [None]:
###Custom callback to log plots of test set and metrics at selected epochs during training

from plotly.subplots import make_subplots
import plotly.graph_objects as go
class CustomCallback(keras.callbacks.Callback):
    def __init__(self, model, x_test, y_test,x_test_all, y_test_all, gt_freq_test):

        self.model = model
        self.x_test = x_test
        self.y_test = y_test
        self.x_test_all = x_test_all
        self.y_test_all = y_test_all
        self.gt_freq_test = gt_freq_test
        


    def on_epoch_end(self, epoch, logs={}):

        if epoch%5==0:
            y_pred = self.model.predict(self.x_test)
            nbin=self.y_test.shape[1]
            print("Logging metrics and plots of predicted test set...\n")
            
            str_id='Test set: gt,pred'
            str_id2='Experimental data'
            xs = [i for i in range(nbin)]
            fig = make_subplots(rows=1, cols=self.y_test.shape[0])
            
            len_x = self.y_test.shape[0]
            len_x_all= self.y_test_all.shape[0]
            for jj in range(len_x):


                fig.add_trace(
                    go.Scatter(x=xs, y=self.y_test[jj,:].reshape(nbin,),name = 'GT',line = dict(color = 'rgb(22, 96, 167)', width = 2, dash = 'dot')),
                    row=1, col=jj+1
                )

                fig.add_trace(
                    go.Scatter(x=xs, y=y_pred[jj,:].reshape(nbin,),name = 'Predicted',line = dict(color = 'rgb(205, 12, 24)', width = 1)),
                    row=1, col=jj+1
                )

                fig.update_layout(height=600, width=800)

                wandb.log({str_id : fig}, commit=False)
                

                
            #Monitor selected metrics
            
            Y=self.model.predict(self.x_test_all).reshape(len_x_all,nbin)
            X=self.x_test_all.reshape(len_x_all,nbin)
            GT=self.y_test_all.reshape(len_x_all,nbin)

            metrics_NN=mm.compute_metrics2(Y,X,GT,self.gt_freq_test)
            wandb.log({'ssim_mean':np.mean(metrics_NN['ssim']),'ssim_std':np.std(metrics_NN['ssim']),
                      'precision_mean':np.mean(metrics_NN['precision']),'precision_std':np.std(metrics_NN['precision']),
                      },commit=True)

           

## Training from scratch

In [None]:
# Load Dataset
### Select folder and filename of the dataset

dataset='HighNoise'
pathToData='Datasets/HighNoise/'

normFactor, nbin, X_test,X_train,Y_test,Y_train, GT_freq, GT_freq_train, GT_freq_test = data_load(pathToData,dataset, shuf=True,seed=1990, prep=1,ShowPlot=1)

# Define hyperparameters

ker=[21]
max_ker=[88]
bn=False
nlayers=6
nDense=0
nConv=4

Nparam=10000


# Instantiate keras model
keras.backend.clear_session()
k_size=np.linspace(5,max_ker[0],ker[0]).astype(int)
channels = ((-1+np.sqrt(1+4* k_size * Nparam)) / (2 *k_size)).astype(int)
channels=channels.tolist()
k_size=k_size.tolist()

model=net.Multi_parallelKernel_CNN_modified3((nbin,1),nlayers=nlayers,channels=channels,k_size=k_size,bn=bn,nDense=nDense,nConvFinal=nConv)  
model_name = 'Multi_parallelKernel_CNN_modified3'



model.summary()

history = []
history = History()

# Define training strategy: training is perfomed in two phases. Initially, only the reconstruction loss is optimized for epochs0 epochs, the the grad loss term is switched on and the full loss is optimized for epochs1 epochs 

###custom loss parameters
Grad_weight=.6 ###relative weight between the reconstruction and grad terms of the loss


### Global weight to avoid abrupt changes of the loss when swtiching on the grad term: to be adjusted depending on the dataset
Weight=8.5 #High noise with L2 on grad

#### For other datasets and choice of L2/L1, use:
### Weight=8.5 #High noise L2
### Weight=0.05 #High noise with L1 on grad



### Choose between L1 and L2 norm on the grad term
lnorm=2
if lnorm==1:
    custom_mse,grad_loss,mse_loss=cll.loss_def_norm1_modifiedL1(Grad_weight=Grad_weight,Weight=Weight)
else:
    custom_mse,grad_loss,mse_loss=cll.loss_def_norm1(Grad_weight=Grad_weight,Weight=Weight)

### lr
lr_initial=0.001
lr_refine=0.0005
loss0 = mse_loss
loss1 = custom_mse

### Fix the seed during hyperparameter sweeps
fix_seed=False
if fix_seed:
    #Set the seed
    from numpy.random import seed
    seed(1)
    tf.random.set_seed(2)

    import os
    os.environ['PYTHONHASHSEED'] = '0'

if logger:
# Configure Logger run (Here we use wandb)
### Change the wandb project and entity and all the parameters accordingly
    run = wandb.init(project='SRSdenoiser', entity='',group='', job_type='DatasetEval',config={     
                            "learning_rate0": lr_initial,
                            "epochs0": 25,
                            "batch_size0": 32,
                            "loss_function0": "mse",
                            "learning_rate1": lr_refine,
                            "epochs1": 60,
                            "batch_size1": 32,
                            "loss_function1" : "custom_loss",
                            "Grad_weight" : Grad_weight,
                            "Weight" : Weight,   
                            "architecture": model_name,
                            "number_of_layers":nlayers,
                            "N_denseLayers":nDense,
                            "N_finalConvs":nConv,
                            "Nparams":Nparam,
                            "channels":channels,
                            "kernel_sizes":k_size,
                            "batch_norm":bn,
                            "dataset": dataset,
                            "lnorm":lnorm
                        })
    wandb.run.name="Run test"


    ### Set up wandb log



    wandb_callback = WandbCallback(monitor="val_loss", verbose=0, mode="auto", save_weights_only=(False),
                    log_weights=(False), log_gradients=(False), save_model=(True),
                    training_data=(X_train, Y_train), validation_data=None, labels=[], data_type=None,
                    predictions=36, generator=None, input_type=None, output_type=None,
                    log_evaluation=(False), validation_steps=None, class_colors=None,
                    log_batch_frequency=None, log_best_prefix="best_", save_graph=(True),
                    validation_indexes=None, validation_row_processor=None,
                    prediction_row_processor=None, infer_missing_processors=(True),
                    log_evaluation_frequency=0)


# Compile the model with custom loss and an adam optimizer.
optimizer = keras.optimizers.Adam(learning_rate=lr_initial)
model.compile(loss=loss0, optimizer=optimizer, metrics=["mae",grad_loss,mse_loss])


# Compile the model with custom loss and an adam optimizer.
if logger:
    epochs = wandb.config.epochs0
    batch_size = wandb.config.batch_size0
else:
    epochs = 25
    batch_size = 32


### A scheduler is defined to decrease the learning rate after each epoch in a controlled way during phase 1 and 2 of the training

def scheduler(epoch, lr):
    if epoch < 15:
        print('----Using large lr----')
        return lr
    else:
        return lr * 0.99#0.925


model_checkpoint = keras.callbacks.ModelCheckpoint(
        filepath='best_weights',
        monitor="val_loss",
        save_weights_only=True, 
        save_best_only=True,
        save_freq='epoch')

#Define callbacks. If logger is on, also custom callback is used 
callbacks=[model_checkpoint, tf.keras.callbacks.LearningRateScheduler(scheduler)]

if logger:
    iidx=[6,283,8,991]  #Selected samples to be monitored during training
    callbacks = [ callbacks, CustomCallback(model, X_test[iidx], Y_test[iidx], X_test, Y_test,GT_freq_test)]


### Phase 1: Fit the model using the reconstruction loss only

if logger:
    history=model.fit(X_train, Y_train, validation_split=0.2, batch_size=batch_size, epochs=epochs, callbacks=[callbacks, history, wandb_callback])
else:
    history=model.fit(X_train, Y_train, validation_split=0.2, batch_size=batch_size, epochs=epochs, callbacks=[callbacks, history])

    


### Phase 2: Additional training with the full loss function

print('Introducing Grad Loss...')

# Compile the model with custom loss and an adam optimizer.
optimizer = keras.optimizers.Adam(learning_rate=lr_refine)
model.compile(loss=loss1, optimizer=optimizer, metrics=["mae",grad_loss,mse_loss])

if logger:
    epochs = wandb.config.epochs1
    batch_size = wandb.config.batch_size1
else:
    epochs = 60
    batch_size = 32
    
    
model_checkpoint = keras.callbacks.ModelCheckpoint(
        filepath='best_weights',
        monitor="val_loss",
        save_weights_only=True, 
        save_best_only=True,
        save_freq='epoch')

def scheduler(epoch, lr):
    if epoch < 15:
        print('----Using large lr----')
        return lr
    else:
        return lr * .99


callbacks=[model_checkpoint, tf.keras.callbacks.LearningRateScheduler(scheduler)]
if logger: 
    callbacks = [callbacks,CustomCallback(model, X_test[iidx], Y_test[iidx], X_test, Y_test,GT_freq_test)]

# Fit the model using the train and test datasets.

if logger:
    history=model.fit(X_train, Y_train, validation_split=0.2, batch_size=batch_size, epochs=epochs, callbacks=[callbacks, history,wandb_callback])
else:
    history=model.fit(X_train, Y_train, validation_split=0.2, batch_size=batch_size, epochs=epochs, callbacks=[callbacks, history])

if logger:
    run.finish()


## Testing and inference

In [None]:
### Here we show how to use the trained networks to do inference. We use samples from the test set previously unseen during training

### Load raw data to be processed by the NN

dataset='HighNoise'
pathToData='Datasets/HighNoise/'


### Use seed=1990 to split the train and test datasets consistently with what done in the paper 
normFactor, nbin, X_test,X_train,Y_test,Y_train, GT_freq, GT_freq_train, GT_freq_test = data_load(pathToData,dataset, shuf=True,seed=1990, prep=1,ShowPlot=0)

### Examples from the simulated test set are already preprocessed by the data_load method
nSample=0 ### Choose a number between 0 and 999
Raw_data_test = X_test[nSample,:]
GT_data_test = Y_test[nSample,:]


### Alternatively, provide raw data formatted as an array of length=801 and store to Raw_data_test. In this case preprocessing is needed as explained below

### Preprocessing of Raw data: the same preprocessing routine and NormFactor used during training must be used
###NormFactor for the pretrained networks are provided below:
#NormFactor = 12.4987  #NormFactor for the HighNoise dataset
#NormFactor = 14.7908  #NormFactor for the LowNoise dataset

preprocess = False  #Set to True when providing external raw data

if preprocess:
    meanX=np.mean(Raw_data_test)
    stdX=np.std(Raw_data_test)
    Raw_data_test-=meanX
    Raw_data_test /= stdX

    Raw_data_test /=NormFactor

In [None]:
### Inference phase: select a pretrained NN from the weights folder and process the raw data using the corresponding NN

Raw_data_test = Raw_data_test.reshape((1,nbin,1))

path_to_model = 'Weights/'

model_name = 'NN_HN'
model = tf.keras.models.load_model(path_to_model+model_name+'_model-best.h5',compile=False)
NN_output = model.predict(Raw_data_test)
   

### Plot the results
fig= plt.figure(figsize=(10,5))
plt.plot(Raw_data_test.reshape(nbin,),'-r',linewidth=1,label='Raw')
plt.plot(NN_output.reshape(nbin,),'-g',linewidth=1.5,label='NN')
plt.plot(GT_data_test.reshape(nbin,), '--k',linewidth=1.5,label='GT')
plt.xlabel('Absolute Raman Shift (pixels)')
plt.ylabel('Intensity (A.U.)')

plt.legend(loc='lower right')
plt.rcParams.update({'font.size': 26})
plt.rcParams['axes.linewidth'] = 1
plt.show()