In [3]:
# !conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [29]:
class Discriminator(nn.Module):
    
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.disc(x)

In [30]:
class Generator(nn.Module):
    
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.gen(x)

In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [32]:
lr = 3e-4

In [39]:
z_dim = 64 # 128, 512

In [40]:
image_dim = 28 * 28 * 1

In [41]:
batch_size = 32

In [59]:
num_epochs = 10

In [60]:
disc = Discriminator(image_dim).to(device)

In [61]:
gen = Generator(z_dim, image_dim).to(device)

In [62]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

In [63]:
# Mean and SD is 0.1307, 0.3081
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

In [64]:
dataset = datasets.MNIST(root='dataset/', transform=transform, download=True)

In [65]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [66]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

In [67]:
criterion = nn.BCELoss()

In [68]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

In [69]:
step = 0

In [70]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        
        # Train Discriminator max log(D(real)) + log(1 - D(g(z)))
        noise = torch.randn((batch_size, z_dim)).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        # Train Generator min log(1 - D(G(z))) <--> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/10] Batch 0/1875                       Loss D: 0.6499, loss G: 0.7425
Epoch [1/10] Batch 0/1875                       Loss D: 0.2653, loss G: 1.7771
Epoch [2/10] Batch 0/1875                       Loss D: 0.1128, loss G: 2.6949
Epoch [3/10] Batch 0/1875                       Loss D: 0.0893, loss G: 3.7012
Epoch [4/10] Batch 0/1875                       Loss D: 0.0764, loss G: 4.3006
Epoch [5/10] Batch 0/1875                       Loss D: 0.0414, loss G: 4.3429
Epoch [6/10] Batch 0/1875                       Loss D: 0.0770, loss G: 4.5209
Epoch [7/10] Batch 0/1875                       Loss D: 0.0288, loss G: 4.3206
Epoch [8/10] Batch 0/1875                       Loss D: 0.0157, loss G: 5.1158
Epoch [9/10] Batch 0/1875                       Loss D: 0.0141, loss G: 5.0246
