In [13]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
    def  __init__(self, img_dim) -> None:
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )
    def forward(self,x):
        return self.disc(x)

In [4]:
class Generator(nn.Module):
    def  __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # MNIMST 28*28*1
            nn.Tanh(),
        )
    def forward(self,x):
        return self.gen(x)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
num_epochs = 50

In [6]:
discriminator = Discriminator(image_dim).to(device)
generator = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # means and std of MNIST 
)
dataset = datasets.MNIST(root = "dataset/", transform=transforms, download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(discriminator.parameters(), lr=lr)
opt_gen = optim.Adam(generator.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [01:19<00:00, 123966.08it/s]


Extracting dataset/MNIST\raw\train-images-idx3-ubyte.gz to dataset/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 54092.20it/s]


Extracting dataset/MNIST\raw\train-labels-idx1-ubyte.gz to dataset/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:16<00:00, 99005.19it/s] 


Extracting dataset/MNIST\raw\t10k-images-idx3-ubyte.gz to dataset/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST\raw\t10k-labels-idx1-ubyte.gz to dataset/MNIST\raw






In [14]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]
        ### Train Discriminator: max (log(D(real)) + log(1 - D(fake)))
        noise = torch.randn(batch_size, z_dim).to(device)  # random noise
        fake = generator(noise)                          # generate a fake image
        disc_real = discriminator(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = discriminator(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2 
        discriminator.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Discriminator: min (log(1 - D(G(z)) --> max log(D(G(z))))
        output = discriminator(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        generator.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        ## tensorboard
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = generator(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.5309, loss G: 0.6450
Epoch [1/50] Batch 0/1875                       Loss D: 0.0744, loss G: 2.6749
Epoch [2/50] Batch 0/1875                       Loss D: 0.0327, loss G: 3.8955
Epoch [3/50] Batch 0/1875                       Loss D: 0.0397, loss G: 3.7457
Epoch [4/50] Batch 0/1875                       Loss D: 0.0356, loss G: 4.4808
Epoch [5/50] Batch 0/1875                       Loss D: 0.0569, loss G: 5.5115
Epoch [6/50] Batch 0/1875                       Loss D: 0.0138, loss G: 6.3366
Epoch [7/50] Batch 0/1875                       Loss D: 0.0083, loss G: 5.2438
Epoch [8/50] Batch 0/1875                       Loss D: 0.0293, loss G: 5.7426
Epoch [9/50] Batch 0/1875                       Loss D: 0.0415, loss G: 5.4944
Epoch [10/50] Batch 0/1875                       Loss D: 0.0074, loss G: 5.3813
Epoch [11/50] Batch 0/1875                       Loss D: 0.0103, loss G: 5.4017
Epoch [12/50] Batch 0/1875                       L