In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torch.nn.functional as F
from torch.utils import *
from torch.utils.data import DataLoader
from torch.nn.functional import gumbel_softmax
import matplotlib.pyplot as plt
import os
import time
from plot_metrics import plot_jsd, plot_jsd_fred, plot_metrics

ModuleNotFoundError: No module named 'torch'

In [None]:
from sequence_encoder import SequenceDataset

# Define the file path
file_path = r"C:\Users\kotsgeo\Documents\GANs\Old\AMPdata.txt"

# Create an instance of SequenceDataset
dataset = SequenceDataset(file_path)

# Print the number of sequences
num_sequences = len(dataset)
print("Number of sequences:", num_sequences)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

# Use the dataloader in your training loop or other processes
for batch in dataloader:
    print("Batch shape:", batch.shape)
    print("Number of batches:", (len(dataloader)))
    break

Number of sequences: 2600
Batch shape: torch.Size([64, 156, 5])
Number of batches: 40


In [4]:
from seq_analysis import sample_and_analyze, save_analysis, analyze_sequences
from JSD import jsd
from CutMix import create_cutmix_mask
from models import Generator_lang, UNetDiscriminator
from amp_evaluator import evaluate_amp_batch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define results directory with absolute path
results_dir = r"C:\Users\kotsgeo\Documents\GANs\Old"

# Create model directories
model_save_dir = os.path.join(results_dir, 'saved_models2')
jsd_models_dir = os.path.join(model_save_dir, 'best_jsd')
orf_models_dir = os.path.join(model_save_dir, 'best_orf')
amp_models_dir = os.path.join(model_save_dir, 'best_amp')

# Create directories if they don't exist
os.makedirs(jsd_models_dir, exist_ok=True)
os.makedirs(orf_models_dir, exist_ok=True)
os.makedirs(amp_models_dir, exist_ok=True)

############################## DEFINE #######################################################
# Parameters                                                                                
n_chars = 5                                                                                 
seq_len = 156
batch_size = 64
hidden_g = 128
hidden_d = 128
num_epochs = 120
lambda_gp = 10  # Gradient penalty coefficient

# Initialize models
generator = Generator_lang(n_chars, seq_len, batch_size, hidden_g).to(device)
discriminator = UNetDiscriminator(n_chars, seq_len, hidden_d).to(device)

# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.9, 0.999), weight_decay=1e-5)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.9, 0.999), weight_decay=1e-5)

d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
#############################################################################################

def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"

def calc_gradient_penalty(discriminator, real_data, fake_data, device, lambda_gp=10):
    batch_size = real_data.size(0)
    
    # Only interpolate between real and fake (remove mixed data interpolation)
    alpha = torch.rand(batch_size, 1, 1, device=device)
    alpha = alpha.expand_as(real_data)
    
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)
    
    disc_interpolates, _ = discriminator(interpolates)
    
    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Flatten gradients
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = lambda_gp * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

def discriminator_train(discriminator, real_sequences, fake_sequences, mixed_sequences, mask, optimizer, 
                        device, lambda_mix=1, lambda_dec=1, lambda_consistency=1, scale=0.1):
    # Ensure correct shape
    if real_sequences.shape[1] != n_chars:
        real_sequences = real_sequences.transpose(1, 2)
    if fake_sequences.shape[1] != n_chars:
        fake_sequences = fake_sequences.transpose(1, 2)
    if mixed_sequences.shape[1] != n_chars:
        mixed_sequences = mixed_sequences.transpose(1, 2)

    # Get predictions for real and fake data
    real_global, real_pixel = discriminator(real_sequences)
    fake_global, fake_pixel = discriminator(fake_sequences)
    mixed_global, mixed_pixel = discriminator(mixed_sequences)

    wasserstein_loss = (
        -torch.mean(real_global) +      # Maximize real scores
        torch.mean(fake_global)         # Minimize fake scores  
        # lambda_mix * torch.mean(mixed_global)  # Mixed: partially real
    )

    # # Consistency loss for mixed samples
    # consistency_loss = torch.mean((mixed_global - 
    #                   (lambda_mix * real_global + (1-lambda_mix) * fake_global))**2)
    
    # # Calculate Wasserstein loss
    # wasserstein_loss = real_loss + fake_loss + mixed_loss 

    criterion = torch.nn.BCEWithLogitsLoss()

    mask = mask.squeeze(-1).unsqueeze(1)  # [batch_size, 1, seq_len]    

    # # Calculate unified decoder loss
    # dec_loss = -torch.mean(
    #     torch.log(real_pixel + 1e-8) + # Real positions should be 1
    #     torch.log(1 - fake_pixel + 1e-8) + # Fake positions should be 0
    #     mask * torch.log(mixed_pixel + 1e-8) + # Mixed sequences should be treated as real where mask is 1
    #     (1 - mask) * torch.log(1 - mixed_pixel + 1e-8) # Mixed sequences should be treated as fake where mask is 0
    # ) / 3.0

    # Ensure all tensors have matching dimensions
    real_pixel = real_pixel.view(real_pixel.size(0), 1, -1)[:, :, :5]
    fake_pixel = fake_pixel.view(fake_pixel.size(0), 1, -1)[:, :, :5]
    mixed_pixel = mixed_pixel.view(mixed_pixel.size(0), 1, -1)[:, :, :5]

    dec_loss = (
        criterion(real_pixel, torch.ones_like(real_pixel)) +    # Real → 1
        criterion(fake_pixel, torch.zeros_like(fake_pixel)) +   # Fake → 0
        criterion(mixed_pixel, mask.float())                     # Mixed → mask
    ) / 3.0

    # Calculate gradient penalty including mixed sequences
    gradient_penalty = calc_gradient_penalty(discriminator, real_sequences, fake_sequences, device)

    # Scale the losses
    wasserstein_loss = wasserstein_loss * scale  # Scale Wasserstein loss
    gradient_penalty = gradient_penalty * scale   # Scale gradient penalty

    # Normalize decoder loss and combine losses
    normalized_dec_loss = 1 + torch.tanh(dec_loss)  # Normalize decoder loss to be in range [0, 2]
    total_d_loss = wasserstein_loss + gradient_penalty * (lambda_dec * dec_loss) #+ lambda_consistency * consistency_loss # Combine losses
    
    # Update weights
    optimizer.zero_grad()
    total_d_loss.backward()

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)

    optimizer.step()

    return {
        'total_d_loss': total_d_loss.item(),
        'wasserstein_loss': wasserstein_loss.item(),
        'gradient_penalty': gradient_penalty.item(),
        'dec_loss': dec_loss.item(), 
        'normalized_dec_loss': normalized_dec_loss.item()
    }

def generator_train(generator, discriminator, batch_size, optimizer, device, lambda_pixel=1, scale=1):
    # Generate fake sequences
    noise = torch.randn(batch_size, 128, device=device)
    fake_sequences = generator(noise) 
    fake_sequences = fake_sequences.transpose(1, 2) # Shape: [batch_size, n_chars, seq_len]

    # Get discriminator predictions
    fake_global, fake_pixel = discriminator(fake_sequences)
    
    pixel_loss = -torch.mean(torch.log(fake_pixel + 1e-8))

    normalized_dec_loss = 1 + torch.tanh(fake_pixel.mean())  # Normalize decoder loss to be in range [0, 2]
    # Generator loss
    g_loss = -torch.mean(fake_global) * scale * (lambda_pixel * normalized_dec_loss)

    # Update weights
    optimizer.zero_grad()
    g_loss.backward()

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)

    optimizer.step()

    return {
        'g_loss': g_loss.item(),
        'pixel_loss': pixel_loss.item()
    }

def train(generator, discriminator, dataloader, num_epochs, d_step=5, g_step=1, device=device, 
          lambda_mix=1, lambda_dec=1, lambda_pixel=1, lambda_consistency=1, scale=0.1):
    # Initialize scores
    jsd_history = []
    amp_history = [] 

    # Save best models
    best_jsd_models = [] 
    best_orf_models = [] 
    best_amp_models = []

    # Initialize FReD score tracking
    # fred_history = []

    iteration_losses = {
        'total_d_loss': [],
        'wasserstein_loss': [],
        'gradient_penalty': [],
        'total_g_loss': [],
        'dec_loss': [],
        'normalized_dec_loss': [],
        'pixel_loss': []
    }

    def update_best_models(score, epoch, model, best_list, maximize=False, max_keep=5):
        """Helper function to update best models list"""
        model_state = model.state_dict()
        if len(best_list) < max_keep:
            best_list.append((score, epoch, model_state))
            best_list.sort(reverse=maximize)
        else:
            if (maximize and score > best_list[-1][0]) or (not maximize and score < best_list[-1][0]):
                best_list[-1] = (score, epoch, model_state)
                best_list.sort(reverse=maximize)
        return best_list

    total_iterations = 0
    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        running_losses = {
        'total_d_loss': 0,
        'wasserstein_loss': 0,
        'gradient_penalty': 0,
        'total_g_loss': 0,
        'dec_loss': 0,
        'normalized_dec_loss': 0,
        'pixel_loss': 0
    }
        num_batches = len(dataloader)

        current_lambda_mix = min(lambda_mix, epoch/10)

        for batch_idx, real_sequences in enumerate(dataloader):
            total_iterations += 1
            
            # Ensure real_sequences has shape [batch_size, n_chars, seq_len]
            if real_sequences.shape[1] != n_chars:
                real_sequences = real_sequences.transpose(1, 2)
            real_sequences = real_sequences.to(device)
            batch_size = real_sequences.size(0)

            # Generate fake sequences
            noise = torch.randn(batch_size, 128, device=device)
            fake_sequences = generator(noise).detach()
            if fake_sequences.shape[1] != n_chars:
                fake_sequences = fake_sequences.transpose(1, 2)

            # Create CutMix mask and mixed sequences
            mask = create_cutmix_mask(real_sequences, lam=0.8)
            mixed_sequences = mask * real_sequences + (1 - mask) * fake_sequences

            if mixed_sequences.shape[1] != n_chars:
                mixed_sequences = mixed_sequences.transpose(1, 2)

            # Train discriminator
            d_losses_sum = {'total_d_loss': 0, 'wasserstein_loss': 0, 'gradient_penalty': 0, 'dec_loss': 0, 'normalized_dec_loss': 0}
            for _ in range(d_step):
                d_losses = discriminator_train(discriminator, real_sequences, fake_sequences, mixed_sequences, mask, 
                                               d_optimizer, device, current_lambda_mix, lambda_dec, lambda_consistency, scale)
                for key in d_losses:
                    d_losses_sum[key] += d_losses[key]

            d_losses_avg = {k: v / d_step for k, v in d_losses_sum.items()}

            # Train generator
            g_losses_sum = {'g_loss': 0, 'pixel_loss': 0}
            for _ in range(g_step):
                g_losses = generator_train(generator, discriminator, batch_size, g_optimizer, device, lambda_pixel, scale)
                for key in g_losses:
                    g_losses_sum[key] += g_losses[key]

            g_losses_avg = {k: v / g_step for k, v in g_losses_sum.items()}

            # Update losses
            iteration_losses['total_d_loss'].append(d_losses_avg['total_d_loss'])
            iteration_losses['wasserstein_loss'].append(d_losses_avg['wasserstein_loss'])
            iteration_losses['gradient_penalty'].append(d_losses_avg['gradient_penalty'])
            iteration_losses['dec_loss'].append(d_losses_avg['dec_loss'])
            iteration_losses['normalized_dec_loss'].append(d_losses_avg['normalized_dec_loss'])
            iteration_losses['total_g_loss'].append(g_losses_avg['g_loss'])
            iteration_losses['pixel_loss'].append(g_losses_avg['pixel_loss'])

            running_losses['total_d_loss'] += d_losses_avg['total_d_loss']
            running_losses['wasserstein_loss'] += d_losses_avg['wasserstein_loss']
            running_losses['gradient_penalty'] += d_losses_avg['gradient_penalty']
            running_losses['dec_loss'] += d_losses_avg.get('dec_loss', 0)
            running_losses['normalized_dec_loss'] += d_losses_avg.get('normalized_dec_loss', 0)
            running_losses['total_g_loss'] += g_losses_avg['g_loss']
            running_losses['pixel_loss'] += g_losses_avg['pixel_loss']

            if batch_idx % 41 == 0:
                print(f'Batch [{batch_idx+1}/{num_batches}]')
                print(f'D_total_loss: {d_losses_avg["total_d_loss"]:.4f}')
                print(f'Wasserstein Loss: {d_losses_avg["wasserstein_loss"]:.4f}')
                print(f'Gradient Penalty: {d_losses_avg["gradient_penalty"]:.4f}')
                print(f'Decoder Loss: {d_losses_avg.get("dec_loss", 0):.4f}')
                print(f'Normalized Decoder Loss: {d_losses_avg.get("normalized_dec_loss", 0):.4f}')
                print(f'G_total_loss: {g_losses_avg["g_loss"]:.4f}')
                print(f'Pixel Loss: {g_losses_avg["pixel_loss"]:.4f}\n')


        # Calculate time for this epoch
        epoch_time = time.time() - epoch_start_time
        total_time = time.time() - start_time

        # Calculate epoch averages
        avg_losses = {k: v / num_batches for k, v in running_losses.items()}

        # # Calculate metrics and save best models every epoch
        # current_jsd = jsd(generator, dataloader, num_batches=5)
        # jsd_history.append(current_jsd)

        # # Get generated sequences and analyze them
        # generated_seqs = sample_and_analyze(generator, epoch=epoch, device=device)
        # save_analysis(generated_seqs, epoch, results_dir='Results2')

        # # Analyze sequences to get ORF count
        # seq_properties = analyze_sequences(generated_seqs)
        # orf_count = seq_properties['valid_orfs']

        # ===== EVALUATION SECTION =====
        # Generate sequences once for all evaluations
        noise = torch.randn(320, 128, device=device)
        with torch.no_grad():
            generated_sequences = generator(noise)

        # Get a batch of real sequences for JSD comparison
        real_batch = next(iter(dataloader)).to(device)
        if real_batch.size(0) < 320:
            # If batch size doesn't match, get multiple batches
            real_sequences = []
            for batch in dataloader:
                real_sequences.append(batch)
                if sum(b.size(0) for b in real_sequences) >= 320:
                    break
            real_batch = torch.cat(real_sequences, dim=0)[:320].to(device)

        # Convert to DNA strings using existing function
        generated_seqs = sample_and_analyze(pre_generated=generated_sequences, epoch=epoch, device=device)
        save_analysis(generated_seqs, epoch, results_dir='Results2')

        # Calculate JSD
        current_jsd = jsd(real_batch, generated_sequences)
        jsd_history.append(current_jsd)

        # Analyze sequences to get ORF count
        seq_properties = analyze_sequences(generated_seqs)
        orf_count = seq_properties['valid_orfs']

        # Convert to DNA strings for AMP evaluation (removing padding)
        generated_dna_seqs = [seq.replace('P', '') for seq in generated_seqs]

        # Evaluate AMP properties
        amp_score, amp_details = evaluate_amp_batch(generated_dna_seqs, return_details=True)
        amp_history.append(amp_score)

        # Update all best models
        best_jsd_models = update_best_models(current_jsd, epoch, generator, best_jsd_models, maximize=False)
        best_orf_models = update_best_models(orf_count, epoch, generator, best_orf_models, maximize=True)
        best_amp_models = update_best_models(amp_details["perfect_amp_percentage"], epoch, generator, best_amp_models, maximize=True) # change amp_score with 
                                                                                                                                      # amp_details["perfect_amp_percentage"]

        # Print epoch averages
        print(f'Epoch [{epoch+1}/{num_epochs}] - Epoch Time: {epoch_time:.2f}s - Total Time: {format_time(total_time)}')
        print(f'D_total_loss: {avg_losses["total_d_loss"]:.4f}')
        print(f'Wasserstein Loss: {avg_losses["wasserstein_loss"]:.4f}')
        print(f'Gradient Penalty: {avg_losses["gradient_penalty"]:.4f}')
        print(f'Decoder Loss: {avg_losses["dec_loss"]:.4f}')
        print(f'Normalized Decoder Loss: {avg_losses["normalized_dec_loss"]:.4f}')
        print(f'G_total_loss: {avg_losses["total_g_loss"]:.4f}')
        print(f'Pixel Loss: {avg_losses["pixel_loss"]:.4f}\n')
        print(f'Latest JSD Score: {current_jsd:.4f}')
        print(f'AMP Score: {amp_score:.2f}% (Perfect AMPs: {amp_details["perfect_amp_percentage"]:.1f}%)')
        print(50 * "-")

        # Step the schedulers
        d_scheduler.step()
        g_scheduler.step()

    # Save best models at the end of training
    for i, (jsd_score, epoch, model_state) in enumerate(best_jsd_models):
        save_path = os.path.join(jsd_models_dir, f'generator_jsd_{i+1}_epoch_{epoch}_score_{jsd_score:.4f}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'jsd_score': jsd_score
        }, save_path)

    for i, (orf_count, epoch, model_state) in enumerate(best_orf_models):
        save_path = os.path.join(orf_models_dir, f'generator_orf_{i+1}_epoch_{epoch}_count_{orf_count}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'orf_count': orf_count
        }, save_path)

    # Save best AMP models
    for i, (amp_score, epoch, model_state) in enumerate(best_amp_models):
        save_path = os.path.join(amp_models_dir, f'generator_amp_{i+1}_epoch_{epoch}_score_{amp_score:.2f}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'amp_score': amp_score
        }, save_path)

    return iteration_losses, total_iterations, jsd_history, amp_history

ModuleNotFoundError: No module named 'torch'

In [5]:
iteration_losses, total_iterations, jsd_history, amp_history = train(generator, discriminator, dataloader, num_epochs, d_step=3, 
                                                        g_step=1, device=device, lambda_dec=1, lambda_pixel=1, scale=1)

# Plot JSD and AMP
plot_metrics(iteration_losses, total_iterations, num_epochs, len(dataloader),
             jsd_history=jsd_history, amp_history=amp_history,
             plot_jsd=True, plot_fred=False, plot_amp=True)

NameError: name 'train' is not defined

In [10]:
from sequence_generator import generate_and_filter_sequences, convert_dna_fasta_to_protein

saved_model_path = r"C:\Users\kotsgeo\Documents\GANs\Old\saved_models2\best_amp\generator_amp_2_epoch_66_score_50.31.pt"  
generator.eval()

# Generate sequences and convert to proteins in one go
sequences, analysis, dna_file = generate_and_filter_sequences(generator, batch_size=64, num_samples=11270-2*640, count_atg=False, add_atg=False)
num_proteins, protein_file = convert_dna_fasta_to_protein(input_fasta=dna_file, add_atg=False)


Generation complete!
Found 5893 valid sequences out of 9990 generated
Sequences saved in: C:\Users\kotsgeo\Documents\GANs\Old\campr4\valid_sequences.fasta
Analysis saved in: C:\Users\kotsgeo\Documents\GANs\Old\campr4\sequence_analysis.txt
Converted 5564 DNA sequences to proteins
Protein sequences saved in: C:\Users\kotsgeo\Documents\GANs\Old\campr4\valid_sequences_proteins.fasta


In [7]:
from analyze_amp_probabilities import calculate_averages

# Actual usage example
files = [
    r"C:\Users\kotsgeo\Documents\GANs\Old\campr4\CAMPdownload_rf.txt",
    r"C:\Users\kotsgeo\Documents\GANs\Old\campr4\CAMPdownload_svm.txt",
    r"C:\Users\kotsgeo\Documents\GANs\Old\campr4\CAMPdownload_ann.txt"
]

model_types = [
    "Random Forest",
    "SVM",
    "ANN"
]

calculate_averages(files, model_types)

Percentage above 0.5 (Random Forest): 49.70%
Percentage above 0.8 (Random Forest): 8.22%
Percentage above 0.5 (SVM): 37.71%
Percentage above 0.8 (SVM): 14.70%
Percentage above 0.5 (ANN): 34.47%
Percentage above 0.8 (ANN): 15.67%

Averages across all models:
Average percentage above 0.5: 40.63%
Average percentage above 0.8: 12.86%
