In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter 
from PIL import Image
import numpy as np
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  
        )

    def forward(self, x):
        return self.gen(x)



device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0


In [None]:
G_losses = []
img_list = []
D_losses = []
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)  #to generate random noise
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real)) #log(D(x))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))  #log(1-D(G(z)))
        lossD = (lossD_real + lossD_fake) / 2
        numpylossD=lossD.detach().cpu().numpy() #converting to numpy array
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        numpylossG=lossG.detach().cpu().numpy()
        
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )
            G_losses.append(lossG.item())
            D_losses.append(lossD.item())
        if (step % 500 == 0) or ((epoch == num_epochs-1) and (batch_idx == len(loader)-1)):
            with torch.no_grad():
                fake = gen(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
#             real_batch = next(iter(loader))

# # Plot the real images
#             plt.figure(figsize=(15,15))
#             plt.subplot(1,2,1)
#             plt.axis("off")
#             plt.title("Real Images")
#             plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

#             # Plot the fake images from the last epoch
#             plt.subplot(1,2,2)
#             plt.axis("off")
#             plt.title("Fake Images")
#             plt.imshow(np.transpose(img_list[-1],(1,2,0)))
#             plt.show()

#             with torch.no_grad():
#                 fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
#                 data = real.reshape(-1, 1, 28, 28)
#                 img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
#                 img_grid_real = torchvision.utils.make_grid(data, normalize=True)
               
                # im=img_grid_fake.detach().cpu().numpy()*255
                # im = im.astype(np.uint8)
                # im = Image.fromarray(im)
                # print(im)
                # img = Image.fromarray((im * 255).astype(np.uint8))
                # img.show()
                # writer_fake.add_image(
                #     "Mnist Fake Images", img_grid_fake, global_step=step
                # )
                # writer_real.add_image(
                #     "Mnist Real Images", img_grid_real, global_step=step
                # )
        step+=1

Epoch [0/50] Batch 0/1875                       Loss D: 0.6759, loss G: 0.7918
Epoch [1/50] Batch 0/1875                       Loss D: 0.3874, loss G: 1.3801
Epoch [2/50] Batch 0/1875                       Loss D: 0.5009, loss G: 1.0968
Epoch [3/50] Batch 0/1875                       Loss D: 0.8164, loss G: 0.7524
Epoch [4/50] Batch 0/1875                       Loss D: 0.8706, loss G: 0.9132


In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
real_batch = next(iter(loader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()