# VAE-GAN Training Notebook (Standalone)

This notebook trains a Generative Adversarial Network (GAN) using a pre-trained VAE's decoder as the generator. It is completely self-contained and can be run on platforms like Google Colab without any external file dependencies.

### Steps:
1. **Setup**: Installs necessary libraries and defines all required classes and configuration parameters.
2. **Initialization**: Sets up the dataset, models (Generator and Discriminator), optimizers, and loss function.
3. **Training**: Runs the main training loop.
4. **Logging & Saving**: Logs results to Weights & Biases and saves model checkpoints.

## 1. Setup and Dependencies

In [1]:
# Install necessary libraries
# !pip install torch torchvision tqdm wandb numpy matplotlib imageio

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from tqdm.notebook import tqdm
import wandb
import os
import numpy as np
import matplotlib.pyplot as plt
import imageio

print("Libraries imported successfully.")

Libraries imported successfully.


### All-in-One Code Block: Config, Models, and Utilities

In [None]:
# --- Configuration Class ---
class Config():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed = 42
    dataset_name = "CIFAR-10"
    image_size = 32
    in_channels = 3
    out_channels = 3
    encoder_channels = [32, 64, 128]
    decoder_channels = [128, 64, 32]
    kernel_size = 4
    latent_dim = 128
    hidden_dim = 512
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 50
    d_features = 64
    gan_lr = 2e-4
    gan_epochs = 50
    checkpoint_interval = 10
    sample_interval = 5
    log_interval = 100
    save_reconstruction_interval = 2
    data_path = "./data"
    reconstruction_save_path = "./saves/reconstructions"
    model_save_path = "./saves/checkpoints"
    dataset_path = "./data"
    sample_dir = "./saves/gan_samples"
    wandb_project = "VAE-GAN-CIFAR-10-Colab"
    wandb_run_name = "gan_run"

# --- VAE Model --- 
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=Config.in_channels, out_channels=Config.encoder_channels[0], kernel_size=Config.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.encoder_channels[0]),
            nn.ReLU(True),
            nn.Conv2d(in_channels=Config.encoder_channels[0], out_channels=Config.encoder_channels[1], kernel_size=Config.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.encoder_channels[1]),
            nn.ReLU(True),
            nn.Conv2d(in_channels=Config.encoder_channels[1], out_channels=Config.encoder_channels[2], kernel_size=Config.kernel_size, stride=2, padding=1),
            nn.BatchNorm2d(Config.encoder_channels[2]),
            nn.ReLU(True),
        )
        self.flatten = nn.Flatten()
        self.mu = nn.Linear(in_features=Config.encoder_channels[2] * 4 * 4, out_features=Config.latent_dim)
        self.logvar = nn.Linear(in_features=Config.encoder_channels[2] * 4 * 4, out_features=Config.latent_dim)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(in_features=Config.latent_dim, out_features=Config.decoder_channels[0]*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=Config.decoder_channels[0], out_channels=Config.decoder_channels[1], kernel_size=Config.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.decoder_channels[1]),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=Config.decoder_channels[1], out_channels=Config.decoder_channels[2], kernel_size=Config.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.decoder_channels[2]),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=Config.decoder_channels[2], out_channels=Config.out_channels, kernel_size=Config.kernel_size, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, Config.decoder_channels[0], 4, 4)
        x = self.deconv(x)
        return x

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decoder(z)
        return reconstructed, mu, logvar

# --- Discriminator Model ---
class Discriminator(nn.Module):
    def __init__(self, in_channels=Config.in_channels, base_channels=Config.d_features):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=base_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=base_channels, out_channels=base_channels*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_channels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=base_channels*2, out_channels=base_channels*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_channels*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=base_channels*4, out_channels=1, kernel_size=4, stride=1, bias=False),
        )
        
    def forward(self, x):
        out = self.discriminator(x)
        return out.view(-1)

# --- Utility Functions ---
def create_interpolation_gif(decoder, latent_dim, steps, gif_save_path):
    decoder.eval()
    with torch.no_grad():
        z_start = torch.randn(1, latent_dim).to(next(decoder.parameters()).device)
        z_end = torch.randn(1, latent_dim).to(next(decoder.parameters()).device)
        ratios = torch.linspace(0, 1, steps).unsqueeze(1).to(z_start.device)
        z_interpolate = z_start * (1 - ratios) + z_end * ratios
        imgs = decoder(z_interpolate).cpu()
        imgs = (imgs + 1) / 2
        grid_imgs = [make_grid(img.unsqueeze(0), nrow=1).permute(1, 2, 0).numpy() for img in imgs]
        grid_imgs = [(img * 255).astype(np.uint8) for img in grid_imgs]
        imageio.mimsave(gif_save_path, grid_imgs, fps=5)
    decoder.train()

print("All required classes and functions are defined.")

All required classes and functions are defined.


## 2. Initialization

In [11]:
# Set seeds for reproducibility
torch.manual_seed(Config.seed)

# --- Initialize WandB ---
# You will be prompted to login to wandb
wandb.login()
wandb.init(
    project=Config.wandb_project,
    name=Config.wandb_run_name + "_2",
    config={k:v for k,v in Config.__dict__.items() if not k.startswith("__")}
)

# --- Create necessary directories ---
os.makedirs(Config.sample_dir, exist_ok=True)
os.makedirs(Config.model_save_path, exist_ok=True)

print("WandB initialized and directories created.")



0,1
Discriminator_Loss,▁
Epoch,▁
Generator_Loss,▁

0,1
Discriminator_Loss,0.21534
Epoch,1.0
Generator_Loss,3.29011


WandB initialized and directories created.


In [4]:
# --- Load The Dataset ---
transform = transforms.Compose([
    transforms.Resize((Config.image_size, Config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
train_dataset = datasets.CIFAR10(root=Config.data_path, train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=2)

print(f"Dataset loaded. Found {len(train_dataset)} images.")

Dataset loaded. Found 50000 images.


### A Note on the Pre-trained VAE Model
This notebook requires a pre-trained VAE model to use its decoder as the generator. Since we cannot directly load a local file in this standalone notebook, we will first train a VAE from scratch in the next cell. If you have a pre-trained `vae_epoch_50.pt` file, you can upload it to your Colab environment and modify the path in the subsequent cells.

### Step 2a: Train the VAE (or load if you have it)

In [None]:
# --- VAE Training Cell ---
print("Starting VAE training...")
vae = VAE().to(Config.device)
optimizer = optim.Adam(vae.parameters(), lr=Config.learning_rate)
loss_fn = nn.MSELoss(reduction="sum")

for epoch in range(Config.num_epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"VAE Epoch [{epoch+1}/{Config.num_epochs}]", leave=False)
    for i, (images, _) in loop:
        images = images.to(Config.device)
        recon_images, mu, logvar = vae(images)
        
        recon_loss = loss_fn(recon_images, images)
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + kld
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loop.set_postfix(loss=loss.item())

# Save the trained VAE model
vae_model_path = os.path.join(Config.model_save_path, "vae_epoch_50.pt")
torch.save(vae.state_dict(), vae_model_path)
print(f"VAE training complete. Model saved to {vae_model_path}")

In [6]:
Config.model_save_path

'./saves/checkpoints'

In [9]:
# --- Initialize GAN Models, Optimizers, and Loss Function ---

# Load the just-trained VAE to use its decoder as the generator
vae_model_path = os.path.join(Config.model_save_path, "vae_epoch_50.pt")
checkpoint = torch.load(vae_model_path, map_location=Config.device)
vae = VAE().to(Config.device)
vae.load_state_dict(checkpoint['model_state_dict'])
print(f"Pre-trained VAE model loaded successfully from {vae_model_path}")
    
decoder = vae.decoder.to(Config.device) # This is our Generator

# Initialize the Discriminator
discriminator = Discriminator().to(Config.device)

# Optimizers & Loss Function
gen_optimizer = optim.Adam(decoder.parameters(), lr=Config.gan_lr, betas=(0.5, 0.999))
disc_optimizer = optim.Adam(discriminator.parameters(), lr=Config.gan_lr, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()

print("GAN components initialized.")

  checkpoint = torch.load(vae_model_path, map_location=Config.device)


Pre-trained VAE model loaded successfully from ./saves/checkpoints\vae_epoch_50.pt
GAN components initialized.


## 3. Training the GAN

In [12]:
decoder.train()
discriminator.train()

print("Starting GAN training...")

for epoch in range(Config.gan_epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"GAN Epoch [{epoch + 1}/{Config.gan_epochs}]", leave=False)
    epoch_g_losses, epoch_d_losses = [], []
    
    for i, (real_imgs, _) in loop:
        
        real_imgs = real_imgs.to(Config.device)
        batch_size = real_imgs.size(0)
        
        real_labels = torch.ones(batch_size, 1, device=Config.device).squeeze(1)
        fake_labels = torch.zeros(batch_size, 1, device=Config.device).squeeze(1)
        
        # --- Train Discriminator ---
        z = torch.randn(batch_size, Config.latent_dim).to(Config.device)
        fake_imgs_detached = decoder(z).detach()
        
        real_preds = discriminator(real_imgs)
        fake_preds = discriminator(fake_imgs_detached)
        
        d_loss_real = criterion(real_preds, real_labels)
        d_loss_fake = criterion(fake_preds, fake_labels)
        d_loss = (d_loss_real + d_loss_fake) / 2
        
        disc_optimizer.zero_grad()
        d_loss.backward()
        disc_optimizer.step()
        
        # --- Train Generator ---
        z = torch.randn(batch_size, Config.latent_dim).to(Config.device)
        fake_imgs = decoder(z)
        preds = discriminator(fake_imgs)
        
        g_loss = criterion(preds, real_labels)
        
        gen_optimizer.zero_grad()
        g_loss.backward()
        gen_optimizer.step()
        
        loop.set_postfix(D_Loss=d_loss.item(), G_Loss=g_loss.item())
        epoch_d_losses.append(d_loss.item())
        epoch_g_losses.append(g_loss.item())

        #Log loss per step (batch)
        global_step = epoch * len(train_loader) + i
        wandb.log({
            'Step_Discriminator_Loss': d_loss.item(),
            'Step_Generator_Loss': g_loss.item(),
            'Step': global_step
        }, step=global_step)

    # --- End of Epoch Logging & Saving ---
    avg_d_loss = np.mean(epoch_d_losses)
    avg_g_loss = np.mean(epoch_g_losses)

    log_dict = {
        'Epoch': epoch + 1,
        'Discriminator_Loss': avg_d_loss,
        'Generator_Loss': avg_g_loss,
    }

    if (epoch + 1) % Config.sample_interval == 0:
        fake_grid = make_grid(fake_imgs[:16], nrow=4, normalize=True)
        real_grid = make_grid(real_imgs[:16], nrow=4, normalize=True)
        gif_save_path = os.path.join(Config.sample_dir, f"interpolation_epoch_{epoch + 1}.gif")
        create_interpolation_gif(decoder, Config.latent_dim, steps=10, gif_save_path=gif_save_path)
        
        log_dict.update({
            "Sample_Image_Fake": wandb.Image(fake_grid, caption=f"Fake Samples at Epoch {epoch + 1}"),
            "Sample_Image_Real": wandb.Image(real_grid, caption=f"Real Samples at Epoch {epoch + 1}")
        })
        if os.path.exists(gif_save_path):
            log_dict["Latent_Interpolation"] = wandb.Video(gif_save_path, format="gif")
    
    wandb.log(log_dict)
    print(f"Epoch [{epoch+1}/{Config.gan_epochs}] Completed | Avg D_Loss: {avg_d_loss:.4f} | Avg G_Loss: {avg_g_loss:.4f}")
    
    if (epoch + 1) % Config.checkpoint_interval == 0:
        checkpoint_path = os.path.join(Config.model_save_path, f"gan_checkpoint_epoch_{epoch+1}.pt")
        checkpoint = {
            "epoch": epoch + 1,
            "decoder_state_dict": decoder.state_dict(),
            "discriminator_state_dict": discriminator.state_dict(),
            "gen_optimizer_state_dict": gen_optimizer.state_dict(),
            "disc_optimizer_state_dict": disc_optimizer.state_dict(),
        }
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")
           
print("--- Training Finished ---")


Starting GAN training...


GAN Epoch [1/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [1/50] Completed | Avg D_Loss: 0.1606 | Avg G_Loss: 3.4944


GAN Epoch [2/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [2/50] Completed | Avg D_Loss: 0.1662 | Avg G_Loss: 3.1941


GAN Epoch [3/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [3/50] Completed | Avg D_Loss: 0.1596 | Avg G_Loss: 3.1417


GAN Epoch [4/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [4/50] Completed | Avg D_Loss: 0.1589 | Avg G_Loss: 3.0231


GAN Epoch [5/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [5/50] Completed | Avg D_Loss: 0.1743 | Avg G_Loss: 2.9335


GAN Epoch [6/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [6/50] Completed | Avg D_Loss: 0.1807 | Avg G_Loss: 2.9742


GAN Epoch [7/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [7/50] Completed | Avg D_Loss: 0.1739 | Avg G_Loss: 3.1046


GAN Epoch [8/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [8/50] Completed | Avg D_Loss: 0.1515 | Avg G_Loss: 3.1530


GAN Epoch [9/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [9/50] Completed | Avg D_Loss: 0.1636 | Avg G_Loss: 3.1696


GAN Epoch [10/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [10/50] Completed | Avg D_Loss: 0.1474 | Avg G_Loss: 3.2291
Checkpoint saved to ./saves/checkpoints\gan_checkpoint_epoch_10.pt


GAN Epoch [11/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [11/50] Completed | Avg D_Loss: 0.1633 | Avg G_Loss: 3.2630


GAN Epoch [12/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [12/50] Completed | Avg D_Loss: 0.1189 | Avg G_Loss: 3.3371


GAN Epoch [13/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [13/50] Completed | Avg D_Loss: 0.1952 | Avg G_Loss: 3.2487


GAN Epoch [14/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [14/50] Completed | Avg D_Loss: 0.1515 | Avg G_Loss: 3.2716


GAN Epoch [15/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [15/50] Completed | Avg D_Loss: 0.1263 | Avg G_Loss: 3.3857


GAN Epoch [16/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [16/50] Completed | Avg D_Loss: 0.1514 | Avg G_Loss: 3.4394


GAN Epoch [17/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [17/50] Completed | Avg D_Loss: 0.1408 | Avg G_Loss: 3.4097


GAN Epoch [18/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [18/50] Completed | Avg D_Loss: 0.1359 | Avg G_Loss: 3.5170


GAN Epoch [19/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [19/50] Completed | Avg D_Loss: 0.1003 | Avg G_Loss: 3.6285


GAN Epoch [20/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [20/50] Completed | Avg D_Loss: 0.1494 | Avg G_Loss: 3.6396
Checkpoint saved to ./saves/checkpoints\gan_checkpoint_epoch_20.pt


GAN Epoch [21/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [21/50] Completed | Avg D_Loss: 0.0510 | Avg G_Loss: 3.8050


GAN Epoch [22/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [22/50] Completed | Avg D_Loss: 0.1349 | Avg G_Loss: 3.7006


GAN Epoch [23/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [23/50] Completed | Avg D_Loss: 0.1278 | Avg G_Loss: 3.7332


GAN Epoch [24/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [24/50] Completed | Avg D_Loss: 0.0876 | Avg G_Loss: 3.9249


GAN Epoch [25/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [25/50] Completed | Avg D_Loss: 0.1449 | Avg G_Loss: 3.8110


GAN Epoch [26/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [26/50] Completed | Avg D_Loss: 0.0982 | Avg G_Loss: 3.8978


GAN Epoch [27/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [27/50] Completed | Avg D_Loss: 0.1324 | Avg G_Loss: 3.7806


GAN Epoch [28/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [28/50] Completed | Avg D_Loss: 0.1501 | Avg G_Loss: 3.8620


GAN Epoch [29/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [29/50] Completed | Avg D_Loss: 0.0450 | Avg G_Loss: 4.1399


GAN Epoch [30/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [30/50] Completed | Avg D_Loss: 0.1074 | Avg G_Loss: 4.0649
Checkpoint saved to ./saves/checkpoints\gan_checkpoint_epoch_30.pt


GAN Epoch [31/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [31/50] Completed | Avg D_Loss: 0.1128 | Avg G_Loss: 4.0354


GAN Epoch [32/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [32/50] Completed | Avg D_Loss: 0.0902 | Avg G_Loss: 4.2210


GAN Epoch [33/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [33/50] Completed | Avg D_Loss: 0.1542 | Avg G_Loss: 3.9426


GAN Epoch [34/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [34/50] Completed | Avg D_Loss: 0.0386 | Avg G_Loss: 4.2717


GAN Epoch [35/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [35/50] Completed | Avg D_Loss: 0.1544 | Avg G_Loss: 3.9673


GAN Epoch [36/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [36/50] Completed | Avg D_Loss: 0.0331 | Avg G_Loss: 4.3971


GAN Epoch [37/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [37/50] Completed | Avg D_Loss: 0.1946 | Avg G_Loss: 3.7709


GAN Epoch [38/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [38/50] Completed | Avg D_Loss: 0.0332 | Avg G_Loss: 4.3347


GAN Epoch [39/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [39/50] Completed | Avg D_Loss: 0.1738 | Avg G_Loss: 4.0346


GAN Epoch [40/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [40/50] Completed | Avg D_Loss: 0.1184 | Avg G_Loss: 4.0731
Checkpoint saved to ./saves/checkpoints\gan_checkpoint_epoch_40.pt


GAN Epoch [41/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [41/50] Completed | Avg D_Loss: 0.0510 | Avg G_Loss: 4.1837


GAN Epoch [42/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [42/50] Completed | Avg D_Loss: 0.1529 | Avg G_Loss: 3.9959


GAN Epoch [43/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [43/50] Completed | Avg D_Loss: 0.1424 | Avg G_Loss: 3.9723


GAN Epoch [44/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [44/50] Completed | Avg D_Loss: 0.0326 | Avg G_Loss: 4.3615


GAN Epoch [45/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [45/50] Completed | Avg D_Loss: 0.1022 | Avg G_Loss: 4.3502


GAN Epoch [46/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [46/50] Completed | Avg D_Loss: 0.0846 | Avg G_Loss: 4.3919


GAN Epoch [47/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [47/50] Completed | Avg D_Loss: 0.1385 | Avg G_Loss: 3.9699


GAN Epoch [48/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [48/50] Completed | Avg D_Loss: 0.0329 | Avg G_Loss: 4.4216


GAN Epoch [49/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [49/50] Completed | Avg D_Loss: 0.0241 | Avg G_Loss: 4.8433


GAN Epoch [50/50]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [50/50] Completed | Avg D_Loss: 0.2238 | Avg G_Loss: 3.8056
Checkpoint saved to ./saves/checkpoints\gan_checkpoint_epoch_50.pt
--- Training Finished ---


## 4. Finalize

In [None]:
# End the WandB run
wandb.finish()
print("WandB run finished.")