In [37]:
!git clone https://github.com/adcollin/AMPLify-Feedback.git
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers
import numpy as np
import math

fatal: destination path 'AMPLify-Feedback' already exists and is not an empty directory.


In [38]:
generator = tf.keras.models.load_model('AMPLify-Feedback/model_weights/PeptideGenerator.keras')

In [39]:
generator.summary()

Model: "model_17"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 Input1 (InputLayer)         [(None, 326)]                0         []                            
                                                                                                  
 Input0 (InputLayer)         [(None, 10)]                 0         []                            
                                                                                                  
 Input1Transform (Dense)     (None, 10)                   3270      ['Input1[0][0]']              
                                                                                                  
 Concat (Concatenate)        (None, 20)                   0         ['Input0[0][0]',              
                                                                     'Input1Transform[0][0]

In [40]:
discriminator = tf.keras.models.load_model('AMPLify-Feedback/model_weights/PeptideDiscriminator.keras')



In [41]:
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Flatten (Flatten)           (None, 8170)              0         
                                                                 
 Dense0 (Dense)              (None, 512)               4183552   
                                                                 
 Dropout (Dropout)           (None, 512)               0         
                                                                 
 Dense1 (Dense)              (None, 256)               131328    
                                                                 
 Output (Dense)              (None, 1)                 257       
                                                                 
Total params: 4315137 (16.46 MB)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 4315137 (16.46 MB)
_________________________________________________________________


In [42]:
def create_oracle():
    inputs0 = tf.keras.layers.Input((190,43),name="SeqInput")
    inputs1 = tf.keras.layers.Input((326,),name="StateInput")
    x = tf.keras.layers.Conv1D(128, 5, activation='relu', name="Conv1D_0")(inputs0) # kernel_size=5 works well
    x = tf.keras.layers.Conv1D(128, 5, activation='relu', name="Conv1D_1")(x) # Just two layers work better
    x = tf.keras.layers.Flatten(name="Flatten_0")(x)
    x = tf.keras.layers.Dense(512, activation="relu", name="LearnSeqDense_0")(x)
    x = tf.keras.layers.Concatenate(axis=1, name="Concat")([x, inputs1])
    x = tf.keras.layers.Dense(1024, activation="relu", name="LearnConcatDense_0")(x)
    x = tf.keras.layers.LayerNormalization(name="LayerNorm_0")(x)
    x = tf.keras.layers.Dense(512, activation="relu", name="LearnConcatDense_1")(x)
    x = tf.keras.layers.LayerNormalization(name="LayerNorm_1")(x)
    x = tf.keras.layers.Dense(1, activation="linear", name="Output")(x)
    model = tf.keras.models.Model([inputs0, inputs1], x, name="MICPredictor")
    return model

In [43]:
oracle = create_oracle()
path = "AMPLify-Feedback/model_weights/MICPredictor"
for i, layer in enumerate(oracle.layers):
    weights = np.load(f"{path}/layer_{i}_weights.npy", allow_pickle=True)
    layer.set_weights(weights)

In [44]:
oracle.summary()

Model: "MICPredictor"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 SeqInput (InputLayer)       [(None, 190, 43)]            0         []                            
                                                                                                  
 Conv1D_0 (Conv1D)           (None, 186, 128)             27648     ['SeqInput[0][0]']            
                                                                                                  
 Conv1D_1 (Conv1D)           (None, 182, 128)             82048     ['Conv1D_0[0][0]']            
                                                                                                  
 Flatten_0 (Flatten)         (None, 23296)                0         ['Conv1D_1[0][0]']            
                                                                                       

In [45]:
# GAN
def compile_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    discriminator.trainable = False
    gan_input0 = layers.Input(shape=(latent_dim,))
    gan_input1 = layers.Input(shape=(326,))
    gan_output = discriminator(generator([gan_input0, gan_input1]))
    gan = tf.keras.Model([gan_input0, gan_input1], gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan

In [46]:
def generate_sequences(generator, latent_dim, num_sequences):
    noise = np.random.normal(0, 1, (num_sequences, latent_dim))
    bacteria = []
    for i in range(num_sequences):
      bacterium = np.zeros(326)
      bacterium[np.random.randint(0, 326)]=1
      bacteria.append([bacterium])
    bacteria = np.concatenate(bacteria, axis=0)
    generated_sequences = generator.predict([noise, bacteria])
    return generated_sequences, bacteria

In [47]:
def fit_gan(generator, discriminator, gan, seq_train, state_train, labels, epochs, batch_size, latent_dim):
    for epoch in range(epochs):
        for i in range(0, seq_train.shape[0], batch_size):
            sequences = seq_train[i:i + batch_size]
            state_train_batch = state_train[i:i + batch_size]
            current_batch_size = sequences.shape[0]

            # Train discriminator
            discriminator.trainable = True
            d_loss = discriminator.train_on_batch(sequences, labels)
            discriminator.trainable = False

            # Train generator
            noise = np.random.normal(0, 1, (current_batch_size, latent_dim))
            g_loss = gan.train_on_batch([noise, state_train_batch], labels)

            # Print the progress
            print(f"Epoch {epoch+1}/{epochs}, Batch {i//batch_size+1}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")

In [48]:
def feedback_loop_prediction(prediction, threshold):
  loop_prediction = np.zeros((prediction.shape[0], 1))
  for i in range (prediction.shape[0]):
    if prediction[i] < threshold:
      loop_prediction[i] = 1
  return loop_prediction

In [53]:
def RL_loop(generator, discriminator, gan, oracle, num_sequences):
  for i in range (n_iter_max) :
    sequences, bacteria = generate_sequences(generator, latent_dim, num_sequences)
    prediction = oracle.predict([sequences, bacteria])
    loop_prediction = feedback_loop_prediction(prediction, math.log(100,2))
    fit_gan(generator, discriminator, gan, sequences, bacteria, loop_prediction, epochs=5, batch_size=num_sequences, latent_dim=latent_dim)

In [54]:
latent_dim = 10
num_sequences = 100
n_iter_max = 10

gan = compile_gan(generator, discriminator)

RL_loop(generator, discriminator, gan, oracle, num_sequences)

Epoch 1/5, Batch 1, Discriminator Loss: [1.193461537361145, 0.36000001430511475], Generator Loss: 1.1725174188613892
Epoch 2/5, Batch 1, Discriminator Loss: [1.1837159395217896, 0.3400000035762787], Generator Loss: 1.0413657426834106
Epoch 3/5, Batch 1, Discriminator Loss: [1.2448168992996216, 0.3100000023841858], Generator Loss: 1.060624361038208
Epoch 4/5, Batch 1, Discriminator Loss: [1.126537799835205, 0.3799999952316284], Generator Loss: 0.919628918170929
Epoch 5/5, Batch 1, Discriminator Loss: [1.2678349018096924, 0.30000001192092896], Generator Loss: 0.577129602432251
Epoch 1/5, Batch 1, Discriminator Loss: [0.6505067348480225, 0.7200000286102295], Generator Loss: 0.7472743988037109
Epoch 2/5, Batch 1, Discriminator Loss: [0.5603952407836914, 0.7900000214576721], Generator Loss: 0.8301504254341125
Epoch 3/5, Batch 1, Discriminator Loss: [0.6779341697692871, 0.7300000190734863], Generator Loss: 0.6682521104812622
Epoch 4/5, Batch 1, Discriminator Loss: [0.6395621299743652, 0.7599