In [1]:
import torch
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, dense_units):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(input_dim, dense_units),
            nn.LeakyReLU(0.1), # a default choice for GANs
            nn.Linear(dense_units, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_dim, dense_units):
        self.gen = nn.Sequential(
            nn.Linear(latent_dim, dense_units),
            nn.LeakyReLU(0,1),
            nn.Linear(dense_units, img_dim),
            nn.Tanh()
        )
    def forward(self, x):
        return self.gen(x)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-4
latent_dim = 8*8
image_dim = 28 * 28 * 1,
dense_units = 256
batch_size = 32
epochs = 50

In [None]:
disc = Discriminator(image_dim, dense_units).to(device)
gen = Generator(latent_dim, image_dim, dense_units).to(device)

fixed_noise = torch.randn((batch_size, latent_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

disc_optimizer = optim.Adam(disc.parameters(), lr=lr)
gen_optimizer = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

writer_fake = SummaryWriter(f'runs/GAN/fake')
writer_real = SummaryWriter(f'runs/GAN/real')
step = 0

In [None]:
for epoch in range(epochs):
    for batch_idx, (x, _) in enumerate(loader):
        x = x.to(device)
        x = x.view(-1, image_dim)

        # Disc training: max log(D(real)) + log(1 - D(G(x)))
        # BCELoss = -w_n [y_n log(x_n) + (1-y_n) log(1 - x_n)]
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake = gen(noise)
        disc_real = disc(x).view(-1)

        # only y_n log(x_n)
        loss_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake.detach()).view(-1)
        # only (1 - y_n)log(1 - x_n)
        loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (loss_fake + loss_real) / 2
        disc.zero_grad()
        # if the retain is not true, the nodes of the variable graph will be
        # removed from to save memory, so we wouldn't be able to use the fake again
        lossD.backward(retain_graph=True)
        disc_optimizer.step()

        # Train Generator: min log(1-D(G(z))) <-> max log(D(G(z)))
        output_disc = disc(fake).view(-1).to(device)
        loss_gen = criterion(output_disc, torch.ones_like(output_disc))
        gen.zero_grad()
        loss_gen.backward()
        gen_optimizer.step()


        with torch.no_grad():
            if batch_idx == 0:
                print(f' epoch {epoch}/{epochs} gen_loss: {loss_gen} disc_loss: {lossD}')

                fake_grid = torchvision.utils.make_grid(gen(noise).reshape(-1, 1, 28, 28), normalize=True)
                real_grid = torchvision.utils.make_grid(x.reshape(-1, 1, 28, 28), normalize=True)

                
                writer_fake.add_image(f'fake_',fake_grid, global_step=step)
                writer_real.add_image(f'real_', real_grid, global_step=step)

                step += 1