In [1]:
"""GAN's Generator-> create fake image
   GAN's Discriminator->  predict fake or real image                    """
#purpose maximize discrimator minimize generator



import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter



In [2]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc=nn.Sequential(
            nn.Linear(img_dim,128), #img_dim->28x28 mnist images
            nn.LeakyReLU(0.2), #risk of dead neurans
            nn.Linear(128,1),
            nn.Sigmoid() #[0,1]
        )

    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self,z_dim,img_dim): #z_dim->noise dimension
        super().__init__()
        self.gen=nn.Sequential(
            nn.Linear(z_dim,256),#Convert noise vector to 256 neurons
            nn.LeakyReLU(0.2),
            nn.Linear(256,img_dim),#translates to an image that is 784 pixels by 256 inches long.
            nn.Tanh(),#[-1,1] normalize image
         )

    def forward(self, x):
        return self.gen(x)


#Hyperparams
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate=3e-4
z_dim=64
image_dim=784#28*28*1
batch_size=32
num_epochs=35


#normalize MNIST dset mean=0 std=1
disc=Discriminator(image_dim).to(device)
gen=Generator(z_dim,image_dim).to(device)
fixed_noise=torch.randn(batch_size,z_dim).to(device)
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])#

#dataLoader->batch
dataset=datasets.MNIST(root='./data_mnist',train=True,transform=transforms,download=True)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc=optim.Adam(disc.parameters(),lr=learning_rate)
opt_gen=optim.Adam(gen.parameters(),lr=learning_rate)
criterion=nn.BCELoss()
writer_fake=SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real=SummaryWriter(f"runs/GAN_MNIST/real")
step=0

for epoch in range(num_epochs):
    for batch_idx,(real,_) in enumerate(loader):
        real=real.view(-1,784).to(device) #The image is made into a 28x28 → 784 flat vector.
        batch_size=real.shape[0]

        """Train Discriminator: max log(D(real))+log1-D(G(z))"""
        noise=torch.randn(batch_size,z_dim).to(device)
        fake=gen(noise)#Generate fake image from random noise.

        #Label real images as 1 and fake images as 0.
        disc_real=disc(real).view(-1)
        lossD_real=criterion(disc_real,torch.ones_like(disc_real))

        disc_fake=disc(fake).view(-1)
        lossD_fake=criterion(disc_fake,torch.zeros_like(disc_fake))

        #avg the two losses
        lossD=(lossD_real+lossD_fake)/2
        disc.zero_grad()
        #backprop
        lossD.backward(retain_graph=True)
        opt_disc.step()


        """Train Generator: min log(1-D(G(z)))->max log(D(G(z)))"""
        #The Generator wants to fool the Discriminator → The fake images should look like “1” (real).
        output=disc(fake).view(-1)
        lossG=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1


100.0%
100.0%
100.0%
100.0%


Epoch [0/35] Batch 0/1875                       Loss D: 0.6689, loss G: 0.6960
Epoch [1/35] Batch 0/1875                       Loss D: 0.2933, loss G: 1.8737
Epoch [2/35] Batch 0/1875                       Loss D: 0.1392, loss G: 2.6075
Epoch [3/35] Batch 0/1875                       Loss D: 0.0743, loss G: 3.0826
Epoch [4/35] Batch 0/1875                       Loss D: 0.1858, loss G: 3.9209
Epoch [5/35] Batch 0/1875                       Loss D: 0.0401, loss G: 3.8737
Epoch [6/35] Batch 0/1875                       Loss D: 0.1053, loss G: 4.2930
Epoch [7/35] Batch 0/1875                       Loss D: 0.0323, loss G: 4.6096
Epoch [8/35] Batch 0/1875                       Loss D: 0.1195, loss G: 4.8633
Epoch [9/35] Batch 0/1875                       Loss D: 0.0612, loss G: 4.6216
Epoch [10/35] Batch 0/1875                       Loss D: 0.0838, loss G: 5.1876
Epoch [11/35] Batch 0/1875                       Loss D: 0.0376, loss G: 4.5221
Epoch [12/35] Batch 0/1875                       L

In [4]:
"""%load_ext tensorboard
%tensorboard --logdir runs"""

'%load_ext tensorboard\n%tensorboard --logdir runs'