In [66]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np



In [67]:
# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), 1, 28, 28)




In [68]:
# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity



In [69]:
# Initialize networks
generator = Generator()
discriminator = Discriminator()

# Loss function and optimizer
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training parameters
num_epochs = 200
batch_size = 64
latent_dim = 100
sample_interval = 200



In [73]:
# Training loop
for epoch in range(num_epochs):
    for i in range(len(data_loader)):
        # Adversarial ground truths
        valid = torch.ones(batch_size, 1)
        fake = torch.zeros(batch_size, 1)

        # Generate a batch of images
        z = torch.randn(batch_size, latent_dim)
        gen_imgs = generator(z)

        # Train discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(gen_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward(retain_graph=True)  # Set retain_graph=True
        optimizer_D.step()

        # Train generator
        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward(retain_graph=True)  # Set retain_graph=True
        optimizer_G.step()

        # Print progress
        if i % 100 == 0:
            print(
                f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(data_loader)}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}]"
            )




[Epoch 0/200] [Batch 0/16] [D loss: 0.693163] [G loss: 0.697557]
[Epoch 1/200] [Batch 0/16] [D loss: 0.693590] [G loss: 0.669914]
[Epoch 2/200] [Batch 0/16] [D loss: 0.692946] [G loss: 0.708356]
[Epoch 3/200] [Batch 0/16] [D loss: 0.693256] [G loss: 0.706473]
[Epoch 4/200] [Batch 0/16] [D loss: 0.695358] [G loss: 0.692029]
[Epoch 5/200] [Batch 0/16] [D loss: 0.692916] [G loss: 0.703328]
[Epoch 6/200] [Batch 0/16] [D loss: 0.695857] [G loss: 0.710978]
[Epoch 7/200] [Batch 0/16] [D loss: 0.689528] [G loss: 0.700532]
[Epoch 8/200] [Batch 0/16] [D loss: 0.690515] [G loss: 0.707524]
[Epoch 9/200] [Batch 0/16] [D loss: 0.695085] [G loss: 0.689372]
[Epoch 10/200] [Batch 0/16] [D loss: 0.689901] [G loss: 0.714130]
[Epoch 11/200] [Batch 0/16] [D loss: 0.693741] [G loss: 0.685341]
[Epoch 12/200] [Batch 0/16] [D loss: 0.696112] [G loss: 0.722313]
[Epoch 13/200] [Batch 0/16] [D loss: 0.695022] [G loss: 0.702495]
[Epoch 14/200] [Batch 0/16] [D loss: 0.695686] [G loss: 0.699445]
[Epoch 15/200] [Batc

In [75]:
 # Save generated images at sample interval
if epoch % sample_interval == 0 and i == 0:
            save_image(gen_imgs.data[:25], f"images/{epoch}.png", nrow=5, normalize=True)