In [None]:
#############################################################################################
## trains three hypertuned deep learning neural network architectures on training data sets 
## of different sizes and sampled in multiple different ways. 
## Determines generalization performance of the resulting neural networks on test data.
#############################################################################################

import numpy as np
import pandas as pd
from pathlib import Path


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
import keras_tuner as kt
from keras import backend


#to get reproducible results 
import random
from numpy.random import seed
random.seed(637281) #python core random number generator
np.random.seed(123784)
tf.random.set_seed(243924)

from scipy import stats
import sys

import os #for remove command

#for time stamping output file
from datetime import datetime
dateFORMAT = '%d-%m-%Y'
import pandas as pd

import deep_funcs_pub as aw
#in case aw has changed
import importlib
importlib.reload(aw)


In [None]:
#########################################################
### extract fitness data
#########################################################

fitdatfile="fitness_data_science_papkou2023.tsv"
pathstr=""
filepath = pathstr + fitdatfile
infile = Path(filepath)
df =  pd.read_csv(infile, sep='\t')


fitdatall=df.T.to_dict('list')

fit={}
sefit={}
aaseq={}
for n in fitdatall.keys():
    ntseq=fitdatall[n][0]
    aaseq[ntseq]=fitdatall[n][1]
    fit[ntseq]=fitdatall[n][2]
    sefit[ntseq]=fitdatall[n][3]

#shift fitness by 2 to avoid divergence of mape near zero and for consistency with 
#the data transformation used during hypertuning
fitshift = 2
for s in fit.keys():
    fit[s]+=2
print("shifting fitness values by ", fitshift, 
      "min/max after shift", np.min(list(fit.values())), np.max(list(fit.values())) )

#now write data for genotypes with fitness above the threshold below into a new dict
#these correspond to the viable genotype
hilothresh=1.5
fit_hi={}
sefit_hi={}
aaseq_hi={}
for s in fit.keys():
     if fit[s]>=hilothresh:
         aaseq_hi[s]=aaseq[s]
         fit_hi[s]=fit[s]
         sefit_hi[s]=sefit[s]
print("number of data points after filtering for fitness above ", hilothresh, 
      ":", len(list(fit_hi.values())))
 

In [None]:
#######################################################################
## main routine: loop over different sampling modes and test set sizes 
#######################################################################


#defines fold-cross validation used during training
crossfold=4  
print("performing ", crossfold, "-fold cross validation ")

#it is useful to have the currently best network stored away
#needed to also delete this model for each new training
tmp_best_model_file = "sampling_tmp_best.keras"

#stop nn training if validation mae does not improve for the last patience epochs
callbacks_list=[keras.callbacks.EarlyStopping(monitor="val_loss", 
                                                  patience=5),
                keras.callbacks.ModelCheckpoint(filepath=tmp_best_model_file, 
                                                   monitor="val_loss", 
                                                   save_best_only=True)]

max_epochs = 100
batch_size = 128


trva_size_arr=range(200, 8001, 200)


#the different sampling modes to be used
sampling_mode_arr=["random", "unique_aas", "two_syn_aas", "max_codon_usage", 
                   "maxdiv_nts", "maxdiv_aas_georgiev", "NNK", "NNS", "NNG", 
                   "NNT", "NDT", "Tang"]


#an array to hold different NN architectures to try
architectures=["dense_stack", "RNN_stack", "transf"]

#number of replicate trainings to be performed for each nn
n_replicates=3
print("performing ", n_replicates, " replicates")  

#a data frame to which the output will be written
datoutarr=[]
datoutdf = pd.DataFrame(datoutarr, columns=['architecture', 'sampling_mode', 'size', 'actual_repl', 
                                            'mean_act_n_sample', 'serr_act_n_sample',
                                            'mean_frac_sample','serr_frac_sample',
                                            'mean_minave_va_loss', 'serr_minave_va_loss', 
                                            'mean_te_loss', 'serr_te_loss',
                                            'mean_te_mae', 'serr_te_mae', 
                                            'mean_te_mape', 'serr_te_mape', 
                                            'mean_te_spe', 'serr_te_spe',
                                            'mean_te_pea',  'serr_te_pea'])


#loop over nn architectures  
for archit in architectures:
    #loop over sampling strategies
    for s_mode in sampling_mode_arr:

        #loop over sample sizes
        for size in trva_size_arr:
                        
            #to free memory from previous sessions\n",
            tf.keras.backend.clear_session()
            try:
                del model
            except NameError:
                pass

            
            #define the target size for the training|validation (tr/va) data set
            f_tr_va=size/len(list(fit_hi.values()))
              
            #arrays that will hold for each replicate 
            trva_size=[] #actual size of tr/va set size as fraction of whole data set 
            trva_frac=[] #fraction of actual tr/va set size as fraction of whole data set
            minave_va_loss=[] #min of the average validation loss across all k-fold cross validation 
            te_loss=[] # test loss
            te_mae=[] #test mae
            te_mape=[] #test mape

            te_spe=[] #spearman correlation between predicted and observed fitness
            te_pea=[] #pearson correlation between predicted and observed fitness

            
            #loop over replicate trainings
            for repl in range(n_replicates):
                #console output for monitoring
                print("\n\nMODEL:", archit, " sampling_mode: ", s_mode, " size", size, 
                      " replicate", repl, "\n\n")

                #to record training losses
                tr_loss_hist = []
                tr_mae_hist = []
                tr_mape_hist = []
                va_loss_hist = []
                va_mae_hist = []
                va_mape_hist = []
                       

                # notice that each replicate uses its own data sample
                # the following models can all use data from the same function
                if (archit == 'RNN_stack') or (archit == 'dense_stack') or (archit == 'transf'):
                    
                    [tr, va, te]=aw.dhfr_sample_data_kfold_int_codon_onehot(s_mode, crossfold, aaseq_hi, fit_hi, sefit_hi, 
                                                                 f_tr_va, f_te=0.5, flattenflag=1) 
                else:
                    print("'error_aw': invalid model architecture")

                #delete the checkpoint model from the last training session
                #this is so awkward because apptly keras has no command to do that
                if os.path.exists(tmp_best_model_file):
                    os.remove(tmp_best_model_file)
                
                #it is important to have this statement here, or the previously deleted best
                #solution file will apptly. not be recreated
                #stop if validation loss does not improve for the last patience epochs
                callbacks_list=[keras.callbacks.EarlyStopping(monitor="val_loss", 
                                                  patience=5),
                                keras.callbacks.ModelCheckpoint(filepath=tmp_best_model_file, 
                                                   monitor="val_loss", 
                                                   save_best_only=True)]

                    
                #when callback is used, the epoch at which the training stopped for each of the k-fold cross-validations
                stopped_epoch=[]
                for i in range(crossfold):
                    
                    if archit=='transf':
                        model = aw.transf_cod_w_pos_embed_v1()

                        tr_c=tr["tr_codseq_int"][i]
                        tr_f=tr["tr_fit"][i]
                        va_c=va["va_codseq_int"][i]
                        va_f=va["va_fit"][i]
                        
                        history = model.fit(tr_c, tr_f,
                                    callbacks=callbacks_list,
                                    validation_data=(va_c, va_f),
                                    epochs=max_epochs, batch_size=batch_size, verbose=0) 
                        
                        
                    elif archit=='RNN_stack':
                        model = aw.RNN_stack_w_pos_embed_v1()

                        tr_c=tr["tr_codseq_int"][i]
                        tr_f=tr["tr_fit"][i]
                        va_c=va["va_codseq_int"][i]
                        va_f=va["va_fit"][i]
                        
                        history = model.fit(tr_c, tr_f,
                                    callbacks=callbacks_list,
                                    validation_data=(va_c, va_f),
                                    epochs=max_epochs, batch_size=batch_size, verbose=0) 
                        
                    elif archit=='dense_stack':
                        model = aw.dense_stack_v1()
                        #print("loaded dense stack model")
                        #print(model.summary())
                        
                        
                        tr_s=tr["tr_ntseq_1hot"][i]
                        tr_f=tr["tr_fit"][i]
                        va_s=va["va_ntseq_1hot"][i]
                        va_f=va["va_fit"][i]
                        
                        history = model.fit(tr_s, tr_f,
                                    callbacks=callbacks_list,
                                    validation_data=(va_s, va_f),
                                    epochs=max_epochs, batch_size=batch_size, verbose=0) 
                    else:
                        print("'error_aw': invalid model architecture")


                   
                    
                    print("epoch at which stopping occurred", callbacks_list[0].stopped_epoch)
                    stopped_epoch.append(callbacks_list[0].stopped_epoch)

                    tr_loss_hist.append(history.history["loss"])
                    va_loss_hist.append(history.history["val_loss"])
                    tr_mae_hist.append(history.history["mae"])
                    va_mae_hist.append(history.history["val_mae"])
                    tr_mape_hist.append(history.history["mape"])
                    va_mape_hist.append(history.history["val_mape"])



                #when we use a callback for early stopping, otherwise comment out
                num_epochs = np.min(stopped_epoch)
                
                #very occasionally an RNN fails to train, a likely keras bug
                #if that happens, print a warning and throw out the entire replicate
                if num_epochs==0:
                    print("\nwarning_aw: at least one fold cross-training failed, skipping entire replicate")
                    continue
                
                
                #store the history for each of the k-fold trainings
                #this will be an array containing the average loss over all k folds over time (up to num_epochs) 
                ave_tr_loss_hist = [
                np.mean([x[i] for x in tr_loss_hist]) for i in range(num_epochs)]
                ave_tr_mape_hist = [
                np.mean([x[i] for x in tr_mape_hist]) for i in range(num_epochs)]
                ave_tr_mae_hist = [
                np.mean([x[i] for x in tr_mae_hist]) for i in range(num_epochs)]
                ave_va_loss_hist = [
                np.mean([x[i] for x in va_loss_hist]) for i in range(num_epochs)]
                ave_va_mae_hist = [
                np.mean([x[i] for x in va_mae_hist]) for i in range(num_epochs)]
                ave_va_mape_hist = [
                np.mean([x[i] for x in va_mape_hist]) for i in range(num_epochs)]

            
                if archit=='hybrid':
                    
                    best_model=keras.models.load_model(tmp_best_model_file, 
                                                        custom_objects ={"PositionalEmbedding": aw.PositionalEmbedding})
                    #define test sets
                    te_s=te["te_ntseq_1hot"]
                    te_c=te["te_codseq_int"]
                    te_f=te["te_fit"] 
                    te_sefit=te["te_sefit"] 

                  
                    predict_fit=best_model.predict([te_s, te_c])
                    eval_results = best_model.evaluate([te_s, te_c], te_f, 
                                                              return_dict=True,
                                                              verbose =0) #2: single line output, 0: silent
                elif archit=='dense_stack':
                    best_model=keras.models.load_model(tmp_best_model_file) 
                                                      
                                                      
                    #define test sets
                    te_s=te["te_ntseq_1hot"]
                    te_f=te["te_fit"] 
                    te_sefit=te["te_sefit"] 

                  
                    predict_fit=best_model.predict(te_s)
                    eval_results = best_model.evaluate(te_s, te_f, 
                                                        return_dict=True,
                                                        verbose =2) #2: single line output, 0: silent  
                
                elif archit=='RNN_stack':
                    best_model=keras.models.load_model(tmp_best_model_file, 
                                                        custom_objects ={"PositionalEmbedding": aw.PositionalEmbedding})
                    #define test sets
                    te_c=te["te_codseq_int"]
                    te_f=te["te_fit"] 
                    te_sefit=te["te_sefit"] 

                  
                    predict_fit=best_model.predict(te_c)
                    eval_results = best_model.evaluate(te_c, te_f, 
                                                              return_dict=True,
                                                              verbose =2) #2: single line output, 0: silent
                    
                elif archit=='transf':
                    best_model=keras.models.load_model(tmp_best_model_file, 
                                                        custom_objects ={"PositionalEmbedding": aw.PositionalEmbedding})
                    #define test sets
                    te_c=te["te_codseq_int"]
                    te_f=te["te_fit"] 
                    te_sefit=te["te_sefit"] 

                  
                    predict_fit=best_model.predict(te_c)
                    eval_results = best_model.evaluate(te_c, te_f, 
                                                              return_dict=True,
                                                              verbose =2) #2: single line output, 0: silent
                else:
                    print("error_aw: invalid architecture")
                    sys.exit()
               
                    
                #now fill the result arrays for this replicate
                
                #compute the actual size of the combined training and validation set
                trva_size.append(len(tr["tr_fit"][0])+len(va["va_fit"][0]))
                trva_frac.append(f_tr_va)
                #this computes the minimum over the average loss (over all k folds) for all epochs, which
                #is not necessarily the minimum at the last epochs, but probably close.
                minave_va_loss.append(np.min(ave_va_loss_hist))

                te_loss.append(eval_results["loss"])
                te_mae.append(eval_results["mae"])
                te_mape.append(eval_results["mape"])

                #some versions of the spearmanr and pearsonr routines require flat arrays as inputs
                speres=stats.spearmanr(te_f, predict_fit.flatten())
                te_spe.append(speres.correlation)

                peares=stats.pearsonr(te_f, predict_fit.flatten())
                te_pea.append(peares.statistic) #the statistic is itself an array, 

                for metric in eval_results.keys():
                    print(metric, " on test set", eval_results[metric])
                print("fitness act. vs. pred. for test set", stats.spearmanr(te_f, predict_fit.flatten()), "n=", len(te_f))      
                print("fitness act. vs. pred. for test set", stats.pearsonr(te_f, predict_fit.flatten()), "n=", len(te_f))      

            #end of loop over replicates, now calculate statistics over the replicates
            
            #to deal with situations (rare and only for RNNs) where some replicates were skipped because 
            #training on at least one fold failed and terminated at epoch zero
            actual_repl=len(trva_size)
            
            if actual_repl>0:
                datrow=[archit, s_mode, size, actual_repl,  
                        np.mean(trva_size), np.std(trva_size)/np.sqrt(actual_repl), 
                        np.mean(trva_frac), np.std(trva_frac)/np.sqrt(actual_repl), 
                        np.mean(minave_va_loss), np.std(minave_va_loss)/np.sqrt(actual_repl), #note that this is an average of averages
                        np.mean(te_loss), np.std(te_loss)/np.sqrt(actual_repl), 
                        np.mean(te_mae), np.std(te_mae)/np.sqrt(actual_repl), 
                        np.mean(te_mape), np.std(te_mape)/np.sqrt(actual_repl), 
                        np.mean(te_spe), np.std(te_spe)/np.sqrt(actual_repl), 
                        np.mean(te_pea), np.std(te_pea)/np.sqrt(actual_repl)]
            #for the unlikely case that all replicates have at least one fold crossvalidation that failed to train    
            else:
                datrow=[archit, s_mode, size, actual_repl,  
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan, 
                        np.nan , np.nan ]
            print(datrow)
            datoutarr.append(datrow)
            #appending the row to the data file
            datoutdf.loc[len(datoutdf)] = datrow

        #end of loop over all sizes, write copy of the file
        datoutfile="sampling_"+ \
                     datetime.now().strftime(dateFORMAT) + ".txt"
        datoutdf.to_csv(datoutfile, sep='\t', index=False) 


