In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from  torchinfo import summary
import os
import torchvision


In [44]:
device= "cuda" if torch.cuda.is_available() else "cpu"
lr=3e-4
z_dim= 64
img_dim = 28*28*1
batch_size = 32
num_epochs = 20

In [45]:
ROOT_DIR = Path.cwd()
IMAGE_DIR=ROOT_DIR.parent.resolve().parent.resolve() / "data"
# ROOT_DIR, IMAGE_DIR, Path.cwd()

(WindowsPath('g:/Code/deep-learning-libray/pytorch/GANs/02_simple_gans'),
 WindowsPath('G:/Code/deep-learning-libray/pytorch/data'),
 WindowsPath('g:/Code/deep-learning-libray/pytorch/GANs/02_simple_gans'))

In [46]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc= nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(.1),
            nn.Linear(128,1),
            nn.Tanh(),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.disc(x)

In [47]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.disc= nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(.1),
            nn.Linear(256,img_dim),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.disc(x)

In [48]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [49]:
dataset = datasets.MNIST(root=IMAGE_DIR, transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [50]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"{ROOT_DIR}/logs/fake")
writer_real = SummaryWriter(f"{ROOT_DIR}/logs/real")
step = 0
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/20] Batch 0/1875                       Loss D: 0.7793, loss G: 0.6805
Epoch [1/20] Batch 0/1875                       Loss D: 0.6231, loss G: 0.7990
Epoch [2/20] Batch 0/1875                       Loss D: 0.5400, loss G: 0.9060
Epoch [3/20] Batch 0/1875                       Loss D: 0.8133, loss G: 0.3133
Epoch [4/20] Batch 0/1875                       Loss D: 0.8133, loss G: 0.3133
Epoch [5/20] Batch 0/1875                       Loss D: 0.4111, loss G: 1.1707
Epoch [6/20] Batch 0/1875                       Loss D: 0.8139, loss G: 0.3133
Epoch [7/20] Batch 0/1875                       Loss D: 0.8133, loss G: 0.3133
Epoch [8/20] Batch 0/1875                       Loss D: 0.8132, loss G: 0.3133
Epoch [9/20] Batch 0/1875                       Loss D: 0.6138, loss G: 0.8440
Epoch [10/20] Batch 0/1875                       Loss D: 0.8133, loss G: 0.3133
Epoch [11/20] Batch 0/1875                       Loss D: 0.8133, loss G: 0.3133
Epoch [12/20] Batch 0/1875                       L