<a href="https://colab.research.google.com/github/andrkech/GENERATIVE-METHODS-IN-GENOMICS/blob/main/DNA_seq_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Libraries

In [None]:
!pip install pytorch-lightning

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np

## Dataset

In [None]:
# Define dataset for DNA sequences.

class DNADataset(Dataset):
    def __init__(self, dna_sequences):
        self.dna_sequences = dna_sequences

    def __len__(self):
        return len(self.dna_sequences)

    def __getitem__(self, idx):
        dna_sequence = self.dna_sequences[idx]
        return dna_sequence

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, sequence_length):
        super().__init__()
        self.latent_dim = latent_dim
        self.sequence_length = sequence_length
        self.fc1 = nn.Linear(latent_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, sequence_length * 4)
        self.fc4 = nn.Linear(sequence_length * 4, sequence_length)

    def forward(self, z):
        x = F.relu(self.fc1(z))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, sequence_length):
        super().__init__()
        self.sequence_length = sequence_length
        self.fc1 = nn.Linear(sequence_length, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

## GAN

In [None]:
class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100, sequence_length=50, lr=0.0002):
        super().__init__()
        self.save_hyperparameters()
        self.generator = Generator(latent_dim=latent_dim, sequence_length=sequence_length)
        self.discriminator = Discriminator(sequence_length=sequence_length)
        self.validation_z = torch.randn(6, latent_dim)

        # Turn off automatic optimization
        self.automatic_optimization = False

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_dna = batch
        z = torch.randn(real_dna.shape[0], self.hparams.latent_dim)
        z = z.type_as(real_dna)

        fake_dna = self(z)

        # Get the optimizers manually
        opt_g, opt_d = self.optimizers()

        if optimizer_idx == 0:
            # Generator loss
            g_loss = F.mse_loss(fake_dna, real_dna)
            log_dict = {'g_loss': g_loss}
            return {'loss': g_loss, 'progress_bar': log_dict, 'log': log_dict}

        if optimizer_idx == 1:
            # Discriminator loss
            real_labels = torch.ones(real_dna.shape[0], 1)
            fake_labels = torch.zeros(fake_dna.shape[0], 1)
            real_labels = real_labels.type_as(real_dna)
            fake_labels = fake_labels.type_as(real_dna)

            real_pred = self.discriminator(real_dna)
            fake_pred = self.discriminator(fake_dna.detach())

            d_real_loss = F.binary_cross_entropy(real_pred, real_labels)
            d_fake_loss = F.binary_cross_entropy(fake_pred, fake_labels)
            d_loss = d_real_loss + d_fake_loss

            log_dict = {'d_loss': d_loss}
            return {'loss': d_loss, 'progress_bar': log_dict, 'log': log_dict}

    def configure_optimizers(self):
        lr = self.hparams.lr
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [opt_g, opt_d], []

## Synthetic DNA sequences

In [None]:
# Create synthetic DNA sequences.
num_sequences = 1000
sequence_length = 50
dna_sequences = ["".join(np.random.choice(["A", "T", "C", "G"], size=sequence_length)) for _ in range(num_sequences)]

# Create DNADataset and DataLoader.
dna_dataset = DNADataset(dna_sequences)
batch_size = 64
data_loader = DataLoader(dna_dataset, batch_size=batch_size, shuffle=True)

## Model Initialization

In [None]:
# Initialize GAN model.
latent_dim = 20
model = GAN(latent_dim=latent_dim, sequence_length=sequence_length)

# Initialize Trainer.
trainer = pl.Trainer(max_epochs=20, accelerator="auto")

## GAN Training

In [None]:
# Train the GAN.
trainer.fit(model, data_loader)

# Generate new DNA sequences using the trained GAN.
def generate_sequences(num_sequences):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_sequences, latent_dim)
        generated_dna = model(z)
    return generated_dna

In [None]:
# Generate new DNA sequences
num_generated_sequences = 10
generated_sequences = generate_sequences(num_generated_sequences)

# Print the generated sequences
for i, seq in enumerate(generated_sequences):
    print(f"Generated Sequence {i+1}: {seq}")