In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import time

In [2]:
class Generator(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 1, 28, 28)
        return output

In [3]:
class Discriminator(nn.Module):
  
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        output = self.model(x)
        return output

In [None]:
def train_gan(
    batch_size: int = 32,
    num_epochs: int = 100,
    device: str = "cuda:0" if torch.cuda.is_available() else "cpu",
):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_set = torchvision.datasets.MNIST(
        root=".",
        train=True,
        download=True,
        transform=transform
    )

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True
    )

    discriminator = Discriminator().to(device)
    generator = Generator().to(device)

    loss_function = nn.BCELoss()
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0001)

    for epoch in range(num_epochs):
        for n, (real_samples, _) in enumerate(train_loader):

    
            current_batch_size = real_samples.size(0)

            real_samples = real_samples.to(device)

            real_labels = torch.ones((current_batch_size, 1), device=device)
            fake_labels = torch.zeros((current_batch_size, 1), device=device)

            latent_vectors = torch.randn((current_batch_size, 100), device=device)
            fake_samples = generator(latent_vectors)

         
            all_samples = torch.cat((real_samples, fake_samples))
            all_labels = torch.cat((real_labels, fake_labels))

            discriminator.zero_grad()
            output = discriminator(all_samples)
            loss_d = loss_function(output, all_labels)
            loss_d.backward()
            optimizer_d.step()

           
            latent_vectors = torch.randn((current_batch_size, 100), device=device)
            generator.zero_grad()
            generated = generator(latent_vectors)
            output = discriminator(generated)
            loss_g = loss_function(output, real_labels)
            loss_g.backward()
            optimizer_g.step()

       
        clear_output(wait=True)
        print(f"Epoch {epoch} | Loss D: {loss_d:.4f} | Loss G: {loss_g:.4f}")

        samples = generated[:16].detach().cpu()
        fig = plt.figure(figsize=(4, 4))
        for i in range(16):
            ax = fig.add_subplot(4, 4, i + 1)
            ax.imshow(samples[i].reshape(28, 28), cmap="gray_r")
            ax.axis("off")
        plt.show()

train_gan(batch_size=32, num_epochs=100)

Epoch 0 | Loss D: 0.0485 | Loss G: 4.0814
