In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/AvonYangXX1/DreamWalker.git
from DreamWalker.commandline_scripts.utils import ModelLoader, Preprocessing
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers
import numpy as np
import math
import matplotlib.pyplot as plt
import statistics
import pandas as pd
import random

In [None]:
loader = ModelLoader()
oracle = loader.create_oracle()

In [None]:
data = pd.read_csv("DreamWalker/data/processed_data/AMP/PeptideMarkerMIC.csv")
targets = ["Escherichia_coli", "Pseudomonas_aeruginosa", "Staphylococcus_aureus", "Bacillus_subtilis", "Salmonella_enterica"]
SpeciesToMarker = data.query("Target in @targets").drop_duplicates("Target")[["Target", "Marker"]].set_index("Target")

In [None]:
pp = Preprocessing(5)
SpeciesToKmers = []
for seq in SpeciesToMarker.values.tolist():
    SpeciesToKmers.append(pp.CountKmers(seq))
SpeciesToKmers = np.concatenate(SpeciesToKmers, axis=0)

In [None]:
aa_vocal = np.load("DreamWalker/model_weights/PepTV_vocal.npy")
pep_decoder = tf.keras.layers.StringLookup(vocabulary=aa_vocal[1:], invert=True, oov_token='')

In [None]:
class Dreaming(): # This class trains the DreamWalker (exploring Oracle's dream)
    def __init__(self, oracle, pep_decoder, alpha=0.25, decoder_choice="decoder", learning_rate=1e-3):
        self.alpha = alpha
        self.generated_onehot = None
        self.loss = None
        self.gradients = None
        self.optimizer = None
        self.generator = loader.create_generator(decoder_choice)
        self.oracle = oracle
        self.pep_decoder = pep_decoder
        self.learning_rate = learning_rate

    @tf.function
    def compute_gradient(self, marker_batch_pos, marker_batch_neg, batch_size):

        with tf.GradientTape() as gen_tape:
            gen_tape.watch([marker_batch_pos, marker_batch_neg])
            self.generated_onehot = self.generator(marker_batch_pos, training=True)
            mic_pos = tf.reduce_mean(self.oracle([self.generated_onehot, marker_batch_pos]))
            mic_neg = 0
            for i in range(4):
                mic_neg += tf.reduce_mean( self.oracle([self.generated_onehot, marker_batch_neg[:, i]] ))

            self.loss = mic_pos - self.alpha * mic_neg

        self.gradients = gen_tape.gradient(self.loss, self.generator.trainable_variables)
        self.optimizer.apply_gradients(zip(self.gradients, self.generator.trainable_variables))
        return mic_pos, mic_neg

    def ExploreHallucination(self, SpeciesToKmers, n_iter_max, batch_size):

        # prepare kmers of targets
        markers = []
        for i in range(n_iter_max * batch_size):
            key = [i for i in range(5)]
            random.seed(i)
            random.shuffle(key)
            markers.append([SpeciesToKmers[key]])
        markers = np.concatenate(markers, axis=0)
        ###
        mic_pos_log, mic_neg_log  = [], []
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
        for i in range (n_iter_max):
            #Sequences prediction and gradient computation
            marker_batch_pos = markers[i:i + batch_size, 0]
            marker_batch_pos = tf.convert_to_tensor(marker_batch_pos, np.float64)

            marker_batch_neg = markers[i:i + batch_size, 1:5]
            marker_batch_neg = tf.convert_to_tensor(marker_batch_neg, np.float64)

            mic_pos, mic_neg = self.compute_gradient(marker_batch_pos, marker_batch_neg, batch_size)
            # generated_sequences = self.onehot2seq(generated_onehot)
            mic_pos_log.append(mic_pos.numpy())
            mic_neg_log.append(mic_neg.numpy())

            # print(f"Iter {i+1}/{n_iter_max}; MIC_POS {np.mean(mic_pos):.4f}; MIC_NEG {np.mean(mic_neg):.4f}")
        mic_neg_log = np.array(mic_neg_log)
        mic_pos_log = np.array(mic_pos_log)
        return mic_pos_log, mic_neg_log

    def onehot2seq(self, onehot):
        decoded_sequences = []
        chars_array = self.pep_decoder(tf.math.argmax(onehot, axis=2)).numpy().astype('str')
        decoded_sequences += ["".join(chars) for chars in chars_array]
        return decoded_sequences

In [None]:
n_iter_max = 2000
batch_size = 128
dreaming = Dreaming(oracle, pep_decoder, alpha=0.125, decoder_choice="GAN", learning_rate=1e-4)
mic_pos_log, mic_neg_log = dreaming.ExploreHallucination(SpeciesToKmers, n_iter_max, batch_size)

In [None]:
onehot = dreaming.generator.predict(SpeciesToKmers, verbose=0)
print(oracle.predict([onehot, SpeciesToKmers], verbose=0))
decoded_sequences = dreaming.onehot2seq(onehot)
decoded_sequences

In [None]:
parameters = {"MIC_POS" : mic_pos_log,
              "MEAN_MIC_NEG" : 0.25 * mic_neg_log,
              }
for p in list(parameters.keys()):
    plt.plot(range(len(parameters[p])), parameters[p], linestyle='-', label = p)
plt.legend(loc = "upper right")
plt.ylabel("MIC")
plt.xlabel("Batch index")
plt.title("Performances with the Hallucination Exploration")
plt.savefig("MIC.png")

In [None]:
# path = "drive/MyDrive/DreamWalker/model_weights/DreamWalkerWeights"
# for i, layer in enumerate(dreaming.generator.layers):
#     param = layer.get_weights()
#     if len(param) == 0:
#         continue
#     weights = param[0]
#     biases = param[1]
#     np.savez_compressed(f'{path}/layer_{i}_weights', weights=weights)
#     np.savez_compressed(f'{path}/layer_{i}_biases', biases=biases)