In [1]:
import os
import argparse
import tensorflow as tf
import numpy as np
import lib

2024-10-01 16:03:28.184357: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-01 16:03:28.195813: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-01 16:03:28.206711: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-01 16:03:28.209866: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-01 16:03:28.219133: I tensorflow/core/platform/cpu_feature_guar

In [None]:
print(tf.__version__)

In [2]:
checkpoint = None
data_loc = os.path.abspath("/home1/smaruj/ExpressionGAN/scripts/data")
log_dir = os.path.abspath("./logs") 

In [3]:
generic = False
# This flag determines whether to generate random data instead of using a dataset from data_loc.
# Default: False, meaning the script will use a real dataset unless overridden by this flag.

# data_loc
data_start = 0
# log_dir
log_name = "gan_unbalanced"
# checkpoint = 

# model_type = "resnet"
model_type = "mlp"
# NOT CHECKED FOR RESNET SO FAR!
# Determines the type of model architecture, with options for "resnet" or "mlp" (multi-layer perceptron).

train_iters = 10 #500000
disc_iters = 5
checkpoint_iters = 100
latent_dim = 100
# Sets the size of the latent space (random noise vector) that the generator will use as input.

gen_dim = 100
disc_dim = 100
gen_layers = 5
disc_layers = 5

batch_size = 64
max_seq_len = 50
# Defines the maximum length of DNA sequences in the dataset.

vocab = "dna"
vocab_order = None
# Optionally sets a specific order for the one-hot encoding of vocabulary characters.

annotate = False
validate = False
# Determines whether a validation set will be used during training.

balanced_bins = False
learning_rate = 1e-5
lmbda = 10.
seed = 42
# Sets a random seed for reproducibility of results.

In [4]:
#%% set RNG
seed = seed
np.random.seed(seed)
tf.random.set_seed(seed)

In [5]:
#%% fix vocabulary of model
charmap, rev_charmap = lib.dna.get_vocab(vocab, vocab_order)
vocab_size = len(charmap)
# This function call sets up the vocabulary based on the type of sequence being used (as defined by args.vocab and args.vocab_order).
# args.vocab: Specifies the type of vocabulary to use, such as "dna" or "rna".
# args.vocab_order: Optionally provides a custom order for the one-hot encoding of characters.
# charmap: A dictionary that maps characters (like 'A', 'T', 'G', 'C' for DNA) to their corresponding one-hot encoding indices.
# rev_charmap: The reverse mapping from one-hot encoded indices back to characters.

In [6]:
I = np.eye(vocab_size)
# This creates an identity matrix of size vocab_size. The identity matrix will be used for one-hot encoding of the sequences, 
# where each character (nucleotide) is represented by a unique row in the matrix.

In [None]:
#%% organize model logs/checkpoints
#logdir, checkpoint_baseline = lib.log(args, samples_dir=True)
# This line calls a function lib.log() to set up directories for saving logs and model checkpoints during training.

In [7]:
#%% build GAN
latent_vars = tf.Variable(tf.random.normal(shape=[batch_size, latent_dim], seed=seed), name='latent_vars')
# This line initializes the latent space variables (latent_vars) that the generator will take as input. 
# These variables are sampled from a normal distribution and are of shape [args.batch_size, args.latent_dim].
data_enc_dim = vocab_size + 1 if annotate else vocab_size
data_size = max_seq_len * data_enc_dim
# The data encoding dimension data_enc_dim is adjusted based on whether annotations are included or not.
# If args.annotate is True, the encoding will have an additional annotation channel (hence vocab_size + 1).
# Otherwise, the encoding will only include the vocabulary size (vocab_size).
# data_size: Total size of the encoded sequence data, calculated as the maximum sequence length (args.max_seq_len) 
# multiplied by the encoding dimension.

2024-10-01 16:03:40.810027: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43626 MB memory:  -> device: 0, name: NVIDIA L40S, pci bus id: 0000:42:00.0, compute capability: 8.9


In [9]:
with tf.compat.v1.variable_scope("Generator", reuse=None) as scope:
  if model_type=="mlp":
    generator_model = lib.models.mlp_generator((latent_dim,), dim=gen_dim, output_size=50, num_layers=gen_layers)
    gen_data = generator_model(latent_vars)
    # changed the output_size to match the length of sequence
  elif model_type=="resnet":
    gen_data = lib.models.resnet_generator(latent_vars, gen_dim, max_seq_len, data_enc_dim, annotate)
#  gen_vars = lib.get_vars(scope)
# This block defines the Generator network, which takes the latent variables and generates DNA sequences:
# The type of model (mlp or resnet) is selected based on args.model_type.
# MLP Generator: A fully connected multilayer perceptron (MLP) that maps latent space to the DNA sequence.
# ResNet Generator: A ResNet-based architecture that generates more complex structures.
# gen_vars: The trainable variables of the generator (used later for optimization).

In [10]:
if model_type == "mlp":
    real_data = tf.random.normal([batch_size, max_seq_len])  # Will be replaces real data source
    eps = tf.random.uniform([batch_size, 1])  # Use tf.random.uniform instead of tf.random_uniform
elif model_type == "resnet":
    real_data = tf.random.normal([batch_size, max_seq_len, data_enc_dim])  # Will be replaces real data source
    eps = tf.random.uniform([batch_size, 1, 1])  # Use tf.random.uniform

# Interpolation between real_data and gen_data
interp = eps * real_data + (1 - eps) * gen_data
# This block sets up placeholders for real data (real_data) and calculates an interpolated data point between real and 
# generated data (used for gradient penalty, in case this is a Wasserstein GAN with gradient penalty (WGAN-GP)):
# real_data: Placeholder for real DNA sequences, shaped differently based on whether it's an MLP or ResNet model.
# eps: A random variable used to linearly interpolate between real and generated data.
# interp: The interpolation between real and generated data, used for the Lipschitz continuity constraint in WGAN-GP.

In [11]:
interp

<tf.Tensor: shape=(64, 50), dtype=float32, numpy=
array([[ 0.22526273, -0.5796349 ,  0.28014633, ...,  0.57496464,
         0.4842909 , -0.22990184],
       [ 0.10483949, -0.47229636, -0.0832711 , ...,  0.1461472 ,
        -0.75750464, -0.75956464],
       [-0.11446839, -1.3069288 , -0.7515979 , ...,  0.9511246 ,
        -0.06501768,  0.29098797],
       ...,
       [-0.08224764,  0.13212356, -0.6718516 , ...,  0.80016863,
         0.21342859, -0.91508466],
       [ 0.2805145 , -0.05317048,  0.00689527, ..., -1.0051547 ,
        -0.7431211 , -0.38497156],
       [ 0.04917089, -0.06297167,  0.84044915, ...,  0.14386684,
         0.0486706 ,  0.3304389 ]], dtype=float32)>

In [12]:
with tf.compat.v1.variable_scope("Discriminator", reuse=None) as scope:
  if model_type=="mlp":
    discriminator_model = lib.models.mlp_discriminator((50,), dim=disc_dim, num_layers=disc_layers)
    gen_score = discriminator_model(gen_data)
    # changed frm dara_size to 50
  elif model_type=="resnet":
    gen_score = lib.models.resnet_discriminator(gen_data, disc_dim, max_seq_len, data_enc_dim, res_layers=disc_layers)
#  disc_vars = lib.get_vars(scope)
# This block defines the Discriminator network, which scores both real and generated sequences: It checks whether the input is real or fake.
# Similar to the generator, the discriminator architecture is chosen based on whether it's an MLP or ResNet.
# gen_score: The discriminator’s score for generated data (how “real” it thinks the generated data is).
# disc_vars: The trainable variables of the discriminator.

In [13]:
with tf.compat.v1.variable_scope("Discriminator", reuse=True) as scope:
  if model_type=="mlp":
    real_score = discriminator_model(real_data)
    interp_score = discriminator_model(interp)
  elif model_type=="resnet":
    real_score = lib.models.resnet_discriminator(real_data, disc_dim, max_seq_len, data_enc_dim, res_layers=disc_layers)
    interp_score = lib.models.resnet_discriminator(interp, disc_dim, max_seq_len, data_enc_dim, res_layers=disc_layers)
# This re-uses the discriminator to score both the real and interpolated data:
# real_score: The discriminator's score for real data.
# interp_score: The discriminator's score for the interpolated data (used for gradient penalty in WGAN-GP).
# This block reuses the discriminator with reuse=True so that the same weights are used for real, generated, and interpolated data.

In [14]:
#%% cost function
mean_gen_score = tf.reduce_mean(gen_score)
mean_real_score = tf.reduce_mean(real_score)
# gen_score and real_score: These are the outputs of the discriminator when given either generated data (gen_score) or 
# real data (real_score). These scores represent how likely the discriminator thinks the data is "real" (with higher values 
# indicating more "real").

gen_cost = - mean_gen_score #tf.reduce_mean(gen_score)
# The goal of the Generator is to maximize the discriminator's score for the generated data, which means it wants the 
# generated data to appear as real as possible. By negating mean_gen_score, the generator will be trained to increase this 
# value (since optimizers minimize the loss).
disc_diff = mean_gen_score - mean_real_score 
# The Discriminator wants to maximize the difference between the scores it gives to real and fake data, i.e., 
# it wants the real data to have high scores and the generated data to have low scores.
# disc_diff measures the gap between the mean score assigned to fake data (mean_gen_score) and real data (mean_real_score).

In [15]:
# This line does exactly the same as the line above
# disc_diff = tf.reduce_mean(gen_score) - tf.reduce_mean(real_score)
#%% gradient penalty
# grads = tf.gradients(interp_score, interp)[0]

with tf.GradientTape() as tape:
    # Forward pass: calculate the output (scores)
    tape.watch(interp)

    # Get the score from the discriminator for the interpolated data
    if model_type == "mlp":
        interp_score = discriminator_model(interp)
    # elif model_type == "resnet":
    #     interp_score = lib.models.resnet_discriminator(interp, disc_dim, max_seq_len, data_enc_dim, res_layers=disc_layers)
    
# Compute gradients
grads = tape.gradient(interp_score, interp)

# interp_score: This is the score from the discriminator for the interpolated data between real and generated samples.
# The interpolated data, a combination of real data and generated data, is used to compute the gradients.
# tf.gradients(interp_score, interp): This computes the gradients of the discriminator's score with respect to the interpolated samples. The gradients describe how much the discriminator's output changes when the interpolated data changes.
# Why interpolate? In WGAN-GP, the gradient penalty is applied to points interpolated between real and generated data 
# to enforce smooth transitions between them.

In [17]:
grad_norms = tf.norm(grads, axis=1) # might need extra term for numerical stability of SGD
# This calculates the norm (magnitude) of the gradients along the specified axes (here, over all spatial dimensions of the data). 
# Essentially, this gives the overall strength of the gradient at each interpolated point.

In [18]:
grad_penalty = lmbda * tf.reduce_mean((grad_norms - 1.) ** 2)
# grad_norms - 1.: In a well-behaved discriminator that respects the Lipschitz constraint, the gradients with respect to the input 
# should have a norm close to 1. This term measures how far the gradient norms are from 1.
# (grad_norms - 1.) ** 2: This squares the difference, penalizing gradients that deviate significantly from 1.
# tf.reduce_mean((grad_norms - 1.) ** 2): The penalty is averaged over all the interpolated samples to create a single penalty term.
# args.lmbda: This is a hyperparameter that controls the strength of the gradient penalty. 
# It ensures that the discriminator learns a Lipschitz function by penalizing gradient norms that are far from 1.
disc_cost = disc_diff + grad_penalty
# disc_cost: The total discriminator loss, which now includes two components:
# disc_diff: The difference between the discriminator’s score for generated and real data. This encourages the discriminator to distinguish between real and fake data.
# grad_penalty: The gradient penalty term, which enforces the Lipschitz constraint by penalizing large deviations from the desired gradient norm of 1.

In [19]:
from tensorflow.keras.optimizers import Adam

# Create the Adam optimizer for the generator
gen_optimizer = Adam(learning_rate=learning_rate, beta_1=0.5, beta_2=0.9, name='gen_optimizer')

# Create the Adam optimizer for the discriminator
disc_optimizer = Adam(learning_rate=learning_rate, beta_1=0.5, beta_2=0.9, name='disc_optimizer')

In [20]:
train_cost = []
gen_costs = []
gen_scores = []
real_scores = []
gen_counts = []
train_counts = []
valid_cost = []
valid_counts = []

In [21]:
# Load dataset
def feed(batch_size=batch_size, seq_len=max_seq_len, data_len=None):
    while True:
        samples = np.random.choice(vocab_size, [batch_size, seq_len])
        data = np.vstack([np.expand_dims(I[vec], 0) for vec in samples])
        if model_type == "mlp":
            reshaped_data = np.reshape(data, [batch_size, -1])
        elif model_type == "resnet":
            reshaped_data = data
        yield reshaped_data

In [22]:
if generic:
    print("\n Inside if and preparing random data!")
    if annotate:
        raise Exception("args `annotate` and `generic` are incompatible.")

    train_seqs = feed()
    if validate:
        valid_seqs = feed(data_len=100)
else:
    print("\n Loading seqs data!")
    data = lib.dna.load(data_loc, vocab_order=vocab_order, max_seq_len=max_seq_len,
                         data_start_line=data_start, vocab=vocab, valid=validate,
                         annotate=annotate)
    if validate:
        split = len(data) // 2
        train_data = data[:split]
        valid_data = data[split:]
        if len(train_data) == 1:
            train_data = train_data[0]
        if len(valid_data) == 1:
            valid_data = valid_data[0]
    else:
        train_data = data
    if annotate:
        if validate:
            valid_data = np.concatenate(valid_data, 2)
        train_data = np.concatenate(train_data, 2)

    def feed(data, batch_size=batch_size):
        num_batches = len(data) // batch_size
        if model_type == "mlp":
            reshaped_data = np.reshape(data, [data.shape[0], -1])
        elif model_type == "resnet":
            reshaped_data = data
        while True:
            for ctr in range(num_batches):
                yield reshaped_data[ctr * batch_size: (ctr + 1) * batch_size]

    train_seqs = feed(train_data)


 Loading seqs data!


In [None]:
# Load checkpoint (if any)
# if args.checkpoint:
#     checkpoint = tf.train.Checkpoint(optimizer=gen_optimizer, generator=gen_data, discriminator=real_data)
#     checkpoint.restore(args.checkpoint).expect_partial()

In [23]:
# Train GAN
print("Training GAN")
print("================================================")
fixed_latents = []
nSampleBatches = 10
for nBatches in range(nSampleBatches):
    fixed_latents.append(np.random.normal(size=[batch_size, latent_dim]))

Training GAN


In [None]:
for idx in range(train_iters):
    true_count = idx + 1
    # Train discriminator
    for _ in range(disc_iters):
        real_batch = next(train_seqs)
        
        with tf.GradientTape() as tape:
            true_batch = tf.convert_to_tensor(real_batch)
            latent_vars = tf.random.normal(shape=[batch_size, latent_dim])
            gen_data = generator_model(latent_vars)
            
            real_score = discriminator_model(true_batch)
            
            # Compute costs
            mean_gen_score = tf.reduce_mean(gen_score)
            mean_real_score = tf.reduce_mean(real_score)
            gen_cost = -mean_gen_score
            disc_diff = mean_gen_score - mean_real_score

            interp = eps * real_data + (1 - eps) * gen_data
            interp_score = discriminator_model(interp)
            
            # Gradient penalty
            with tf.GradientTape() as penalty_tape:
                # Forward pass: calculate the output (scores)
                penalty_tape.watch(interp)

                # Get the score from the discriminator for the interpolated data
                if model_type == "mlp":
                    interp_score = lib.models.mlp_discriminator(interp, dim=disc_dim, input_size=50, num_layers=disc_layers)
                elif model_type == "resnet":
                    interp_score = lib.models.resnet_discriminator(interp, disc_dim, max_seq_len, data_enc_dim, res_layers=disc_layers)
               
            # # Compute gradients
            grads = penalty_tape.gradient(interp_score, interp)
            grad_norms = tf.norm(grads, axis=1)
            grad_penalty = lmbda * tf.reduce_mean((grad_norms - 1.) ** 2)
            
            disc_cost = disc_diff + grad_penalty

        print(gradients, disc_vars)
        gradients = tape.gradient(disc_cost, disc_vars)
        disc_optimizer.apply_gradients(zip(gradients, disc_vars))

    # Train generator
    with tf.GradientTape() as tape:
        latent_vars = tf.random.normal(shape=[batch_size, latent_dim])
        gen_data = lib.models.mlp_generator(latent_vars, dim=gen_dim, input_size=latent_dim, output_size=data_size, num_layers=gen_layers)
        gen_score = lib.models.discriminator(gen_data, dim=disc_dim, input_size=data_size, num_layers=disc_layers)
        gen_cost = -tf.reduce_mean(gen_score)

    gradients = tape.gradient(gen_cost, gen_data.trainable_variables)
    gen_optimizer.apply_gradients(zip(gradients, gen_data.trainable_variables))

    # Log results
    if idx % 10 == 0:
        train_cost.append(gen_cost.numpy())
        gen_costs.append(gen_cost.numpy())
        gen_scores.append(mean_gen_score.numpy())
        real_scores.append(mean_real_score.numpy())
        train_counts.append(true_count)
        print("Iteration: {}. Generator Cost: {:.4f}, Discriminator Cost: {:.4f}".format(idx, gen_cost.numpy(), disc_cost.numpy()))
    
    # Save checkpoints
    # if idx % args.checkpoint_iters == 0:
    #     checkpoint.save(file_prefix=checkpoint_baseline)

In [None]:
print("Discriminator Trainable Variables:", disc_vars)
print("Generator Trainable Variables:", gen_data.trainable_variables)

In [None]:
# After training, save model
tf.saved_model.save(gen_data, os.path.join(logdir, "final_model"))