In [1]:
#https://www.analyticsvidhya.com/blog/2021/06/a-detailed-explanation-of-gan-with-implementation-using-tensorflow-and-keras/
from numpy import zeros, ones, expand_dims, asarray
from numpy.random import randn, randint
from keras.models import Sequential, load_model,Model
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import LeakyReLU, Dropout
from keras.layers import BatchNormalization, Activation
from keras import initializers
from keras.initializers import RandomNormal
from matplotlib import pyplot
import numpy as np
from math import sqrt
import pickle

In [2]:
#chargement des données
with open('liste_prot', 'rb') as f:
    X_train = pickle.load(f)
X_train = np.array(X_train)

In [3]:
#3500 prot de 939 aa avec 21 aa diff
X_train.shape

(3500, 939, 21)

In [4]:
#On vient générer une entrée de nombre aléatoire pour que le generator tente de les transformer en image
def generate_latent_points(latent_dim, n_samples):
    x_input = randn(latent_dim * n_samples)  
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

In [5]:
# On ajoute aux données réelles la classe 1 ce qui permet au discriminator de savoir que les données sont réelles lors du fitting
def generate_real_samples(X_train, n_samples):
    ix = randint(0, len(X_train), n_samples) 
    X = X_train[ix]  
    y = ones((n_samples, 1)) 
    return X, y

In [6]:
#on génère de fausses images avec la classe 0 ce qui permet au discriminator de savoir que les données sont fausses lors du fitting
def generate_fake_samples(generator, latent_dim, n_samples):
    z_input = generate_latent_points(latent_dim, n_samples)
    images = generator.predict(z_input)  
    y = zeros((n_samples, 1))
    return images, y

In [7]:
#on définie les couches du discriminator
def define_discriminator(in_shape=(939, 21, 1)):
    model = Sequential()
    model.add(Input(shape=in_shape))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['accuracy'])
    return model

In [8]:
discriminator = define_discriminator()

In [9]:
#on définie les couches du generator
def define_generator(latent_dim): 
    model_g = Sequential()
    model_g.add(Input(shape=(latent_dim,)))
    model_g.add(Dense(256))
    model_g.add(LeakyReLU(alpha=0.2))
    model_g.add(Dense(512))
    model_g.add(LeakyReLU(alpha=0.2))
    model_g.add(Dense(1024))
    model_g.add(LeakyReLU(alpha=0.2))
    model_g.add(Dense(939 * 21 * 1,Activation('tanh')))
    model_g.add(Reshape((939, 21, 1)))
    model_g.compile(loss='categorical_crossentropy',optimizer = 'Adam')
    return model_g

In [10]:
generator = define_generator(100)

In [11]:
#On créé le GAN contenant discriminator et generator
def define_gan(g_model, d_model):
    #quand on entraine le GAN le discriminator ne s'entraine pas seulement le generator
    d_model.trainable = False
    gan_output = d_model(g_model.output)
    model = Model(g_model.input, gan_output)
    model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['accuracy'])
    return model

In [12]:
gan_model = define_gan(generator, discriminator)


In [13]:
#fonction d'entrainement
def train(g_model, d_model, gan_model, X_train, latent_dim, n_epochs=100, n_batch=64):
    bat_per_epo = int(len(X_train) / n_batch)
    n_steps = bat_per_epo * n_epochs
    for i in range(n_steps):
        #on entraine le discriminator
        X_real, y_real = generate_real_samples(X_train, n_batch)
        d_loss_r, d_acc_r = d_model.train_on_batch(X_real, y_real)
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_batch)
        d_loss_f, d_acc_f = d_model.train_on_batch(X_fake, y_fake)
        # on entraine le generator (le discriminator ne peut pas être entrainé via gan_model)
        z_input = generate_latent_points(latent_dim, n_batch) 
        y_gan = ones((n_batch, 1)) 
        g_loss, g_acc = gan_model.train_on_batch(z_input, y_gan)
        #suivis des loss et acc pour s'assurer que le modèle s'entraine
        print('>%d, dr[%.3f,%.3f], df[%.3f,%.3f], g[%.3f,%.3f]' % (i+1, d_loss_r,d_acc_r, d_loss_f,d_acc_f, g_loss,g_acc))
    #on sauvegarde le generative model pour une future utilisation
    g_model.save('./exemple_GAN_model/model_GAN')

In [14]:
#entrainement du GAN
latent_dim = 100
train(generator, discriminator, gan_model, X_train, latent_dim, n_epochs=20, n_batch=64)

>1, dr[0.648,0.641], df[0.818,0.000], g[3.410,0.000]
>2, dr[0.000,1.000], df[6.676,0.000], g[0.704,0.453]
>3, dr[0.000,1.000], df[0.624,0.688], g[5.960,0.000]
>4, dr[2.648,0.203], df[0.149,1.000], g[3.276,0.000]
>5, dr[0.000,1.000], df[9.741,0.000], g[0.005,1.000]
>6, dr[0.000,1.000], df[7.434,0.000], g[7.522,0.000]
>7, dr[0.000,1.000], df[0.000,1.000], g[22.721,0.000]
>8, dr[0.000,1.000], df[0.002,1.000], g[17.176,0.000]
>9, dr[0.003,1.000], df[1.982,0.422], g[2.847,0.281]
>10, dr[0.988,0.750], df[5.194,0.000], g[3.465,0.062]
>11, dr[0.201,0.875], df[0.006,1.000], g[17.005,0.000]
>12, dr[0.073,0.953], df[0.000,1.000], g[26.279,0.000]
>13, dr[0.009,1.000], df[0.000,1.000], g[21.936,0.000]
>14, dr[0.005,1.000], df[1.737,0.594], g[13.077,0.000]
>15, dr[0.003,1.000], df[0.007,1.000], g[23.715,0.000]
>16, dr[0.019,0.984], df[0.000,1.000], g[29.569,0.000]
>17, dr[0.034,0.984], df[0.084,0.953], g[14.262,0.031]
>18, dr[0.003,1.000], df[6.401,0.016], g[1.669,0.484]
>19, dr[0.242,0.953], df[1.3

>150, dr[0.001,1.000], df[0.106,0.953], g[19.531,0.000]
>151, dr[0.001,1.000], df[0.167,0.938], g[22.434,0.000]
>152, dr[0.005,1.000], df[0.000,1.000], g[34.121,0.000]
>153, dr[0.016,1.000], df[0.003,1.000], g[30.852,0.000]
>154, dr[0.045,1.000], df[0.120,0.969], g[16.210,0.000]
>155, dr[0.111,0.953], df[0.406,0.812], g[17.518,0.000]
>156, dr[0.203,0.906], df[0.090,0.969], g[21.716,0.000]
>157, dr[0.113,0.984], df[0.072,0.969], g[22.881,0.000]
>158, dr[0.124,0.953], df[0.012,1.000], g[18.093,0.000]
>159, dr[0.031,1.000], df[0.034,0.984], g[20.523,0.000]
>160, dr[0.013,1.000], df[0.245,0.906], g[39.866,0.000]
>161, dr[0.018,1.000], df[0.000,1.000], g[65.099,0.000]
>162, dr[0.031,1.000], df[0.000,1.000], g[86.440,0.000]
>163, dr[0.105,0.984], df[0.000,1.000], g[76.341,0.000]
>164, dr[0.016,1.000], df[0.000,1.000], g[44.731,0.000]
>165, dr[0.004,1.000], df[0.174,0.938], g[40.217,0.000]
>166, dr[0.005,1.000], df[0.054,0.969], g[49.300,0.000]
>167, dr[0.016,1.000], df[0.000,1.000], g[55.594

>296, dr[0.088,0.984], df[0.000,1.000], g[474.561,0.000]
>297, dr[0.008,1.000], df[0.000,1.000], g[389.947,0.000]
>298, dr[0.009,1.000], df[0.000,1.000], g[232.001,0.000]
>299, dr[0.008,1.000], df[0.127,0.922], g[70.301,0.016]
>300, dr[0.031,0.984], df[0.102,0.969], g[65.096,0.031]
>301, dr[0.005,1.000], df[0.035,0.984], g[89.393,0.000]
>302, dr[0.030,1.000], df[0.028,0.984], g[75.507,0.000]
>303, dr[0.015,1.000], df[0.000,1.000], g[69.756,0.000]
>304, dr[0.022,1.000], df[0.009,1.000], g[93.066,0.000]
>305, dr[0.020,1.000], df[0.033,0.984], g[80.847,0.000]
>306, dr[0.008,1.000], df[0.003,1.000], g[66.042,0.000]
>307, dr[0.004,1.000], df[0.000,1.000], g[69.870,0.000]
>308, dr[0.006,1.000], df[0.036,0.984], g[56.463,0.016]
>309, dr[0.003,1.000], df[0.033,0.984], g[45.653,0.000]
>310, dr[0.003,1.000], df[0.016,1.000], g[36.213,0.047]
>311, dr[0.004,1.000], df[0.091,0.938], g[41.000,0.016]
>312, dr[0.006,1.000], df[0.001,1.000], g[45.690,0.000]
>313, dr[0.013,1.000], df[0.000,1.000], g[54.

>443, dr[0.052,0.984], df[0.045,0.984], g[22.497,0.016]
>444, dr[0.009,1.000], df[0.091,0.969], g[30.858,0.000]
>445, dr[0.013,1.000], df[0.016,0.984], g[31.518,0.000]
>446, dr[0.006,1.000], df[0.025,0.984], g[40.232,0.000]
>447, dr[0.013,1.000], df[0.000,1.000], g[34.813,0.000]
>448, dr[0.021,1.000], df[0.000,1.000], g[42.270,0.000]
>449, dr[0.008,1.000], df[0.002,1.000], g[28.081,0.016]
>450, dr[0.007,1.000], df[0.051,0.984], g[24.686,0.000]
>451, dr[0.007,1.000], df[0.015,0.984], g[26.079,0.016]
>452, dr[0.006,1.000], df[0.001,1.000], g[24.631,0.016]
>453, dr[0.004,1.000], df[0.010,1.000], g[27.762,0.016]
>454, dr[0.010,1.000], df[0.070,0.953], g[19.344,0.000]
>455, dr[0.003,1.000], df[0.046,0.969], g[20.341,0.016]
>456, dr[0.013,1.000], df[0.019,0.984], g[24.641,0.000]
>457, dr[0.037,0.984], df[0.048,0.969], g[25.016,0.000]
>458, dr[0.290,0.906], df[0.052,0.969], g[17.062,0.031]
>459, dr[0.003,1.000], df[0.120,0.938], g[22.398,0.031]
>460, dr[0.001,1.000], df[0.306,0.891], g[34.456

>590, dr[0.017,1.000], df[0.066,0.953], g[17.018,0.000]
>591, dr[0.070,0.984], df[0.034,0.984], g[16.626,0.000]
>592, dr[0.011,1.000], df[0.013,1.000], g[18.174,0.000]
>593, dr[0.123,0.969], df[0.133,0.938], g[20.566,0.000]
>594, dr[0.004,1.000], df[0.006,1.000], g[25.180,0.000]
>595, dr[0.007,1.000], df[0.001,1.000], g[24.838,0.000]
>596, dr[0.007,1.000], df[0.007,1.000], g[17.579,0.000]
>597, dr[0.047,0.984], df[0.273,0.875], g[22.058,0.000]
>598, dr[0.292,0.891], df[0.054,0.969], g[18.850,0.000]
>599, dr[0.003,1.000], df[0.024,1.000], g[23.273,0.000]
>600, dr[0.002,1.000], df[0.000,1.000], g[28.036,0.000]
>601, dr[0.001,1.000], df[0.000,1.000], g[29.038,0.000]
>602, dr[0.002,1.000], df[0.000,1.000], g[24.496,0.000]
>603, dr[0.004,1.000], df[0.042,0.984], g[17.303,0.000]
>604, dr[0.006,1.000], df[0.137,0.906], g[17.674,0.000]
>605, dr[0.031,1.000], df[0.078,0.984], g[21.502,0.000]
>606, dr[0.226,0.906], df[0.080,0.984], g[15.224,0.000]
>607, dr[0.009,1.000], df[0.415,0.859], g[22.643

>737, dr[0.085,0.984], df[0.016,1.000], g[24.425,0.000]
>738, dr[0.250,0.938], df[0.021,1.000], g[16.053,0.047]
>739, dr[0.066,1.000], df[0.174,0.906], g[14.636,0.016]
>740, dr[0.034,1.000], df[0.105,0.938], g[20.649,0.000]
>741, dr[0.026,1.000], df[0.062,0.969], g[24.542,0.016]
>742, dr[0.052,1.000], df[0.007,1.000], g[23.584,0.000]
>743, dr[0.061,1.000], df[0.024,1.000], g[22.063,0.000]
>744, dr[0.073,0.984], df[0.155,0.938], g[18.688,0.031]
>745, dr[0.053,1.000], df[0.216,0.906], g[21.946,0.000]
>746, dr[0.084,1.000], df[0.095,0.938], g[25.606,0.000]
>747, dr[0.202,0.953], df[0.127,0.953], g[25.858,0.000]
>748, dr[0.145,0.938], df[0.260,0.875], g[20.481,0.016]
>749, dr[0.099,0.969], df[0.222,0.922], g[27.841,0.000]
>750, dr[0.105,0.969], df[0.022,1.000], g[29.447,0.000]
>751, dr[0.157,0.969], df[0.058,0.984], g[20.155,0.000]
>752, dr[0.027,1.000], df[0.159,0.969], g[18.931,0.000]
>753, dr[0.068,0.984], df[0.096,0.953], g[21.370,0.000]
>754, dr[0.028,1.000], df[0.027,1.000], g[23.138

>884, dr[0.005,1.000], df[0.000,1.000], g[36.241,0.000]
>885, dr[0.004,1.000], df[0.005,1.000], g[28.086,0.000]
>886, dr[0.002,1.000], df[0.033,0.984], g[18.725,0.000]
>887, dr[0.002,1.000], df[0.018,0.984], g[24.248,0.000]
>888, dr[0.004,1.000], df[0.064,0.969], g[31.160,0.000]
>889, dr[0.004,1.000], df[0.001,1.000], g[31.968,0.000]
>890, dr[0.015,1.000], df[0.002,1.000], g[32.326,0.000]
>891, dr[0.007,1.000], df[0.016,1.000], g[23.210,0.000]
>892, dr[0.048,0.969], df[0.033,0.984], g[19.295,0.016]
>893, dr[0.009,1.000], df[0.085,0.969], g[20.775,0.000]
>894, dr[0.014,1.000], df[0.014,1.000], g[23.486,0.000]
>895, dr[0.035,1.000], df[0.016,0.984], g[25.098,0.000]
>896, dr[0.019,1.000], df[0.013,1.000], g[20.672,0.000]
>897, dr[0.033,1.000], df[0.054,0.984], g[18.286,0.016]
>898, dr[0.024,1.000], df[0.242,0.906], g[22.481,0.016]
>899, dr[0.205,0.891], df[0.214,0.922], g[18.899,0.000]
>900, dr[0.164,0.938], df[0.617,0.797], g[23.413,0.000]
>901, dr[0.233,0.922], df[0.440,0.844], g[18.574

>1030, dr[0.028,1.000], df[0.000,1.000], g[38.423,0.000]
>1031, dr[0.042,0.984], df[0.000,1.000], g[37.755,0.000]
>1032, dr[0.080,0.969], df[0.001,1.000], g[40.534,0.000]
>1033, dr[0.069,1.000], df[0.000,1.000], g[28.830,0.000]
>1034, dr[0.024,1.000], df[0.001,1.000], g[20.244,0.016]
>1035, dr[0.014,1.000], df[0.175,0.922], g[16.651,0.062]
>1036, dr[0.010,1.000], df[0.191,0.906], g[21.454,0.047]
>1037, dr[0.016,1.000], df[0.041,0.984], g[24.349,0.000]
>1038, dr[0.034,1.000], df[0.020,1.000], g[20.830,0.000]
>1039, dr[0.200,0.953], df[0.192,0.922], g[13.736,0.000]
>1040, dr[0.071,1.000], df[0.308,0.859], g[15.769,0.016]
>1041, dr[0.083,0.984], df[0.021,1.000], g[19.040,0.000]
>1042, dr[0.122,0.984], df[0.016,1.000], g[18.713,0.000]
>1043, dr[0.088,1.000], df[0.080,0.938], g[16.311,0.000]
>1044, dr[0.081,1.000], df[0.489,0.750], g[17.011,0.016]
>1045, dr[0.117,0.984], df[0.034,1.000], g[18.165,0.000]
>1046, dr[0.621,0.750], df[0.590,0.719], g[14.292,0.047]
>1047, dr[0.046,0.984], df[0.31



INFO:tensorflow:Assets written to: ./exemple_GAN_model/model_GAN\assets


INFO:tensorflow:Assets written to: ./exemple_GAN_model/model_GAN\assets


In [15]:
#test de création de prot avec le generator
model = load_model('./exemple_GAN_model/model_GAN')
latent_dim = 100
n_examples = 100
latent_points = generate_latent_points(latent_dim, n_examples)
latent_points.shape
X  = model.predict(latent_points)

In [16]:
#on a créé 100 protéines ayant 939 aa avec 21 classes d'aa
X.shape

(100, 939, 21, 1)

In [17]:
#to_do
#fonction permettant de passer du tableau en 939*21 à une séquence type "HWY..."