In [None]:
import torch
from tqdm import tqdm

def train_gan(model_G, model_D, dataloader, config):
    """
    Generic training loop for GAN models.
    """
    # Load hyperparameters
    latent_dim = config['latent_dim']
    num_epochs = config['num_epochs']
    learning_rate = config['learning_rate']
    betas = config['betas']
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Move models to device
    model_G = model_G.to(device)
    model_D = model_D.to(device)

    # Define optimizers
    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=learning_rate, betas=betas)
    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate, betas=betas)

    # Define loss function (Binary Cross Entropy for simplicity)
    criterion = torch.nn.BCELoss()

    # Training loop
    for epoch in range(num_epochs):
        for real_images, _ in tqdm(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            # Labels for real and fake images
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # ---- Train Discriminator ----
            optimizer_D.zero_grad()

            # Real images loss
            real_preds = model_D(real_images)
            real_loss = criterion(real_preds, real_labels)

            # Fake images loss
            noise = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_images = model_G(noise)
            fake_preds = model_D(fake_images.detach())
            fake_loss = criterion(fake_preds, fake_labels)

            # Total loss and backward
            d_loss = real_loss + fake_loss
            d_loss.backward()
            optimizer_D.step()

            # ---- Train Generator ----
            optimizer_G.zero_grad()

            # Fake images loss
            fake_preds = model_D(fake_images)
            g_loss = criterion(fake_preds, real_labels)
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
