In [1]:
""" 
Simple GAN using fully connected layers
for generating digit '4'

"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.sequential = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.sequential(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.sequential = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.sequential(x)

In [3]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 0.0001
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 64
num_epochs = 100

discriminator = Discriminator(image_dim).to(device)
generator = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.MNIST(root="data/", transform=_transforms, download=True)
# Assuming 'data' is the tensor containing the images
# Assuming 'dataset' is the MNIST dataset

# Get the labels of the dataset
labels = dataset.targets.numpy()

# Find the indices of images with label 4
indices = np.where(labels == 4)[0]
dataset_4 = torch.utils.data.Subset(dataset, indices)

In [4]:
loader = DataLoader(dataset_4, batch_size=batch_size, shuffle=True)
opt_discriminator = optim.Adam(discriminator.parameters(), lr=lr)
opt_generator = optim.Adam(generator.parameters(), lr=lr)
criterion = nn.BCELoss()

In [5]:
step = 0
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        # print("line 4: ",batch_idx,real.shape,_.shape)
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = generator(noise)
        disc_real = discriminator(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = discriminator(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        discriminator.zero_grad()
        lossD.backward(retain_graph=True)
        opt_discriminator.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = discriminator(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        generator.zero_grad()
        lossG.backward()
        opt_generator.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )
            if (step) % 10 == 0:
                with torch.no_grad():
                    fake = generator(fixed_noise).reshape(-1, 1, 28, 28)
                    data = real.reshape(-1, 1, 28, 28)
                    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                    img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                    # Save fake image as PNG
                    # if floder not exist, create it
                    import os
                
                    if not os.path.exists("samples/Simple_GAN_FC/fake"):
                        os.makedirs("samples/Simple_GAN_FC/fake")
                    if not os.path.exists("samples/Simple_GAN_FC/real"):
                        os.makedirs("samples/Simple_GAN_FC/real")
                    vutils.save_image(
                        img_grid_fake, f"samples/Simple_GAN_FC/fake/{epoch}.png", normalize=True
                    )
                    vutils.save_image(
                        img_grid_real, f"samples/Simple_GAN_FC/real/{epoch}.png", normalize=True
                    )

                step += 1

Epoch [0/100] Batch 0/92                       Loss D: 0.6861, loss G: 0.6670
Epoch [1/100] Batch 0/92                       Loss D: 0.4625, loss G: 0.6171
Epoch [2/100] Batch 0/92                       Loss D: 0.5991, loss G: 0.4932
Epoch [3/100] Batch 0/92                       Loss D: 0.4579, loss G: 0.7643
Epoch [4/100] Batch 0/92                       Loss D: 0.3861, loss G: 0.8797
Epoch [5/100] Batch 0/92                       Loss D: 0.4599, loss G: 0.7778
Epoch [6/100] Batch 0/92                       Loss D: 0.5188, loss G: 0.7155
Epoch [7/100] Batch 0/92                       Loss D: 0.5704, loss G: 0.6544
Epoch [8/100] Batch 0/92                       Loss D: 0.4817, loss G: 0.8202
Epoch [9/100] Batch 0/92                       Loss D: 0.4946, loss G: 0.8374
Epoch [10/100] Batch 0/92                       Loss D: 0.5184, loss G: 0.7857
Epoch [11/100] Batch 0/92                       Loss D: 0.5742, loss G: 0.7267
Epoch [12/100] Batch 0/92                       Loss D: 0.5264

In [6]:
def save_checkpoint(state, file_name='checkpoint.pth.tar'):
    path = "./models/Simple_GAN_FC/"
    os.makedirs(path, exist_ok=True)
    torch.save(state, path+file_name)

In [7]:
# Saving params.
save_checkpoint({'epoch': epoch + 1, 'state_dict':discriminator.state_dict(), 'optimizer' : opt_discriminator.state_dict()}, 'D_c.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':generator.state_dict(), 'optimizer' : opt_generator.state_dict()}, 'G_c.pth.tar')