## Watsite On-the-fly



This notebook contains the code for grid-to-grid mapping model as described in our publication. 
Training data files are not included, since they are producable by the commercial program FLAP. You can purchase FLAP to generate your own data, or use any other input type with similar properties. 

This model works with 48x48x48 grids, and arrays have to be in h5 format with the shape (48,48,48,number of channels). The output grid will be (48,48,48,1).


In [21]:
import numpy as np
import os,glob,time
from keras.optimizers import Adam,SGD
from keras.layers import Activation, Input, Dropout, merge, Concatenate, multiply,concatenate,add
from keras.layers.convolutional import Conv3D, UpSampling3D, Deconv3D,Conv3DTranspose
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU, ELU
from keras.models import Model
from keras.layers import Flatten, Dense, Reshape, Lambda,GaussianNoise,MaxPooling3D
from keras.utils import generic_utils as keras_generic_utils
import keras.backend as K
import tensorflow as tf
import keras
from sklearn.model_selection import train_test_split

from keras import callbacks
import matplotlib.pyplot



import h5py
import random


np.random.seed(42) # for consistent results

os.environ['CUDA_DEVICE_ORDER']= 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

print('Keras version is:',keras.__version__)
print('Tensorflow version is:',tf.__version__)

Keras version is: 2.2.4
Tensorflow version is: 1.8.0


### Models: Inception model and baseline U-Net

In [None]:
def inception_block(inputs, depth, batch_mode=0, splitted=False, activation='relu'):
    assert depth % 16 == 0
    actv = activation == 'relu' and (lambda: LeakyReLU(0.0)) or activation == 'elu' and (lambda: ELU(1.0)) or None
    
    c1_1 = Conv3D(depth//4, 1, init='he_normal', padding='same')(inputs)
    
    c2_1 = Conv3D(depth//8*3, 1, init='he_normal', padding='same')(inputs)
    c2_1 = actv()(c2_1)
    if splitted:
        print('WARNING: Splitted not supported.')
        c2_2 = Conv3D(depth//2,(1, 3,1), init='he_normal', padding='same')(c2_1)
        c2_2 = BatchNormalization(axis=-1)(c2_2)
        c2_2 = actv()(c2_2)
        c2_3 = Conv3D(depth//2, (3, 1,1), init='he_normal', padding='same')(c2_2)
    else:
        c2_3 = Conv3D(depth//2, 3, init='he_normal', padding='same')(c2_1)
    
    c3_1 = Conv3D(depth//16, 1, init='he_normal', padding='same')(inputs)
    #missed batch norm
    c3_1 = actv()(c3_1)
    if splitted:
        c3_2 = Conv3D(depth//8, 1, 5, init='he_normal', padding='same')(c3_1)
        c3_2 = BatchNormalization(axis=-1)(c3_2)
        c3_2 = actv()(c3_2)
        c3_3 = Conv3D(depth//8, 5, 1, init='he_normal', padding='same')(c3_2)
    else:
        c3_3 = Conv3D(depth//8, 5, init='he_normal', padding='same')(c3_1)
    
    p4_1 = MaxPooling3D(pool_size=3, strides=1, padding='same')(inputs)
    c4_2 = Conv3D(depth//8, 1, init='he_normal', padding='same')(p4_1)
    
    #res = merge([c1_1, c2_3, c3_3, c4_2], mode='concat', concat_axis=-1)
    res = concatenate([c1_1, c2_3, c3_3, c4_2],axis=-1)
    res = BatchNormalization(axis=-1)(res)
    res = actv()(res)
    return res


def _shortcut(_input, residual):
    stride_width = _input._keras_shape[1] // residual._keras_shape[1]
    stride_height = _input._keras_shape[2] // residual._keras_shape[2]
    stride_depth = _input._keras_shape[3] // residual._keras_shape[3]
    equal_channels = residual._keras_shape[4] == _input._keras_shape[4]

    shortcut = _input
    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or stride_depth > 1 or not equal_channels:
        shortcut = Conv3D(residual._keras_shape[0], 1,
                                 strides=(stride_width, stride_height,stride_depth),
                                 init="he_normal", padding="valid")(_input)

    #return merge([shortcut, residual], mode="sum")
    return add([shortcut, residual])


def rblock(inputs, num, depth, scale=0.1):    
    residual = Conv3D(depth, num, padding='same')(inputs)
    residual = BatchNormalization(axis=-1)(residual)
    residual = Lambda(lambda x: x*scale)(residual)
    res = _shortcut(inputs, residual)
    return ELU()(res) 

def NConvolution3D(nb_filter, dim, padding='same', strides=1):
    def f(_input):
        conv = Conv3D(nb_filter, dim, strides=strides,
                              padding=padding)(_input)
        norm = BatchNormalization(axis=-1)(conv)
        return ELU()(norm)

    return f

def inception(input_img_dim, num_output_channels):
    optimizer = Adam(lr=0.045, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    splitted = False
    act = 'elu'
    
    inputs = Input((input_img_dim), name='main_input')
    conv1 = inception_block(inputs, 32, batch_mode=2, splitted=splitted, activation=act)
    #conv1 = inception_block(conv1, 32, batch_mode=2, splitted=splitted, activation=act)
    
    #pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    pool1 = NConvolution3D(32, 3, padding='same', strides=2)(conv1)
    pool1 = Dropout(0.5)(pool1)
    
    conv2 = inception_block(pool1, 64, batch_mode=2, splitted=splitted, activation=act)
    #pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    pool2 = NConvolution3D(64, 3, padding='same', strides=2)(conv2)
    pool2 = Dropout(0.5)(pool2)
    
    conv3 = inception_block(pool2, 128, batch_mode=2, splitted=splitted, activation=act)
    #pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    pool3 = NConvolution3D(128, 3, padding='same', strides=2)(conv3)
    pool3 = Dropout(0.5)(pool3)
     
    conv4 = inception_block(pool3, 256, batch_mode=2, splitted=splitted, activation=act)
    #pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    pool4 = NConvolution3D(256, 3, padding='same', strides=2)(conv4)
    pool4 = Dropout(0.5)(pool4)
    
    conv5 = inception_block(pool4, 512, batch_mode=2, splitted=splitted, activation=act)
    #conv5 = inception_block(conv5, 512, batch_mode=2, splitted=splitted, activation=act)
    conv5 = Dropout(0.5)(conv5)
    

    
    after_conv4 = rblock(conv4, 1, 256)

    up6 = concatenate([UpSampling3D(size=2)(conv5), after_conv4],axis=-1)
    conv6 = inception_block(up6, 256, batch_mode=2, splitted=splitted, activation=act)
    conv6 = Dropout(0.5)(conv6)
    
    after_conv3 = rblock(conv3, 1, 128)
    up7 = concatenate([UpSampling3D(size=2)(conv6), after_conv3], axis=-1)
    conv7 = inception_block(up7, 128, batch_mode=2, splitted=splitted, activation=act)
    conv7 = Dropout(0.5)(conv7)
    
    after_conv2 = rblock(conv2, 1, 64)
    up8 = concatenate([UpSampling3D(size=2)(conv7), after_conv2], axis=-1)
    conv8 = inception_block(up8, 64, batch_mode=2, splitted=splitted, activation=act)
    conv8 = Dropout(0.5)(conv8)
    
    after_conv1 = rblock(conv1, 1, 32)
    up9 = concatenate([UpSampling3D(size=2)(conv8), after_conv1], axis=-1)
    conv9 = inception_block(up9, 32, batch_mode=2, splitted=splitted, activation=act)

    conv9 = Dropout(0.5)(conv9)

    conv10 = Conv3D(num_output_channels, 1, init='he_normal', activation='sigmoid', name='main_output')(conv9)


    model = Model(input=inputs, output=[conv10])
    model.compile(optimizer=optimizer,loss=gen_dice_loss,metrics=['acc',precision,recall])

    return model    


def baselineUnet(input_img_dim, num_output_channels):
    stride = 2


    # batch norm merge axis
    bn_axis = -1

    input_layer = Input(shape=input_img_dim, name="unet_input")

    # 1 encoder C64

    en_1 = Conv3D(32, 2, padding='same', strides=stride)(input_layer)
    en_1 = BatchNormalization(name='gen_en_bn_1',  axis=bn_axis)(en_1)
    en_1 = LeakyReLU(alpha=0.2)(en_1)
    
    # 2 encoder C128
    en_2 = Conv3D(64, 2, padding='same', strides=stride)(en_1)
    en_2 = BatchNormalization(name='gen_en_bn_2',  axis=bn_axis)(en_2)
    en_2 = Dropout(p=0.5)(en_2)
    en_2 = LeakyReLU(alpha=0.2)(en_2)
    
    # 3 encoder C256
    
    en_3 = Conv3D(128, 2, padding='same', strides=stride)(en_2)
    en_3 = BatchNormalization(name='gen_en_bn_3', axis=bn_axis)(en_3)
    en_3 = Dropout(p=0.5)(en_3)
    en_3 = LeakyReLU(alpha=0.2)(en_3)
    
    # 4 encoder C512
    en_4 = Conv3D(256, 2, padding='same', strides=stride)(en_3)
    en_4 = BatchNormalization(name='gen_en_bn_4', axis=bn_axis)(en_4)
    en_4 = Dropout(p=0.5)(en_4)
    en_4 = LeakyReLU(alpha=0.2)(en_4)
    
    # 5 encoder C512
    en_5 = Conv3D(512, 2, padding='same', strides=stride)(en_4)
    en_5 = BatchNormalization(name='gen_en_bn_5', axis=bn_axis)(en_5)
    en_5 = Dropout(p=0.5)(en_5)
    en_5 = LeakyReLU(alpha=0.2)(en_5)
    
    # 6 encoder C512
    en_6 = Conv3D(512, 2,padding='same', strides=stride)(en_5)
    en_6 = BatchNormalization(name='gen_en_bn_6', axis=bn_axis)(en_6)
    en_6 = Dropout(p=0.5)(en_6)
    en_6 = LeakyReLU(alpha=0.2)(en_6)
    
    # 4 decoder CD1024 (decodes en_5)
    de_4 = UpSampling3D(size=3)(en_6)
    de_4 = Conv3D(512, 2, padding='same')(de_4)
    de_4 = BatchNormalization(name='gen_de_bn_4', axis=bn_axis)(de_4)
    de_4 = Dropout(p=0.5)(de_4)
    de_4 = concatenate([de_4, en_4],axis=-1)
    de_4 = Activation('relu')(de_4)
    
    # 5 decoder CD1024 (decodes en_4)
    de_5 = UpSampling3D(size=(2, 2, 2))(de_4) #short-circuit! should be de_4
    de_5 = Conv3D(256, 2, padding='same')(de_5)
    de_5 = BatchNormalization(name='gen_de_bn_5', axis=bn_axis)(de_5)
    de_5 = Dropout(p=0.5)(de_5)
    de_5 = concatenate([de_5, en_3],axis=-1)
    de_5 = Activation('relu')(de_5)
    
    
    # 6 decoder C512 (decodes en_3) 
    de_6 = UpSampling3D(size=(2, 2, 2))(de_5) #short-circuit! should be de_5
    de_6 = Conv3D(128, 2, padding='same')(de_6)
    de_6 = BatchNormalization(name='gen_de_bn_6', axis=bn_axis)(de_6)
    de_6 = Dropout(p=0.5)(de_6)
    de_6 = concatenate([de_6, en_2], axis=-1)
    de_6 = Activation('relu')(de_6)
    
    # 7 decoder CD256 (decodes en_2)
    de_7 = UpSampling3D(size=(2, 2, 2))(de_6)
    de_7 = Conv3D(64, 2, padding='same')(de_7)
    de_7 = BatchNormalization(name='gen_de_bn_7', axis=bn_axis)(de_7)
    de_7 = Dropout(p=0.5)(de_7)
    de_7 = concatenate([de_7, en_1],axis=-1)
    de_7 = Activation('relu')(de_7)


    de_8 = UpSampling3D(size=(2, 2, 2))(de_7)
    de_8 = Conv3D(num_output_channels, 3, padding='same')(de_8)
    de_8 = Activation('sigmoid')(de_8)

    unet_generator = Model(input=[input_layer], output=[de_8], name='unet_generator')
    return unet_generator

### Preprocessing and data generators

In [None]:
def normalize(x):

    x = np.clip(x,-20.0,20.0)/10

    return x
    

def normalize_wat(X,thresh):


    X[np.where(X>thresh)]=1.


    return X
  
thresh = [0.,0.02,0.03,0.045,0.06,0.07] # thresholds for water occupancy

def load_grids(folders,batch_size=10,printname=False,shuffle=True,permutate_axis=None):
    
    
    # iterate forever bc keras requires this
    num_images = len(folders)
    while True:
        if shuffle:
            np.random.shuffle(folders)
        for batch_num in range(0, num_images-1, batch_size):
            i = batch_num
            if num_images-i < batch_size:
                i_end = num_images-1
            else:
                i_end = i + batch_size
            #open h5 files
            files = []
            for j in range(i,i_end):
                if printname:
                    fil = open('dataset_list.txt','a')
                    print('Folder: {0}'.format(folders[j]),file=fil)
                    fil.close()
                try:
                    files.append(h5py.File(folders[j], 'r'))
                except Exception as e:
                    print(e)
                    continue
            x_batch_atomgrids =[]
            y_batch_watergrids = []
            for k in range(len(files)):

                #this part is for eli5 lib feature importance
                if permutate_axis is not None:
                    assert(type(permutate_axis==type(1)))
                    x_grids = files[k]['x'][:]
                    assert(permutate_axis<x_grids.shape[-1])
                    n = np.random.permutation(x_grids[...,permutate_axis])
                    x_grids[...,permutate_axis] = n
                    x_batch_atomgrids.append(x_grids)
                        
                else:        
                    x_batch_atomgrids.append(files[k]['x'][:])
                
                all_y = []
                y_batch = files[k]['y'][:,:,:,0]
                for t in thresh:
                    all_y.append(normalize_wat(np.array(y_batch),t))
                
                y_batch_watergrids.append(np.stack(all_y,axis=-1))

                
                files[k].close()

            # slice the specific batch that we want and output it through the generator
            
            
            
            x_batch_atomgrids =np.nan_to_num(x_batch_atomgrids)
            x_batch_atomgrids = normalize(np.array(x_batch_atomgrids))
            y_batch_watergrids =np.nan_to_num(y_batch_watergrids)
            #y_batch_watergrids = normalize_wat(np.array(y_batch_watergrids))
            

            yield x_batch_atomgrids, y_batch_watergrids
            

            
def process_batch(j):
        if printname:
                print('Folder: {0}'.format(folders[j]))
        try:
            f=(h5py.File(j, 'r'))
        except Exception as e:
            return
        
        x_batch = f['x'][:]
        y_batch = np.expand_dims(f['y'][:,:,:,0],axis=-1)
        f.close()
        return x_batch,y_batch
    

#### Losses and metrics and other helper functions

In [None]:

    
    
def cropped_dice(y_true,y_pred,edge=20):
    #calculate dice < 5. A. gets 48x48x48 returns 20x20x20
    start = 48//2 - edge//2
    end = start+edge
    
    #return dice_coef(y_true[:,start:end,start:end,start:end,:], y_pred[:,start:end,start:end,start:end,:])
    return recall(y_true[:,start:end,start:end,start:end,:], y_pred[:,start:end,start:end,start:end,:])



def dice_coefprint(y_true, y_pred, smooth=K.epsilon()):
    """
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    dice = (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)
    return tf.Print(dice,[dice])


def dice_coef_thresh(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred) # y_pred_f = K.cast(K.greater(K.flatten(y_pred), Threshold), 'float32')
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)

def dice_coef(y_true, y_pred, smooth=K.epsilon()):
    """
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

def gen_dice_loss(y_true, y_pred):
    '''
    computes the sum of two losses : generalised dice loss and weighted cross entropy
    '''

    #generalised dice score is calculated as in this paper : https://arxiv.org/pdf/1707.03237
    y_true_f = y_true
    y_pred_f = y_pred
    sum_p=K.sum(y_pred_f)
    sum_r=K.sum(y_true_f)
    sum_pr=K.sum(multiply([y_true_f,y_pred_f]))
    weights=K.pow(K.square(sum_r)+K.epsilon(),-1)
    generalised_dice_numerator =2*K.sum(weights*sum_pr)
    generalised_dice_denominator =K.sum(weights*(sum_r+sum_p))
    generalised_dice_score =generalised_dice_numerator /generalised_dice_denominator
    GDL=1-generalised_dice_score
    del sum_p,sum_r,sum_pr,weights

    return GDL




def precision(y_true, y_pred):
    """Precision metric.

    Only computes a batch-wise average of precision.

    Computes the precision, a metric for multi-label classification of
    how many selected items are relevant.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision
def recall(y_true, y_pred):
    """Recall metric.

    Only computes a batch-wise average of recall.

    Computes the recall, a metric for multi-label classification of
    how many relevant items are selected.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall






def log_gen(generator_model,validation_atomgrid,wat_ground_truth,epoch):
    
    watgrid = generator_model.predict(validation_atomgrid)
    #watgrid = np.arctanh((watgrid+0.99999)/2)
    watgrid *= 35 #np.arctanh(watgrid)*35
    wat_ground_truth *= 35 #np.arctanh(wat_ground_truth)*35

    np.save('e{0}_predicted.npy'.format(epoch),watgrid)
    np.save('e{0}_groundtruth.npy'.format(epoch),wat_ground_truth)

def feed_to_generator(gen_model,X,scaling_fac=1.0):
    #returns fake y
    fake_y = gen_model.predict(X)*scaling_fac
    return fake_y
    
    
def water_predictor(generator_model,atomgrid,wat,prefix=''):

    
    pr_wat = feed_to_generator(generator_nn,atomgrid)
    print(pr_wat.shape)
    pr_wat = np.sum(pr_wat,axis=-1)

    
    pr_wat =  np.reshape(pr_wat,(48,48,48))
    wat = np.sum(wat,axis=-1)
    wat = np.reshape(wat,(48,48,48))
    
    np.save(prefix+'watsitemapper_predicted.npy',pr_wat)
    np.save(prefix+'watsitemapper_groundtruth.npy',wat)
    
def water_predictor_num(generator_model,atomgrid,wat,i,prefix=''):

    
    pr_wat = feed_to_generator(generator_nn,atomgrid)

    pr_wat = np.sum(pr_wat,axis=-1)

    
    pr_wat =  np.reshape(pr_wat,(48,48,48))
    wat = np.sum(wat,axis=-1)
    wat = np.reshape(wat,(48,48,48))


    #wat = np.reshape(wat,wat.shape[1:-1])
    
    np.save(prefix+'watsitemapper_predicted_{0}.npy'.format(i),pr_wat)
    np.save(prefix+'watsitemapper_groundtruth_{0}.npy'.format(i),wat)
    


In [None]:
def tf_precision(y_true,y_pred):
    auc = tf.metrics.precision(y_true, y_pred)[1]
    K.get_session().run(tf.local_variables_initializer())
    return auc
def tf_recall(y_true,y_pred):
    auc = tf.metrics.recall(y_true, y_pred)[1]
    K.get_session().run(tf.local_variables_initializer())
    return auc

### Model compile and run

Two pre-trained models are provided, the Inception U-Net model and the Baseline U-Net. To use either, approprite model creator function should be called and its corresponding weights file should be loaded.

In [None]:
#probes_idx = [22, 53, 30, 47, 66, 24, 4, 58, 33, 61, 19, 32]



#save_weight = callbacks.ModelCheckpoint('weights/weights.inception.mapper_watsite_multithresh_koesset.hdf5', monitor='val_loss', verbose=0, save_best_only=True, period=1)
#save_weight = callbacks.ModelCheckpoint('weights/weights.unetbaseline.mapper_watsite_multithresh_koesset.hdf5', monitor='val_loss', verbose=0, save_best_only=True, period=1)

im_width = im_height = im_depth = 48
# inpu/oputputt channels in image
input_channels = 12
output_channels = len(thresh) #the thresholds for water occupancy
input_img_dim = (im_width, im_height, im_depth,input_channels) #channel should be at the end
output_img_dim = (im_width, im_height, im_depth,output_channels)
#generator_nn = inception(input_img_dim, output_channels) #inception unet
generator_nn  = baselineUnet(input_img_dim, output_channels) #baseline unet

#generator_nn.summary()

opt = Adam(lr=1E-3) #, beta_1=0.9, beta_2=0.999, epsilon=1e-08)


generator_nn.load_weights('weights/weights.unetbaseline.mapper_watsite_multithresh_koesset.hdf5')

generator_nn.compile(loss=gen_dice_loss, optimizer=opt,metrics=['mae','acc',precision,recall,dice_coef,cropped_dice])
batch_size = 16

#data_path = '..' # change this
#datafiles = glob.glob(data_path+'/*/newwatsite_grids.h5')

def get_list(f):
    with open(f,'r') as fin:
        pdbs = fin.read().splitlines()
    pdbs = [data_path+'/'+i+'/newwatsite_grids.h5' for i in pdbs]
    return pdbs

#trainlist = get_list('../train_test_sets/all_0.5_2__reducedtrain0.list')
#testlist = get_list('../train_test_sets/all_0.5_2__reducedtest0.list')

#datafiles = check_converged(datafiles)
#trainfiles, valfiles = train_test_split(datafiles,test_size=.1)
#trainfiles = np.array(datafiles[100:])
#valfiles = np.array(datafiles[0:100])
#trainfiles = np.array(list(set(datafiles).intersection(set(trainlist))))
#valfiles = np.array(list(set(datafiles).intersection(set(testlist))))

#np.random.shuffle(trainfiles)
#np.random.shuffle(valfiles)

#trainfiles = trainfiles[0:7000]

#print('trains:',len(trainfiles))
#print('tests:',len(valfiles))
#nb_epoch = 1001
#n_images_per_epoch = len(trainfiles)
#tng_gen = load_grids(trainfiles,batch_size=batch_size)
#val_gen = load_grids(valfiles,batch_size=batch_size)



#generator_nn.fit_generator(tng_gen, epochs=nb_epoch,steps_per_epoch=len(trainfiles)/batch_size, verbose=2, validation_data=val_gen,validation_steps=len(valfiles)/batch_size, shuffle=True,callbacks=[save_weight]) 
#out  = generator_nn.evaluate_generator(val_gen,steps=len(valfiles)/batch_size,verbose=1) # to only evaluate, use this line
#print(out) # get the loss and metrics





### Running a prediction task

The input is an hdf5 file, with two keys: 'x', which is the FLAP atom grids and 'y' which is water occupancy grids. The snippet below takes the grid file for pdb 1adl and generates two files, ground truth and predicted water grids and saves it as npy arrays. Npy arrays may later be converted to DX files for python visualization or other downstream tasks (such as input for Gnina)

In [None]:
## Use this cell to predict grids and save them as npy files


data_path = '.'
pdbs = ['1adl']
grids = [data_path+'/'+i+'/newwatsite_grids.h5' for i in pdbs]

grids = h5py.File(grids[0], 'r')

X= np.array([grids['x'][:]])
X=np.nan_to_num(X)
#print(X.shape)
X= normalize(X)             

y = np.array([grids['y'][:,:,:,0]])
y=np.nan_to_num(y)
#print(y.shape)
all_y = []

for t in thresh:
    all_y.append(normalize_wat(np.array(y),t))               

    
y= np.stack(all_y,axis=-1)
grids.close()


water_predictor_num(generator_nn,X,y,i,prefix='')

    

