In [4]:
pip install torchvision




In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision

class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Tanh()  #range [-1, 1]
        )

    def forward(self, x):
        return self.net(x)
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  #range [0, 1]
        )

    def forward(self, x):
        return self.net(x)

#hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
noise_dim = 100
img_dim = 28 * 28
batch_size = 128
epochs = 10
lr = 0.0002

#load dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

#initialize the models
generator = Generator(noise_dim, img_dim).to(device)
discriminator = Discriminator(img_dim).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

#loss function
criterion = nn.BCELoss()

#training
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.view(-1, img_dim).to(device)  #flatten images
        batch_size = real_imgs.size(0)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ---------------------
        optimizer_D.zero_grad()
        #loss
        real_loss = criterion(discriminator(real_imgs), real_labels)
        #do fake images
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_imgs = generator(noise)
        #loss
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
        #total loss and backprop
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        optimizer_G.zero_grad()
        #do fake images and calculate loss
        fake_imgs = generator(noise)
        g_loss = criterion(discriminator(fake_imgs), real_labels)  #want the generator to fool the discriminator
        g_loss.backward()
        optimizer_G.step()
    print(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    #do and save sample images
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            sample_noise = torch.randn(16, noise_dim).to(device)
            generated_imgs = generator(sample_noise).view(-1, 1, 28, 28)
            save_path = f"./gan_sample_epoch_{epoch+1}.png"
            torchvision.utils.save_image(generated_imgs, save_path, nrow=4, normalize=True)


Epoch [1/10] | D Loss: 0.5513 | G Loss: 1.6891
Epoch [2/10] | D Loss: 0.8333 | G Loss: 5.7554
Epoch [3/10] | D Loss: 0.2645 | G Loss: 2.4879
Epoch [4/10] | D Loss: 0.6480 | G Loss: 1.0492
Epoch [5/10] | D Loss: 0.4894 | G Loss: 1.7415
Epoch [6/10] | D Loss: 0.5301 | G Loss: 1.9277
Epoch [7/10] | D Loss: 0.6832 | G Loss: 3.6620
Epoch [8/10] | D Loss: 0.8068 | G Loss: 1.4979
Epoch [9/10] | D Loss: 0.7813 | G Loss: 1.8743
Epoch [10/10] | D Loss: 0.8428 | G Loss: 1.4181
