In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim), # 28*28*1 -> 784
        nn.Tanh(), # normalize inputs to [-1, 1], so make outputs [-1, 1]
    )

  def forward(self, x):
    return self.gen(x)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1 # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.MNIST(
    root="dataset/",
    transform=transforms,
    download=True
)
loader = DataLoader(dataset, batch_size = batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

step = 0


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 485kB/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 8.40MB/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [4]:
for epoch in range(num_epochs):
  for batch_idx, (real, label) in enumerate(loader):
    # batch_idx = int
    # real.shape = (batch_size, channel, height, width)
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    ## Train discriminator: max log(D(real)) + log(1-D(G(z)))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    ## Train Generator: min log(1-D(G(z))) <-> max log(D(G(z)))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()


    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
            "Mnist fake images", img_grid_fake, global_step=step
        )
        writer_real.add_image(
            "Mnist real images", img_grid_real, global_step=step
        )


Epoch [0/50] Loss D: 0.6006, loss G: 0.6608
Epoch [1/50] Loss D: 0.8731, loss G: 0.6363
Epoch [2/50] Loss D: 0.7240, loss G: 0.7823
Epoch [3/50] Loss D: 0.9306, loss G: 0.7126
Epoch [4/50] Loss D: 0.6595, loss G: 0.9400
Epoch [5/50] Loss D: 0.7465, loss G: 0.8753
Epoch [6/50] Loss D: 0.8142, loss G: 0.9120
Epoch [7/50] Loss D: 0.5021, loss G: 1.1905
Epoch [8/50] Loss D: 0.4907, loss G: 1.3925
Epoch [9/50] Loss D: 0.7406, loss G: 0.5375
Epoch [10/50] Loss D: 0.5480, loss G: 1.1605
Epoch [11/50] Loss D: 0.3682, loss G: 1.7964
Epoch [12/50] Loss D: 0.6903, loss G: 0.8703
Epoch [13/50] Loss D: 0.7642, loss G: 0.9952
Epoch [14/50] Loss D: 0.6353, loss G: 1.1871
Epoch [15/50] Loss D: 0.4674, loss G: 1.2122
Epoch [16/50] Loss D: 0.6952, loss G: 1.2240
Epoch [17/50] Loss D: 0.5857, loss G: 1.2320
Epoch [18/50] Loss D: 0.4437, loss G: 1.3437
Epoch [19/50] Loss D: 0.5335, loss G: 1.6024
Epoch [20/50] Loss D: 0.5383, loss G: 1.3981
Epoch [21/50] Loss D: 0.7233, loss G: 1.0100
Epoch [22/50] Loss D

In [5]:
!tensorboard lostdir runs

2025-03-14 06:36:04.928342: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741934164.949939    4700 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741934164.956117    4700 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
usage: tensorboard [-h] [--helpfull] {serve} ...
tensorboard: error: argument {serve}: invalid choice: 'lostdir' (choose from 'serve')
