### Description

This implements a CVAE-GAN that infills a masked promoter sequence based on a combined loss function, including:
1. Reconstruction Loss (CVAE): Ensures that the infilled sequence closely resembles the original.
2. Adversarial Loss (GAN): Encourages the generation of realistic sequences.
3. Auxiliary Loss (MSE): Ensures that the generated sequence produces the correct expression when evaluated by the CNN.

In [1]:
import CTGAN_1_5 as parent

In [2]:
version = '1_5'

In [3]:
# Hyperparameters
batch_size = 64
epochs = 5
learning_rate = 0.0002
adversarial_lambda = 1
cnn_lambda = 10
path_to_cnn = '../Models/CNN_5_0.keras'
path_to_data = '../Data/combined/LaFleur_supp.csv'

# Load Data and Prepare Dataloaders
df = parent.load_data(path_to_data)
train_loader, test_loader = parent.prepare_dataloader(df, batch_size)
 
# Initialize Models
generator = parent.Generator()
discriminator = parent.Discriminator()
cnn = parent.KerasModelWrapper(path_to_cnn)

In [4]:
# Train Models with Training DataLoader
parent.train_ctgan(generator, discriminator, train_loader, cnn, epochs, learning_rate, adversarial_lambda, cnn_lambda)

# Save the trained models
parent.save_model(generator, f'../Models/generator_{version}.pth')
parent.save_model(discriminator, f'../Models/generator_{version}.pth')

Epoch [1/5]  Loss aD: 0.7234, Loss cD: 0.0372, Loss G: 1.0952
Epoch [2/5]  Loss aD: 0.5904, Loss cD: 0.0215, Loss G: 0.8051
Epoch [3/5]  Loss aD: 0.6089, Loss cD: 0.0184, Loss G: 0.7927
Epoch [4/5]  Loss aD: 0.7077, Loss cD: 0.0198, Loss G: 0.9057
Epoch [5/5]  Loss aD: 0.7998, Loss cD: 0.0233, Loss G: 1.0328
Model saved to ../Models/generator_1_5.pth
Model saved to ../Models/generator_1_5.pth


In [6]:
# # Load the models
# parent.load_model(generator, f'../Models/generator_{version}.pth')
# parent.load_model(discriminator, f'../Models/generator_{version}.pth')

# Evaluate the generator on the test set
parent.evaluate_generator(generator, cnn, test_loader)

Average MSE on Test Set: 0.0188


Average MSE on Test Set: 0.0188

In [7]:
# Test Example
sequences = ['TTTTCTATCTACGTACTTGACACTATTTCNNNNNNNNNNATTACCTTAGTTTGTACGTT']
expressions = [0.5]

# Generate infills
infilled = parent.generate_infills(generator, sequences, expressions)
for original, infilled in zip(sequences, infilled):
    print("Original Sequences:", original)
    print("Infilled Sequences:", infilled)

Original Sequences: TTTTCTATCTACGTACTTGACACTATTTCNNNNNNNNNNATTACCTTAGTTTGTACGTT
Infilled Sequences: TTTTCTATCTACGTACTTGACACTATTTCACTGGTTTAAATTACCTTAGTTTGTACGTT


In [115]:
import torch

def generate_infills(generator, sequences, expressions, mask_size=10):
    infilled_sequences = []
    predicted_exprs = []
    
    for sequence, expr in zip(sequences, expressions):
        start = sequence.find('N' * mask_size)
        if start == -1:
            raise ValueError("No masked region ('N') found in the sequence.")
        
        # Convert the masked sequence to a tensor
        sequence_tensor = parent.one_hot_encode_sequence(sequence).unsqueeze(0)
        expr_tensor = torch.tensor([expr], dtype=torch.float32).view(1, 1)

        # Generate infill using the generator
        generated_segment = generator(sequence_tensor, expr_tensor)
        predicted_infill = parent.decode_one_hot_sequence(generated_segment.argmax(dim=2).squeeze().numpy())

        # Predict the expression of the generated segment using the CNN
        cnn_input = parent.preprocess_cnn_input(sequence_tensor, generated_segment, 0)
        predicted_expr = cnn(cnn_input).item()
        
        # Reconstruct the full sequence
        infilled_sequence = (
            sequence[:start] + predicted_infill + sequence[start + mask_size:]
        )
        infilled_sequences.append(infilled_sequence)
        predicted_exprs.append(predicted_expr)
    
    return infilled_sequences, predicted_exprs

# Test Example
sequences = ['TTTTCTATCTACGTACTTGACACTATTTCCTATTTCNNNNNNNNNNATATTACTCTACCTTAGTTTGTACGTT']
real_exprs = [0.5]

# Generate infills
infilled, predicted_exprs = generate_infills(generator, sequences, real_exprs)
for original, infilled, predicted_expr, real_expr in zip(sequences, infilled, predicted_exprs, real_exprs):
    print(f"Original Sequences: {original}, Expression: {real_expr}")
    print(f"Infilled Sequences: {infilled}, Expression: {predicted_expr}")

Original Sequences: TTTTCTATCTACGTACTTGACACTATTTCCTATTTCNNNNNNNNNNATATTACTCTACCTTAGTTTGTACGTT, Expression: 0.5
Infilled Sequences: TTTTCTATCTACGTACTTGACACTATTTCCTATTTCACTCGTTTAAATATTACTCTACCTTAGTTTGTACGTT, Expression: 0.6876955628395081
