In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter



In [2]:
class Discriminator(nn.Module):

    def __init__(self, img_dim):

        super().__init__()
        self.fc = nn.Linear(img_dim, 128)
        self.leaky_relu = nn.LeakyReLU(0.01)
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    
    def forward(self, x):
        x = self.fc(x)
        x = self.leaky_relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)

        return x

In [3]:
class Generator(nn.Module):

    def __init__(self, z_dim, img_dim):

        super().__init__()
        self.fc = nn.Linear(z_dim, 256)
        self.leaky_relu = nn.LeakyReLU(0.01)
        self.fc2 = nn.Linear(256, img_dim)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.fc(x)
        x = self.leaky_relu(x)
        x = self.fc2(x)
        x = self.tanh(x)
        return x

In [4]:
# hyperparams
device = "mps" if torch.backends.mps.is_available() else "cpu"
lr = 1e-4
z_dim = 64
img_dim = 28 * 28
batch_size = 32
epochs = 200


In [5]:
# Init Discriminator and Generator
dis= Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)


In [6]:
# transformation
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))]
)

In [7]:
# using MNIST dataset

dataset = datasets.MNIST(root="datasets/", 
                         transform=transforms, 
                         download=False)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_dis = optim.Adam(dis.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0


In [8]:
# Training loop
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        # Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        dis_real = dis(real).view(-1)
        dis_fake = dis(fake).view(-1)

        lossD = -torch.sum((torch.log(dis_real) + 
                            torch.log(1 - dis_fake)))

        dis.zero_grad()
        lossD.backward(retain_graph=True)
        opt_dis.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = dis(fake).view(-1)
        lossG = torch.sum(torch.log(1 - output))

        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1



  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Epoch [0/200] Batch 0/1875                       Loss D: 43.1238, loss G: -23.0622
Epoch [1/200] Batch 0/1875                       Loss D: 13.8071, loss G: -7.0354
Epoch [2/200] Batch 0/1875                       Loss D: 40.0646, loss G: -18.2839
Epoch [3/200] Batch 0/1875                       Loss D: 25.9287, loss G: -12.4513
Epoch [4/200] Batch 0/1875                       Loss D: 21.2249, loss G: -9.6153
Epoch [5/200] Batch 0/1875                       Loss D: 35.9114, loss G: -16.7880
Epoch [6/200] Batch 0/1875                       Loss D: 48.4230, loss G: -25.9472
Epoch [7/200] Batch 0/1875                       Loss D: 52.0968, loss G: -23.8745
Epoch [8/200] Batch 0/1875                       Loss D: 12.1634, loss G: -5.3786
Epoch [9/200] Batch 0/1875                       Loss D: 11.8042, loss G: -6.3946
Epoch [10/200] Batch 0/1875                       Loss D: 21.8777, loss G: -10.1736
Epoch [11/200] Batch 0/1875                       Loss D: 12.2278, loss G: -5.4694
Epoch [

In [9]:
# Save the model after training
torch.save(gen.state_dict(), "gan_after_training.pt")