In [1]:
import torch
from torch import nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
torch.__version__

'1.13.0+cu116'

#  Discrimmnator Class

In [3]:
class Discrimmnator(nn.Module):
    def __init__(self, imgSize) -> None:
        super(Discrimmnator, self).__init__();
        self.disc = nn.Sequential(
            nn.Linear(imgSize, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.disc(x)

# Testing of Class
# MockDiscrimmnator = Discrimmnator(784)
# outputValueFromDisc = MockDiscrimmnator(torch.randn(1,784))
# outputValueFromDisc

# Generator Class

In [4]:
class Generator(nn.Module):
    def __init__(self, latentDim, imgSize) -> None :
        super(Generator, self).__init__();
        self.gen = nn.Sequential(
            nn.Linear(latentDim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, imgSize),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.gen(x)

# Testing of Class
# MockGen = Generator(784, 784)
# outgen = MockGen(torch.randn(1,784)) #Due to Tanhall values are from [-1,1]
# outgen.shape

In [5]:
# HyperParamters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
latentDim = 64 # Can give 128, 256
imgDim = 28*28*1 # 784
batch_size = 32
num_epochs = 5
device

'cuda'

In [6]:
disc = Discrimmnator(imgDim).to(device)
gen = Generator(latentDim, imgDim).to(device)

#Fixed Noise
fixedNoise = torch.randn((batch_size, latentDim)).to(device)

# Tranforms for MNIST data
transforms = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))]
)

In [7]:
# Datasets and Dataloader
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=False)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
# Criterion and Optimisers
optmDisc = optim.Adam(disc.parameters(),lr=lr)
optmGen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

In [9]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

In [10]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, latentDim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optmDisc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        optmGen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixedNoise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/5] Batch 0/1875                       Loss D: 0.7166, loss G: 0.6638
Epoch [1/5] Batch 0/1875                       Loss D: 0.4378, loss G: 1.2102
Epoch [2/5] Batch 0/1875                       Loss D: 0.7560, loss G: 0.8700
Epoch [3/5] Batch 0/1875                       Loss D: 0.8084, loss G: 0.8865
Epoch [4/5] Batch 0/1875                       Loss D: 0.4656, loss G: 1.4100


In [11]:
%load_ext tensorboard

In [12]:
%tensorboard --logdir runs/