In [22]:
import os
import random
import cv2
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from tensorflow.keras import Model, Sequential
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Input, Reshape, Dense, Dropout, \
    Activation, LeakyReLU, Conv2D, Conv2DTranspose, Embedding, \
    Concatenate, multiply, Flatten, BatchNormalization
from tensorflow.keras.initializers import glorot_normal
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import load_model

In [23]:
#define path
bagan_gp_weight_path = "./model/bagan_gp_step_weights/"
encoder_model_path = "./model/seq_encoder_model/seq_encoder_model.h5"
decoder_model_path = "./model/seq_encoder_model/seq_decoder_model.h5"
target_seq_path = "./result/seq_result/"

##define dict for seq
seq_dict={ 'T' : 0.25, 'C' : 0.5, 'A' : 0.75, 'G' : 1}
inverse_seq_dict={ 0.25 : 'T', 0.5 : 'C', 0.75 : 'A', 1 : 'G'}

#load encoder model
encoder_model=load_model(encoder_model_path)
decoder_model=load_model(decoder_model_path)



In [24]:
# construct BAGAN_GP model
class BAGAN_GP(Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        gp_weight=10.0,
        trainRatio=3,
    ):
        super(BAGAN_GP, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.train_ratio = trainRatio
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(BAGAN_GP, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images, labels):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # get the interplated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([interpolated, labels], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, data):
        if isinstance(data, tuple):
            real_images = data[0]
            labels = data[1]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        ########################### Train the Discriminator ###########################
        # For each batch, we are going to perform cwgan-like process
        for i in range(self.train_ratio):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            fake_labels = tf.random.uniform((batch_size,), 0, n_classes)
            wrong_labels = tf.random.uniform((batch_size,), 0, n_classes)
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator([random_latent_vectors, fake_labels], training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([fake_images, fake_labels], training=True)
                # Get the logits for real images
                real_logits = self.discriminator([real_images, labels], training=True)
                # Get the logits for wrong label classification
                wrong_label_logits = self.discriminator([real_images, wrong_labels], training=True)

                # Calculate discriminator loss using fake and real logits
                d_cost = self.d_loss_fn(real_logits=real_logits, fake_logits=fake_logits,
                                        wrong_label_logits=wrong_label_logits
                                        )

                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images, labels)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        ########################### Train the Generator ###########################
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        fake_labels = tf.random.uniform((batch_size,), 0, n_classes)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator([random_latent_vectors, fake_labels], training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images, fake_labels], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

# Build Discriminator without inheriting the pre-trained Encoder
def discriminator_cwgan():
    # weight initialization
    init = RandomNormal(stddev=0.02)

    img = Input(img_size)
    label = Input((1,), dtype='int32')


    x = Conv2D(64, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(img)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    x = LeakyReLU(0.2)(x)

    x = Flatten()(x)

    le = Flatten()(Embedding(n_classes, 512)(label))
    le = Dense(4 * 4 * 256)(le)
    le = LeakyReLU(0.2)(le)
    x_y = multiply([x, le])
    x_y = Dense(512)(x_y)

    out = Dense(1)(x_y)

    model = Model(inputs=[img, label], outputs=out)

    return model

# Build discriminator with pre-trained Encoder
def build_discriminator(encoder):

    label = Input((1,), dtype='int32')
    img = Input(img_size)

    inter_output_model = Model(inputs=encoder.input, outputs=encoder.layers[-3].output)
    x = inter_output_model(img)

    le = Flatten()(Embedding(n_classes, 512)(label))
    le = Dense(4 * 4 * 256)(le)
    le = LeakyReLU(0.2)(le)
    x_y = multiply([x, le])
    x_y = Dense(512)(x_y)

    out = Dense(1)(x_y)

    model = Model(inputs=[img, label], outputs=out)

    return model
    
def generator_label(embedding, decoder):
    # # Embedding model needs to be trained along with GAN training
    # embedding.trainable = False

    label = Input((1,), dtype='int32')
    latent = Input((latent_dim,))

    labeled_latent = embedding([latent, label])
    gen_img = decoder(labeled_latent)
    model = Model([latent, label], gen_img)

    return model

def embedding_labeled_latent():
    # # weight initialization
    # init = RandomNormal(stddev=0.02)

    label = Input((1,), dtype='int32')
    noise = Input((latent_dim,))
    # ne = Dense(256)(noise)
    # ne = LeakyReLU(0.2)(ne)

    le = Flatten()(Embedding(n_classes, latent_dim)(label))
    # le = Dense(256)(le)
    # le = LeakyReLU(0.2)(le)

    noise_le = multiply([noise, le])
    # noise_le = Dense(latent_dim)(noise_le)

    model = Model([noise, label], noise_le)

    return model

def decoder():
    # weight initialization
    init = RandomNormal(stddev=0.02)

    noise_le = Input((latent_dim,))

    x = Dense(4*4*256)(noise_le)
    x = LeakyReLU(alpha=0.2)(x)

    ## Size: 4 x 4 x 256
    x = Reshape((4, 4, 256))(x)

    ## Size: 8 x 8 x 128
    x = Conv2DTranspose(filters=128,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        padding='same',
                        kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    ## Size: 16 x 16 x 128
    x = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    ## Size: 32 x 32 x 64
    x = Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    ## Size: 64 x 64 x 3
    generated = Conv2DTranspose(channel, (4, 4), strides=(2, 2), padding='same', activation='tanh', kernel_initializer=init)(x)


    generator = Model(inputs=noise_le, outputs=generated)
    return generator

latent_dim=128
channel=1
n_classes=5
img_size=(64,64,1)
trainRatio=5

de = decoder()
em = embedding_labeled_latent()

d_model = discriminator_cwgan()  # without initialization
g_model = generator_label(em, de)  # initialized with Decoder and Embedding
bagan_gp = BAGAN_GP(
    discriminator=d_model,
    generator=g_model,
    latent_dim=latent_dim,
    trainRatio=trainRatio,
)

In [25]:
#creat seq api
def min_distance(x):
    y = 'T'
    min_distance = abs(x-0.25)
    for idx in inverse_seq_dict:
        distance = abs(x-idx)
        if distance < min_distance:
            min_distance = distance
            y = inverse_seq_dict[idx]
    return y

def value_to_seq(value):
    seq = []
    for i in value:
        sub_seq = []
        for j in i:
            sub_seq.append(min_distance(j))
        seq.append("".join(sub_seq))
    return seq

def use_bagan_create(num,target):
## num: numbers of seq int32
## target: class of seq int32 (0,1,2,3,4)
    num_sample=num
    latent_gen = np.random.normal(size=(num_sample, latent_dim))
    decoded_imgs = bagan_gp.generator.predict([latent_gen, target*np.ones(num_sample)])
    seq_value = decoder_model.predict(decoded_imgs)
    seq_list = value_to_seq(seq_value)
    return seq_list

In [27]:
#create seq
for i in range(20):
    bagan_gp.load_weights(bagan_gp_weight_path + "bagan_gp_weight_%d" %(i)).expect_partial()
    for j in range(5):
        pre_seq=use_bagan_create(1024,j)
        with open(target_seq_path + "pre_seq_step{i}_class{j}.txt".format(i=i+1,j=j+1),'w') as file:
            for sub_seq in pre_seq:
                file.write(sub_seq+'\n')

In [29]:
def count_seq_same(step_num,class_num):
##return (no_repeat_num,count_same)
#no_repeat_num:not repeat with themselves
#count_same:repeat with training set
    count_same=0
    no_repeat_num=0
    with open(target_seq_path+'pre_seq_step{step_num}_class{class_num}.txt'.format(step_num=step_num,
                                                                                               class_num=class_num),
              encoding='gbk') as f:
        begin_txt=[]
        for line in f:
            begin_txt.append(line.strip())
    no_repeat=[]
    for i in begin_txt:
        switch=True
        for j in no_repeat:
            if i==j:
                switch=False
        if switch:
            no_repeat.append(i)
    no_repeat_num=len(no_repeat)
    with open('./data/raw_data/{class_num}.txt'.format(class_num=class_num),encoding='gbk') as f:
        begin_txt=[]
        for line in f:
            begin_txt.append(line.strip())
    next_txt=no_repeat
    for i in next_txt:
        for j in begin_txt:
            if i==j:
                count_same+=1
                break
    return (no_repeat_num,count_same)
all_step_no_repeat=np.empty((20,5))
all_step_count_same=np.empty((20,5))
for i in range(20):
    for j in range(5):
        (num1,num2)=count_seq_same(i+1,j+1)
        all_step_no_repeat[i][j]=num1
        all_step_count_same[i][j]=num2

In [30]:
print(all_step_no_repeat)

[[982. 822. 895. 833. 852.]
 [962. 812. 858. 820. 809.]
 [983. 811. 884. 819. 825.]
 [977. 797. 856. 833. 790.]
 [975. 788. 888. 807. 804.]
 [977. 794. 867. 820. 809.]
 [958. 810. 857. 803. 826.]
 [965. 794. 878. 827. 816.]
 [983. 791. 892. 796. 788.]
 [975. 806. 872. 814. 821.]
 [979. 822. 865. 826. 802.]
 [964. 793. 898. 808. 836.]
 [961. 793. 889. 808. 830.]
 [973. 818. 894. 800. 790.]
 [970. 822. 877. 811. 823.]
 [978. 807. 878. 816. 809.]
 [975. 794. 896. 807. 823.]
 [969. 781. 903. 827. 821.]
 [969. 783. 894. 793. 809.]
 [964. 789. 889. 813. 784.]]


In [31]:
print(all_step_count_same)

[[237. 205. 124. 120. 137.]
 [264. 242. 147. 131. 177.]
 [274. 255. 143. 111. 181.]
 [299. 258. 147. 134. 186.]
 [306. 265. 157. 142. 174.]
 [299. 267. 158. 127. 172.]
 [313. 280. 140. 130. 187.]
 [312. 262. 160. 141. 195.]
 [299. 270. 152. 132. 173.]
 [318. 259. 144. 132. 186.]
 [306. 260. 136. 134. 174.]
 [292. 248. 142. 123. 176.]
 [310. 237. 152. 127. 177.]
 [286. 250. 141. 131. 172.]
 [301. 251. 142. 115. 180.]
 [294. 230. 157. 133. 172.]
 [300. 255. 145. 134. 168.]
 [292. 229. 154. 135. 175.]
 [300. 236. 154. 128. 176.]
 [292. 245. 146. 117. 169.]]


In [32]:
print(all_step_no_repeat-all_step_count_same)

[[745. 617. 771. 713. 715.]
 [698. 570. 711. 689. 632.]
 [709. 556. 741. 708. 644.]
 [678. 539. 709. 699. 604.]
 [669. 523. 731. 665. 630.]
 [678. 527. 709. 693. 637.]
 [645. 530. 717. 673. 639.]
 [653. 532. 718. 686. 621.]
 [684. 521. 740. 664. 615.]
 [657. 547. 728. 682. 635.]
 [673. 562. 729. 692. 628.]
 [672. 545. 756. 685. 660.]
 [651. 556. 737. 681. 653.]
 [687. 568. 753. 669. 618.]
 [669. 571. 735. 696. 643.]
 [684. 577. 721. 683. 637.]
 [675. 539. 751. 673. 655.]
 [677. 552. 749. 692. 646.]
 [669. 547. 740. 665. 633.]
 [672. 544. 743. 696. 615.]]
