# GAN-ZT 10-fold cross validation
### This notebook contains the code need to run a random 10 fold cross validation of a conditional generative adversarial network using [0,1] encoded individual toxicity matrices and chemicals structural data.
#### See http://biorxiv.org/lookup/doi/10.1101/2020.10.02.322917 for details

By Adrian J Green, PhD

#### Import Tensorflow and manage GPUs

In [1]:
import os

# tensorflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import multi_gpu_model

tf.keras.backend.clear_session()  # For easy reset of notebook state.

# minimize GPU useage by allowing memeory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


#### Import Python data handling and visualization modules, and local routines

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

# standard python
import numpy as np
# from sklearn.metrics import mean_squared_log_error, mean_squared_error
# from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import ShuffleSplit
import pathlib
import os.path
import warnings
import timeit

# plotting, especially for jupyter notebooks
# import matplotlib
#matplotlib.rcParams['text.usetex'] = True # breaks for some endpoint labels
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from IPython.display import Image

# pandas
import pandas as pd

# local routines
from chemdataprep import load_PDBs
from toxmathandler_AG import load_tmats, load_indv_tmats

# NN build routines
from NNbuild_train_vis import init_NN_v2

# NN train routines
from NNbuild_train_vis import discriminator_loss,generator_loss, get_train_function, write_training_file

# Performance evaluation routines
from gen_AggE import calc_AggE_indv, display_conf_matrix

print("tensorflow version",tf.__version__,". Executing eagerly?",tf.executing_eagerly())
print("Number of GPUs: ", len(tf.config.experimental.list_physical_devices('GPU')))

tensorflow version 2.1.0 . Executing eagerly? True
Number of GPUs:  1


## Global Options and Variables

#### Chemical structure and toxicity import and model output variables

In [3]:
### PDB options

# cGAN & views parameters
# [Gfeatures,Gbaselayers,Glayers,Dfeatures,Dbaselayers,Dlayers,carbonbased, setNatoms, views, ClassLabels]
parameters = [279, 3, 11, 50, 0, 3, False, 82, 126, 1003] # GAN-ZT_v7

## Option to base views on carbon or not. (Safe even if some have no carbon.)
# Setting True will make the data smaller in memory and everything run faster.
carbonbased = parameters[6]
## Option for truncating the length of views.
# Truncating will make the data and NN smaller and things run faster.
# It make sense if we believe that looking at all neighborhoods of some size 
# gives sufficient understanding of the chemical.
# setNatoms = None # use max number in data
setNatoms = parameters[7] # truncate to this number
useClassLabels = parameters[9] # allow cGAN to use class labels in training, None or int

views=parameters[8]

dataType = '(0,1)_18x1'

if (dataType.find('(0,1)_18x6')!=-1):
    concentrations = [0,1,2,3,4,5]
else:
    concentrations = [5]    ## Which of the available endpoints to use
endpoints = [i for i in range(4,22)] # use all
    
genpath = 'AG-model-GT-'+dataType+'.h5'
discpath = 'AG-model-DT-'+dataType+'.h5'

# Traning individual toxic
trpath = '/home2/ajgreen4/Read-Across_w_GAN/DataFiles/(0,1)_encoding_indv/Tox21_training_compounds/'
valpath = '/home2/ajgreen4/Read-Across_w_GAN/DataFiles/(0,1)_encoding_indv/Tox21_validation_compounds/'
allpath = '/home2/ajgreen4/Read-Across_w_GAN/DataFiles/(0,1)_encoding_indv/Tox21_all_train_compounds/'
    
modelpath = '/home2/ajgreen4/Read-Across_w_GAN/Models/'
imageOut = '/home2/ajgreen4/Read-Across_w_GAN/imageOut/'

#### cGAN variables

In [4]:
# loss
Gloss_function = tf.keras.losses.MeanSquaredError()
Dloss_function = tf.keras.losses.MeanSquaredError()

# optimizers
generator_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-3)

# Data Preparation

### Load chemical data and vectorize into weights and views

In [5]:
%%time
# all training files
[ws, vs, Natoms, Nviews, chemnames, Vshape] = load_PDBs(allpath,setNatoms=setNatoms,setNviews=views,carbonbased=carbonbased)

1003 pdb files found at /home2/ajgreen4/Read-Across_w_GAN/DataFiles/(0,1)_encoding_indv/Tox21_all_train_compounds/PDBs/
Species occurring = {'P', 'CL', 'B', 'N', 'O', 'F', 'BR', 'S', 'H', 'C', 'I', 'AS', 'SI'}
Setting all views to Natoms= 82
126 views needed, but setting to 126
Maximum views used = 126
Data tensor (w,v) shapes= (1003, 126) (1003, 126, 410)
CPU times: user 1min 46s, sys: 245 ms, total: 1min 46s
Wall time: 1min 46s


#### Encode chemical labels

In [6]:
chem_labels = np.arange(len(chemnames))
chem_labels = np.reshape(chem_labels, (chem_labels.shape[0],1))

#### Load individual toxicity data

In [7]:
%%time
# individual toxicity
print("Loading individual toxicity matrices (~15 minutes)")
### Toxicity matrix options
[toxicity,rows, cols, fish] = load_indv_tmats(allpath,chemnames,endpoint_indexes=endpoints,
                                              SET=2, verbose=1)

# legend labels for plotting
print("Using", len(concentrations), "concentrations")
print("Using", len(endpoints), "endpoints")
endpoints = [i for i in range(len(rows))]
concentrations = [i for i in range(len(cols))]
legend = [rows,cols,endpoints,concentrations]
    
if 0:
    print('CASRN: ',chemnames[0])
    print('Chem label: ',fish[0])
    print('First tox matrix:', toxicity[0])

Loading individual toxicity matrices (~15 minutes)
Number of chemicals= 1003
Using concentrations ['64 uM']
Using endpoints: ['MORT', 'YSE_', 'AXIS', 'EYE_', 'SNOU', 'JAW_', 'OTIC', 'PE__', 'BRAI', 'SOMI', 'PFIN', 'CFIN', 'PIG_', 'CIRC', 'TRUN', 'SWIM', 'NC__', 'TR__']
Toxicity vector length Ntoxicity= 18
Using 1 concentrations
Using 18 endpoints
CPU times: user 14min 47s, sys: 6.02 s, total: 14min 53s
Wall time: 15min 2s


## Neural Network training code

In [8]:
# Create wrapper function to allow model to be discarded and re-initilazation between cross-validation folds
def get_train_function():
    # Compile training function
    @tf.function
    def train_step(G_data,real_data,chemClass,toxClass,doG=True,doD=True):
        """Train Condictional Generative Adversarial Network.

        :parameter G_data: List containing a np.array vector with weights and 
                           a np.array matrix with vectorized views.
                           (see chemdataprep.load_pdb())
        :type G_data: list
        :parameter real_data: Master toxicity data matrix. 
                              Rows correspond to chemicals and columns to toxicity measurements.
                              (see toxmathandler.load_tmats())
        :type real_data: np.array
        :parameter chemClass: Chemical class label.
        :type chemClass: int
        :parameter toxClass: Toxcicity class labels.
        :type toxClass: int
        :parameter doG: If True train generator
        :type doG: boolean
        :parameter doD: If True train discriminator
        :type doD: boolean

        :returns: Discriminator and Generator loss
        :rtype: tuple

        """
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_matrix, Gpw_features = generator(G_data+chemClass, training=True)
            expanded_Dpw_model = tf.repeat(Gpw_features, repeats=y, axis=0)

            real_output = discriminator([expanded_Dpw_model]+real_data+toxClass, training=True)
            fake_output = discriminator([Gpw_features]+[generated_matrix]+chemClass, training=True)

            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output)

            if doD:
                # update discriminator
                gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
                discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

                # Additional training
                for i in range(2):
                    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                        generated_matrix, Gpw_features = generator(G_data+chemClass, training=False)
                        expanded_Dpw_model = tf.repeat(Gpw_features, repeats=y, axis=0)

                        real_output = discriminator([expanded_Dpw_model]+real_data+toxClass, training=True)
                        fake_output = discriminator([Gpw_features]+[generated_matrix]+chemClass, 
                                                    training=True)


                        gen_loss = generator_loss(fake_output)
                        disc_loss = discriminator_loss(real_output, fake_output)

                    # update discriminator
                    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
                    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

            if doG:
                # update generator
                gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
                generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        return (gen_loss,disc_loss)
    return train_step

In [9]:
# Initialize the cGAN
doG = 1
doD = 1
if useClassLabels:
    tf.keras.backend.clear_session()  # For easy reset of notebook state.
    
    # Initialize the G & D netowrks  
    [generator, discriminator] = init_NN_v2([ws, vs],toxicity,parameters)
    generator.summary()
    discriminator.summary()

# Shuffle split chemicals
X = chem_labels
# print(chem_labels.shape)
ss = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)

Model: "generator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
gen_class_label (InputLayer)    [(None, 1)]          0                                            
__________________________________________________________________________________________________
parallelwrapper_input0 (InputLa [(None, 126, 410)]   0                                            
__________________________________________________________________________________________________
gen_class_embedding (Embedding) (None, 1, 50)        50150       gen_class_label[0][0]            
__________________________________________________________________________________________________
gen_chem_feature_base (Model)   (None, 126, 279)     1428021     parallelwrapper_input0[0][0]     
__________________________________________________________________________________________

#### Training Notes:
- Generator wants gen_ability to be 1
- Discriminator wants disc_ability to be 1 and gen_ability to be 0

In [10]:
%%time
verbose = 0
# cGAN training loop with 10-fold cross-validation
k = 1
# Set which of G or D or both trains.
for train_index, test_index in ss.split(X): 
    start_time = timeit.default_timer()
    print("Running fold : ", k, " of 10")
    fish_index = []
    for chemical in train_index:
        results = np.where(fish == chemical)
        fish_index.extend(results[0].tolist())
    fish_index = np.array(fish_index)
    fish_index = fish_index.flatten()
    chem_index = train_index.flatten()

    ws_foldT = ws[chem_index]
    vs_foldT = vs[chem_index]
    chem_labels_foldT = chem_labels[:len(chem_index)] # rename chemical labels 0 - 801

    toxicity_foldT = toxicity[fish_index]
    fish_foldT = fish[fish_index]

    # raname toxicity labels 0 - 801
    chem_index_sorted = np.sort(chem_index)
    l = 0
    for label in chem_index_sorted:
        fish_foldT[fish_foldT == label] = l
        l += 1

    y = np.bincount(fish_foldT[:,0])
    y = y[y != 0]

    fish_index = []
    for chemical in test_index:
        results = np.where(fish == chemical)
        fish_index.extend(results[0].tolist())
    fish_index = np.array(fish_index)
    fish_index = fish_index.flatten()
    chem_index = test_index.flatten()

    ws_V = ws[chem_index]
    vs_V = vs[chem_index]

    # Create validation chemical list
    chemnames_array = np.array(chemnames)
    chemnames_V = chemnames_array[chem_index]
    chemnames_V = chemnames_V.tolist()

    chem_labels_V = chem_labels[:len(chem_index)]

    toxicity_V = toxicity[fish_index]
    fish_V = fish[fish_index]

    # raname toxicity labels 0 - 801
    chem_index_sorted = np.sort(chem_index)
    l = 0
    for label in chem_index_sorted:
        fish_V[fish_V == label] = l
        l += 1

    toxicity_V = toxicity[fish_index]
    fish_V = fish[:len(fish_index)]
    z = np.bincount(fish_V[:,0])
    z = z[z != 0]        

    epochs = 800
    best_kappa = 0
    training_loss = np.zeros((epochs,2))
    val_loss = np.zeros((epochs,2))
    train_step = get_train_function()
    j = 0
    while j < epochs:
        info = train_step([ws_foldT,vs_foldT],[toxicity_foldT],[chem_labels_foldT],[fish_foldT],doG,doD)

        # find out how well the two NNs are doing on training set
        gen_lab, Gchem_features = generator.predict([ws_foldT,vs_foldT,chem_labels_foldT])
        expanded_Dpw_model = tf.repeat(Gchem_features, repeats=y, axis=0)

        fake_output = discriminator.predict([Gchem_features,gen_lab,chem_labels_foldT])
        gen_ability_T = Dloss_function(tf.zeros_like(fake_output), fake_output).numpy()

        real_output = discriminator.predict([expanded_Dpw_model,toxicity_foldT,fish_foldT])
        disc_ability_T = Dloss_function(tf.ones_like(real_output), real_output).numpy()

        doG = 1
        doD = 1

        # if G is winning stop training G
        if (disc_ability_T-gen_ability_T) < 0 or gen_ability_T > 0.95:
            doG = 0
        # if D is winning stop training D
        if disc_ability_T > 0.95:
            doD = 0
        if gen_ability_T > 0.90 and disc_ability_T > 0.90:
            doG = 1
            doD = 1            

        # find out how well the two NNs are doing on validation set
        gen_lab_V, Gchem_features_V = generator.predict([ws_V,vs_V,chem_labels_V])
        expanded_Dpw_model_V = tf.repeat(Gchem_features_V, repeats=z, axis=0)

        fake_output_V = discriminator.predict([Gchem_features_V,gen_lab_V,chem_labels_V])
        gen_ability_V = Dloss_function(tf.zeros_like(fake_output_V), fake_output_V).numpy()

        real_output_V = discriminator.predict([expanded_Dpw_model_V,toxicity_V,fish_V])
        disc_ability_V = Dloss_function(tf.ones_like(real_output_V), real_output_V).numpy()

        # Calculate chemical activity - ignoring warning due to potential division by zero
        with warnings.catch_warnings():
            warnings.simplefilter(action='ignore', category=Warning)
            [gen_activity_table, tox_activity_table, gen_AggE, tox_AggE] = calc_AggE_indv(toxicity_V, chem_labels_V, 
                                                                                          chemnames_V, gen_lab_V, 
                                                                                          fish_V, endpoints, z)

            metrics = display_conf_matrix(gen_activity_table, tox_activity_table)
        model_kappa = metrics[0]

        if verbose:
            print("\nTraining Dataset")
            print('    Kappa: ', metrics[0], '  AUROC: ', metrics[1], '    SE', metrics[2])
            print("gen_ability:", gen_ability_T, " disc_ability:", disc_ability_T)
            print("Training State: doG =", doG, ", doD =", doD, " j = ", j, "\n")

        if model_kappa > best_kappa:
            best_kappa = model_kappa
            best_metrics = metrics

            # find out how well the two NNs are doing
            print("\n    Validation Dataset")
            print('    Kappa: ', metrics[0], '  AUROC: ', metrics[1], '    SE', metrics[2])
            print("    gen_ability:", gen_ability_V, " disc_ability:", disc_ability_V)
            print("    Training State: doG =", doG, ", doD =", doD, " j = ", j, "\n")

        if j % 25 == 0:
            print(j, end=" ")
        training_loss[j] = [gen_ability_T, disc_ability_T]
        val_loss[j] = [gen_ability_V, disc_ability_V]
        j += 1
    k += 1

    # Save training results
    best_kappa_f = round(best_kappa,5)

    model_ID = "AG-model-GT-"+dataType+"-Kappa-"+str(best_kappa_f)+"-"+str(k)+"-fold.h5"

    summary_file_df = write_training_file(parameters,[model_ID, concentrations, ws.shape[1], gen_ability_V, disc_ability_V],
                                  best_metrics, '/home2/ajgreen4/Read-Across_w_GAN/output/10-fold-crossval-4-14-21.xlsx')

    # Re-initialize the cGAN
    doG = 1
    doD = 1
    # Initialize the G & D netowrks  
    [generator, discriminator] = init_NN_v2([ws, vs],toxicity,parameters)

    # Determine time taken to run fold
    elapsed = timeit.default_timer() - start_time
    print(elapsed)

Running fold :  1  of 10

    Validation Dataset
    Kappa:  0.134   AUROC:  0.589     SE 61.2
    gen_ability: 0.037265867  disc_ability: 1.430598
    Training State: doG = 1 , doD = 0  j =  0 

0 
    Validation Dataset
    Kappa:  0.163   AUROC:  0.6055     SE 61.2
    gen_ability: 0.045626413  disc_ability: 1.4746188
    Training State: doG = 1 , doD = 0  j =  1 


    Validation Dataset
    Kappa:  0.167   AUROC:  0.5972     SE 51.0
    gen_ability: 0.2560663  disc_ability: 0.48462793
    Training State: doG = 1 , doD = 1  j =  18 

25 
    Validation Dataset
    Kappa:  0.194   AUROC:  0.6219     SE 61.2
    gen_ability: 0.4037956  disc_ability: 0.37176904
    Training State: doG = 0 , doD = 1  j =  40 

50 75 100 125 150 175 200 225 250 275 300 325 350 375 400 425 450 475 500 525 550 575 600 625 650 675 700 725 750 775 2216.4027229570784
Running fold :  2  of 10
0 
    Validation Dataset
    Kappa:  0.004   AUROC:  0.5061     SE 100.0
    gen_ability: 0.0015555418  disc_ability:

In [None]:
os._exit(00)

# 