In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [4]:
device='mps'
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 128
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0
for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real=real.view(-1,784).to(device)
        batch_size=real.shape[0]

        ###Train Discriminator: max log(D(real)) + log(1-D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake =gen(noise)
        disc_real=disc(real).view(-1)
        loss_D_real=criterion(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake.detach()).view(-1)
        loss_D_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD=(loss_D_real+loss_D_fake)/2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        ### Train Generator min log(1-G(z)) <--> max(log(G(z))
        output=disc(fake).view(-1)
        lossG=criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx ==0:
            print(f"Epoch [{epoch}/{num_epochs}]"
                  f"Loss D: {lossD: .4f}, loss G: {lossG:.4f}"
                  )
            with torch.no_grad():
                fake=gen(fixed_noise).reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake= torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real=torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "MNIST Fake Images", img_grid_fake,global_step=step
                )

                writer_real.add_image(
                    "MNIST Real Images", img_grid_real, global_step=step
                )

                step+=1



Epoch [0/60]Loss D:  0.5859, loss G: 0.7205
Epoch [1/60]Loss D:  0.2490, loss G: 1.6715
Epoch [2/60]Loss D:  0.7073, loss G: 0.9577
Epoch [3/60]Loss D:  0.2333, loss G: 2.2526
Epoch [4/60]Loss D:  0.3962, loss G: 1.3407
Epoch [5/60]Loss D:  0.7488, loss G: 0.7931
Epoch [6/60]Loss D:  0.7390, loss G: 0.9084
Epoch [7/60]Loss D:  0.8278, loss G: 0.8540
Epoch [8/60]Loss D:  0.3911, loss G: 1.3301
Epoch [9/60]Loss D:  0.3076, loss G: 1.7175
Epoch [10/60]Loss D:  0.4224, loss G: 1.6057
Epoch [11/60]Loss D:  0.7313, loss G: 0.9219
Epoch [12/60]Loss D:  0.3830, loss G: 1.7069
Epoch [13/60]Loss D:  0.6382, loss G: 1.1794
Epoch [14/60]Loss D:  0.7345, loss G: 0.9153
Epoch [15/60]Loss D:  0.4249, loss G: 1.3539
Epoch [16/60]Loss D:  0.6909, loss G: 1.2497
Epoch [17/60]Loss D:  0.9484, loss G: 0.9708
Epoch [18/60]Loss D:  0.6979, loss G: 1.1134
Epoch [19/60]Loss D:  0.6652, loss G: 1.1266
Epoch [20/60]Loss D:  0.6797, loss G: 1.0051
Epoch [21/60]Loss D:  0.5257, loss G: 1.1354
Epoch [22/60]Loss D: