In [1]:
!git clone https://github.com/adcollin/AMPLify-Feedback.git
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import numpy as np
import math

Cloning into 'AMPLify-Feedback'...
remote: Enumerating objects: 368, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 368 (delta 5), reused 14 (delta 5), pack-reused 351[K
Receiving objects: 100% (368/368), 99.94 MiB | 7.58 MiB/s, done.
Resolving deltas: 100% (172/172), done.
Updating files: 100% (72/72), done.


In [2]:
generator = tf.keras.models.load_model('AMPLify-Feedback/model_weights/PeptideGenerator.keras')

In [3]:
generator.summary()

Model: "model_17"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 Input1 (InputLayer)         [(None, 326)]                0         []                            
                                                                                                  
 Input0 (InputLayer)         [(None, 10)]                 0         []                            
                                                                                                  
 Input1Transform (Dense)     (None, 10)                   3270      ['Input1[0][0]']              
                                                                                                  
 Concat (Concatenate)        (None, 20)                   0         ['Input0[0][0]',              
                                                                     'Input1Transform[0][0]

In [4]:
discriminator = tf.keras.models.load_model('AMPLify-Feedback/model_weights/discriminator.keras')

OSError: ignored

In [None]:
discriminator.summary()

In [None]:
def create_oracle():
    inputs0 = tf.keras.layers.Input((190,43),name="SeqInput")
    inputs1 = tf.keras.layers.Input((326,),name="StateInput")
    x = tf.keras.layers.Conv1D(128, 5, activation='relu', name="Conv1D_0")(inputs0) # kernel_size=5 works well
    x = tf.keras.layers.Conv1D(128, 5, activation='relu', name="Conv1D_1")(x) # Just two layers work better
    x = tf.keras.layers.Flatten(name="Flatten_0")(x)
    x = tf.keras.layers.Dense(512, activation="relu", name="LearnSeqDense_0")(x)
    x = tf.keras.layers.Concatenate(axis=1, name="Concat")([x, inputs1])
    x = tf.keras.layers.Dense(1024, activation="relu", name="LearnConcatDense_0")(x)
    x = tf.keras.layers.LayerNormalization(name="LayerNorm_0")(x)
    x = tf.keras.layers.Dense(512, activation="relu", name="LearnConcatDense_1")(x)
    x = tf.keras.layers.LayerNormalization(name="LayerNorm_1")(x)
    x = tf.keras.layers.Dense(1, activation="linear", name="Output")(x)
    model = tf.keras.models.Model([inputs0, inputs1], x, name="MICPredictor")
    return model

In [None]:
oracle = create_oracle()
path = "AMPLify-Feedback/model_weights/MICPredictor"
for i, layer in enumerate(oracle.layers):
    weights = np.load(f"{path}/layer_{i}_weights.npy", allow_pickle=True)
    layer.set_weights(weights)

In [None]:
oracle.summary()

In [None]:
# GAN
def compile_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    discriminator.trainable = False
    gan_input0 = layers.Input(shape=(latent_dim,))
    gan_input1 = layers.Input(shape=(326,))
    gan_output = discriminator(generator([gan_input0, gan_input1]))
    gan = tf.keras.Model([gan_input0, gan_input1], gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan

In [None]:
def generate_sequences(generator, latent_dim, num_sequences):
    noise = np.random.normal(0, 1, (num_sequences, latent_dim))
    bacteria = []
    for i in range(num_sequences):
      bacterium = np.zeros(326)
      bacterium[np.random.randint(0, 326)]=1
      bacteria.append([bacterium])
    bacteria = np.concatenate(bacteria, axis=0)
    generated_sequences = generator.predict([noise, bacteria])
    return generated_sequences, bacteria

In [None]:
################### if gradient approach using hallucination #####################
def update_generator(generator, MIC):
  learning_rate = 0.001
  with tf.GradientTape() as tape:
    y = tf.constant(MIC)
  gradients = tape.gradient(y, generator.trainable_variables)
  ####### check gradient
  for var, g in zip(layer.trainable_variables, gradients):
    print(f'{var.name}, shape: {g.shape}')
  #########
  gradients_and_vars = zip(gradients, generator.trainable_variables)
  optimizer = tf.optimizers.Adam(learning_rate)
  optimizer.apply_gradients(gradients_and_vars)

In [None]:
def fit_gan(generator, discriminator, gan, seq_train, state_train, labels, epochs, batch_size, latent_dim):
    for epoch in range(epochs):
        for i in range(0, seq_train.shape[0], batch_size):
            sequences = seq_train[i:i + batch_size]
            state_train_batch = state_train[i:i + batch_size]
            current_batch_size = sequences.shape[0]

            # Train discriminator
            discriminator.trainable = True
            d_loss = discriminator.train_on_batch(sequences, labels)
            discriminator.trainable = False

            # Train generator
            noise = np.random.normal(0, 1, (current_batch_size, latent_dim))
            g_loss = gan.train_on_batch([noise, state_train_batch], labels)

            # Print the progress
            print(f"Epoch {epoch+1}/{epochs}, Batch {i//batch_size+1}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")

In [None]:
def feedback_loop_prediction(prediction, threshold):
  loop_prediction = np.zeros((prediction.shape[0], 1))
  for i in range (prediction.shape[0]):
    if prediction[i] < threshold:
      loop_prediction[i] = 1
  return loop_prediction

In [None]:
def RL_loop(generator, oracle):
  for i in range (n_iter_max) :
    sequences, bacteria = generate_sequences(generator, latent_dim, num_sequences)
    prediction = oracle.predict([sequences, bacteria])
    loop_prediction = feedback_loop_prediction(prediction, math.log(100,2))
    seq_output = tf.one_hot(sequences.squeeze(), depth=43)
    seq_label = tf.one_hot(loop_prediction.squeeze(), depth=43)
    fit_gan(generator, discriminator, gan, seq_output, bacteria, seq_label, epochs=5, batch_size=32, latent_dim=latent_dim)



    ################ EVALUATION METRIC
    #generator.compile(optimizer=Adam(3e-5))
    #generator.fit([seq_output, bacteria], seq_val)
    #generator_update = update_generator(generator, prediction)
    #generator.train([seq_train, bacteria], seq_train)

  #return prediction

In [None]:
latent_dim = 10
num_sequences = 100
n_iter_max = 10

gan = compile_gan(generator, discriminator)

RL_loop(generator, oracle)

In [None]:
tf.debugging.disable_traceback_filtering()