In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torch.nn.functional as F
from torch.utils import *
from torch.utils.data import DataLoader
from torch.nn.functional import gumbel_softmax
import matplotlib.pyplot as plt
import os
import time
from plot_metrics import plot_metrics

In [None]:
from total_amp_encoder import AMPSequenceDataset

if __name__ == "__main__":
    file_path = r"C:\Users\kotsgeo\Documents\GANs\all_amp.fasta"
    dataset = AMPSequenceDataset(file_path)
    
    # Create dataloader for training
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
    
    # Quick test
    for batch in dataloader:
        print("Batch shape:", batch.shape)
        print("Number of batches:", (len(dataloader)))
        break

In [None]:
from models import Generator_lang, UNetDiscriminator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define results directory with absolute path
results_dir = r"C:\Users\kotsgeo\Documents\GANs\Old"

# Create model directories
model_save_dir = os.path.join(results_dir, 'saved_models')
jsd_models_dir = os.path.join(model_save_dir, 'best_jsd')
orf_models_dir = os.path.join(model_save_dir, 'best_orf')
amp_models_dir = os.path.join(model_save_dir, 'best_amp')

# Create directories if they don't exist
os.makedirs(jsd_models_dir, exist_ok=True)
os.makedirs(orf_models_dir, exist_ok=True)
os.makedirs(amp_models_dir, exist_ok=True)

############################## DEFINE #######################################################
# Parameters                                                                                
n_chars = 5                                                                                 
seq_len = 156
batch_size = 64
batch_eval = 256 # must be multiplier of batch_size
hidden_g = 192
hidden_d = 128
num_epochs = 250
lambda_gp = 10  # Gradient penalty coefficient
d_step = 5
g_step = 2
lambda_dec = 0.1
lambda_pixel = 0.5
lambda_mix = 1
scale = 1

# Initialize models
generator = Generator_lang(n_chars, seq_len, batch_size, hidden_g).to(device)
discriminator = UNetDiscriminator(n_chars, seq_len, hidden_d).to(device)

# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.9, 0.999), weight_decay=1e-5)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=1e-5)

d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
#############################################################################################

In [None]:
from wugan_train import train

# --- Run Training ---
iteration_losses, total_iterations, jsd_history, amp_history = train(
    generator=generator,
    discriminator=discriminator,
    dataloader=dataloader,
    batch_eval = batch_eval,
    num_epochs=num_epochs,
    n_chars=n_chars,
    device=device,
    results_dir=results_dir,
    d_optimizer=d_optimizer,
    g_optimizer=g_optimizer,
    d_scheduler=d_scheduler,
    g_scheduler=g_scheduler,
    lambda_mix=lambda_mix,
    lambda_dec=lambda_dec,
    lambda_pixel=lambda_pixel,
    scale=scale
)

# Plot JSD and AMP
plot_metrics(iteration_losses, total_iterations, num_epochs, len(dataloader),
             jsd_history=jsd_history, amp_history=amp_history,
             plot_jsd=True, plot_fred=False, plot_amp=True)