In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torch.nn.functional as F
from torch.utils import *
from torch.utils.data import DataLoader
from torch.nn.functional import gumbel_softmax
import matplotlib.pyplot as plt
import os
import time
from plot_metrics import plot_metrics

In [2]:
from total_amp_encoder import AMPSequenceDataset

if __name__ == "__main__":
    file_path = "/files/private/notebooks/GANs/WUGAN/all_amp.fasta"
    dataset = AMPSequenceDataset(file_path)
    
    # Create dataloader for training
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
    
    # Quick test
    for batch in dataloader:
        print("Batch shape:", batch.shape)
        print("Number of batches:", (len(dataloader)))
        break

Batch shape: torch.Size([64, 156, 5])
Number of batches: 75


In [3]:
from models import Generator_lang, UNetDiscriminator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define results directory with absolute path
results_dir = "/files/private/notebooks/GANs/WUGAN"

# Create model directories
model_save_dir = os.path.join(results_dir, 'saved_models')
jsd_models_dir = os.path.join(model_save_dir, 'best_jsd')
orf_models_dir = os.path.join(model_save_dir, 'best_orf')
amp_models_dir = os.path.join(model_save_dir, 'best_amp')

# Create directories if they don't exist
os.makedirs(jsd_models_dir, exist_ok=True)
os.makedirs(orf_models_dir, exist_ok=True)
os.makedirs(amp_models_dir, exist_ok=True)

############################## DEFINE #######################################################
# Parameters                                                                                
n_chars = 5                                                                                 
seq_len = 156
batch_size = 64
hidden_g = 192
hidden_d = 128
num_epochs = 250
lambda_gp = 10  # Gradient penalty coefficient
d_step = 5
g_step = 2
lambda_dec = 0.4
lambda_mix = 1
scale = 1

# Initialize models
generator = Generator_lang(n_chars, seq_len, batch_size, hidden_g).to(device)
discriminator = UNetDiscriminator(n_chars, seq_len, hidden_d).to(device)

# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.9, 0.999), weight_decay=1e-5)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=1e-5)

d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
#############################################################################################

cuda


In [None]:
from wugan_train import train

# --- Run Training ---
iteration_losses, total_iterations, jsd_history, amp_history = train(
    generator=generator,
    discriminator=discriminator,
    dataloader=dataloader,
    num_epochs=num_epochs,
    n_chars=n_chars,
    device=device,
    results_dir=results_dir,
    d_optimizer=d_optimizer,
    g_optimizer=g_optimizer,
    d_scheduler=d_scheduler,
    g_scheduler=g_scheduler,
    lambda_mix=lambda_mix,
    lambda_dec=lambda_dec,
    scale=scale
)

# Plot JSD and AMP
plot_metrics(iteration_losses, total_iterations, num_epochs, len(dataloader),
             jsd_history=jsd_history, amp_history=amp_history,
             plot_jsd=True, plot_fred=False, plot_amp=True)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Batch [1/75]
D_total_loss: 10.0182
Wasserstein Loss: -0.0146
Gradient Penalty: 9.7486
Decoder Loss: 0.7106
G_total_loss: 0.3246
G_loss: -0.0260
Pixel Loss: 0.8767

Batch [42/75]
D_total_loss: -9.5557
Wasserstein Loss: -11.6272
Gradient Penalty: 1.7893
Decoder Loss: 0.7054
G_total_loss: 1.0449
G_loss: 0.6950
Pixel Loss: 0.8749

Epoch [1/250] - Epoch Time: 70.59s - Total Time: 00:01:10
D_total_loss: -7.6749
Wasserstein Loss: -10.1830
Gradient Penalty: 2.2253
Decoder Loss: 0.7072
G_total_loss: 0.6831
G_loss: 0.3331
Pixel Loss: 0.8749

Latest JSD Score: 0.1673
AMP Score: 31.56%
--------------------------------------------------
Batch [1/75]
D_total_loss: -9.7905
Wasserstein Loss: -11.9419
Gradient Penalty: 1.8697
Decoder Loss: 0.7041
G_total_loss: 0.0770
G_loss: -0.2721
Pixel Loss: 0.8728

Batch [42/75]
D_total_loss: -9.9860
Wasserstein Loss: -12.2327
Gradient Penalty: 1.9635
Decoder Loss: 0.7081
G_total_loss: 0.1342
G_loss: -0.2136
Pixel Loss: 0.8694

Epoch [2/250] - Epoch Time: 69.98s - 

In [5]:
!pip install biopython

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [6]:
from sequence_generator import generate_and_filter_sequences, convert_dna_fasta_to_protein

saved_model_path = r"GANs/WUGAN/saved_models2/best_amp/generator_amp_2_epoch_80_score_48.44.pt"
generator.eval()

sequences, analysis, dna_file = generate_and_filter_sequences(
    generator, 
    batch_size=64,
    num_samples=11270-8*640,
    count_atg=False,
    add_atg=False
)

num_proteins, protein_file = convert_dna_fasta_to_protein(input_fasta=dna_file, add_atg=False)


Generation complete!
Found 5334 valid sequences out of 6150 generated
Sequences saved in: /files/private/notebooks/GANs/WUGAN/campr4/valid_sequences.fasta
Analysis saved in: /files/private/notebooks/GANs/WUGAN/campr4/sequence_analysis.txt
Converted 5099 DNA sequences to proteins
Protein sequences saved in: /files/private/notebooks/GANs/WUGAN/campr4/valid_sequences_proteins.fasta
