### Description

This implements a CVAE based on https://github.com/unnir/cVAE adapted for promoter sequences and expressions. This does include CNN auxiliary loss, "re-masking" the generated sequence before giving it to the CNN.

In [1]:
from CVAE_unnir_1_1 import *

In [2]:
# Defining hyperparameters
batch_size = 64
epochs = 100
early_stopping_patience = 10
early_stopping_min_delta = 0.01
latent_size = 20

In [3]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Paths to Data and Pre-trained CNN
path_to_train = '../Data/Augmented/augmented_train_data_6_1.csv'
path_to_test = '../Data/Augmented/augmented_test_data_6_1.csv'
path_to_cvae = '../Models/CVAE_6_1.pt'
path_to_cnn = '../Models/CNN_6_1.keras'
path_to_summary = '../Testing CVAE/runs/CNN_6_1_summary'

# Set up device
device = get_device()

# Initialize model, optimizer
cnn = KerasModelWrapper(path_to_cnn)
model = CVAE(150, latent_size, 1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Load data and one-hot encode sequences
onehot_masked_train, mask_lengths_train, mask_starts_train, expressions_train = load_data(path_to_train)
onehot_masked_test, mask_lengths_test, mask_starts_test, expressions_test = load_data(path_to_test)

# Preprocess sequences and expressions into tensors
masked_tensor_train = torch.tensor(np.stack(onehot_masked_train), dtype=torch.float32)
expressions_tensor_train = torch.tensor(expressions_train.values, dtype=torch.float32)
masked_tensor_test = torch.tensor(np.stack(onehot_masked_test), dtype=torch.float32)
expressions_tensor_test = torch.tensor(expressions_test.values, dtype=torch.float32)

# Create DataLoader
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(masked_tensor_train, expressions_tensor_train),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(masked_tensor_test, expressions_tensor_test),
    batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
)


  saveable.load_own_variables(weights_store.get(inner_path))


In [4]:
# # Train and test the model
# train_losses, test_losses = fit_model(epochs,
#                                         model,
#                                         cnn,
#                                         path_to_summary,
#                                         train_loader,
#                                         test_loader,
#                                         optimizer,
#                                         device,
#                                         early_stopping_patience,
#                                         early_stopping_min_delta
# )

# # Plot the training and testing losses
# plot_losses(train_losses, test_losses)

# # Save the model
# save_model(model, path_to_cvae)

In [5]:
# Load the model
load_model(path_to_cvae)

RecursiveScriptModule(
  original_name=CVAE
  (fc1): RecursiveScriptModule(original_name=Linear)
  (fc21): RecursiveScriptModule(original_name=Linear)
  (fc22): RecursiveScriptModule(original_name=Linear)
  (fc3): RecursiveScriptModule(original_name=Linear)
  (fc4): RecursiveScriptModule(original_name=Linear)
  (elu): RecursiveScriptModule(original_name=ELU)
  (sigmoid): RecursiveScriptModule(original_name=Sigmoid)
)

In [9]:
# Test Example
"TTTTCTATCTACGTAC	TTGACA	CTATTTCCTATTTCTCT	TATAAT	CCCCGCGG	CTCTACCTTAGTTTGTACGTT"

masked_sequences = ['TTTTCTATCTACGTACNNNNNNCTATTTCCTATTTCTCTTATAATCCCCGCGGCTCTACCTTAGTTTGTACGTT']
expressions = [0.2]

# Generate infills
infilled_sequences, predicted_expressions = generate_infills(model, cnn, masked_sequences, expressions)
for masked, infilled, expressions in zip(masked_sequences, infilled_sequences, predicted_expressions):
    print("Masked:  ", masked)
    print("Infilled:", infilled)
    print("Predicted Expression:", expressions)
    print()

Masked:   TTTTCTATCTACGTACNNNNNNCTATTTCCTATTTCTCTTATAATCCCCGCGGCTCTACCTTAGTTTGTACGTT
Infilled: TTTTCTATCTACGTACAGCGGACTATTTCCTATTTCTCTTATAATCCCCGCGGCTCTACCTTAGTTTGTACGTT
Predicted Expression: 0.21180134

