In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [2]:
device = 'mps'

In [2]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np

# Define the residual block used in the generator and discriminator
class ResidualBlock(layers.Layer):
    def __init__(self, units, dropout_rate=0.1):
        super(ResidualBlock, self).__init__()
        self.units = units
        self.dropout_rate = dropout_rate
        self.dense1 = layers.Dense(units)
        self.swish1 = layers.Activation(tf.nn.swish)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.dense2 = layers.Dense(units)
        self.swish2 = layers.Activation(tf.nn.swish)
        self.dropout2 = layers.Dropout(dropout_rate)
    
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.swish1(x)
        x = self.dropout1(x)
        x = self.dense2(x)
        x = self.swish2(x)
        x = self.dropout2(x)
        return inputs + x

# Generator model
class Generator(Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.dense1 = layers.Dense(256)
        self.res_block1 = ResidualBlock(256)
        self.res_block2 = ResidualBlock(512)
        self.res_block3 = ResidualBlock(1024)
        self.dense2 = layers.Dense(21, activation='sigmoid')  # Output layer for sequence of length 21
    
    def call(self, z, labels):
        x = tf.concat([z, labels], axis=-1)
        x = self.dense1(x)
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        x = self.dense2(x)
        return x

# Discriminator model
class Discriminator(Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dense1 = layers.Dense(1024)
        self.res_block1 = ResidualBlock(1024)
        self.res_block2 = ResidualBlock(512)
        self.res_block3 = ResidualBlock(256)
        self.dense2 = layers.Dense(1)  # Output layer without activation
    
    def call(self, x, labels):
        x = tf.concat([x, labels], axis=-1)
        x = self.dense1(x)
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        x = self.dense2(x)
        return x

# WGAN-GP loss functions
def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def gradient_penalty(discriminator, real_data, fake_data, labels):
    alpha = tf.random.uniform([real_data.shape[0], 1], 0.0, 1.0)
    interpolated = alpha * real_data + (1 - alpha) * fake_data
    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        interpolated_output = discriminator(interpolated, labels)
    grads = tape.gradient(interpolated_output, [interpolated])[0]
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

# Training step function
@tf.function
def train_step(real_data, labels, generator, discriminator, generator_optimizer, discriminator_optimizer, batch_size, gp_weight=10):
    z_dim = 128  # Latent vector dimension
    
    # Sample random noise and generate fake data
    z = tf.random.normal([batch_size, z_dim])
    fake_data = generator(z, labels)

    # Discriminator loss
    with tf.GradientTape() as tape:
        real_output = discriminator(real_data, labels)
        fake_output = discriminator(fake_data, labels)
        d_loss = discriminator_loss(real_output, fake_output)
        gp = gradient_penalty(discriminator, real_data, fake_data, labels)
        total_d_loss = d_loss + gp_weight * gp

    d_gradients = tape.gradient(total_d_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))

    # Generator loss
    with tf.GradientTape() as tape:
        fake_data = generator(z, labels)
        fake_output = discriminator(fake_data, labels)
        g_loss = generator_loss(fake_output)

    g_gradients = tape.gradient(g_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
    
    return total_d_loss, g_loss

# Training function
def train(dataset, labels, generator, discriminator, generator_optimizer, discriminator_optimizer, epochs, batch_size):
    for epoch in range(epochs):
        for real_data, labels_batch in dataset.batch(batch_size):
            d_loss, g_loss = train_step(real_data, labels_batch, generator, discriminator, generator_optimizer, discriminator_optimizer, batch_size)
        print(f"Epoch {epoch+1}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")

# Preparing the dataset
def preprocess_sequence(sequence):
    # Convert RNA sequences to numerical representation
    mapping = {'A': 0, 'C': 1, 'G': 2, 'U': 3}
    return [mapping[char] for char in sequence]

# Load and preprocess the data
data = [
    ("CAUGGAGAGAUGUUCUUUACU", 0), ("CAUAUCAACUUUUAUUCUCUC", 0), 
    ("AUUAUGAAACUGUUGUGGUGU", 0), ("CAAGUCGGCUUUGCUAUAAAC", 1),
    ("CCCUGGGCAGUAUAGAGACGU", 1), ("UUGCCUUCUUUUAAGAGAUGG", 1), 
    ("AAAGCUAGGUUCCAACCUGAA", 0), ("UAGCCGGGCAUGGUGGCACAC", 0),
    ("AGGUUUUAGUUUUUGCUUUAU", 1), ("CAGCCUGGGCUAACCAGCAUG", 0),
    ("CUUCGAGGCUUUUCCCCACUG", 0), ("GGUUUGGACAUUGAAAUGGCU", 1)
]

sequences, labels = zip(*data)
sequences = np.array([preprocess_sequence(seq) for seq in sequences])
labels = np.array(labels)

# One-hot encode labels
labels = tf.keras.utils.to_categorical(labels, num_classes=2)

# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((sequences, labels))

# Initialize models and optimizers
generator = Generator()
discriminator = Discriminator()
generator_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)

# Train the models
train(dataset, labels, generator, discriminator, generator_optimizer, discriminator_optimizer, epochs=10000, batch_size=32)


ValueError: in user code:

    File "/var/folders/pr/4495xxw90_g8dgzr4yr3lsyh0000gn/T/ipykernel_6966/1499253138.py", line 90, in train_step  *
        fake_data = generator(z, labels)
    File "/Users/arish/Workspace/experiments/rna_modification/.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/var/folders/pr/4495xxw90_g8dgzr4yr3lsyh0000gn/T/ipykernel_6966/1499253138.py", line 38, in call
        x = tf.concat([z, labels], axis=-1)

    ValueError: Exception encountered when calling Generator.call().
    
    [1mDimension 0 in both shapes must be equal, but are 32 and 12. Shapes are [32] and [12]. for '{{node generator_1_1/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](random_normal, Cast, generator_1_1/concat/axis)' with input shapes: [32,128], [12,2], [] and with computed input tensors: input[2] = <-1>.[0m
    
    Arguments received by Generator.call():
      • z=tf.Tensor(shape=(32, 128), dtype=float32)
      • labels=tf.Tensor(shape=(12, 2), dtype=float32)
