In [115]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [116]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc=nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )
    def forward(self,x):
        return self.disc(x)

In [117]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen=nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh()
        )
    def forward(self,x):
        return self.gen(x)

In [118]:
device="cuda" if torch.cuda.is_available else "cpu"
device

'cuda'

In [119]:
lr=3e-4
lr

0.0003

In [120]:
z_dim=64

My understanding of latent space is that the numbers of xertian number of dimensions are passed ike 64 dimensions. the neural network
updates the weights in itself to produce images from these latent space. when the training is fully good you can move around the latent space to produce images

In [121]:
img_dim=28*28

In [122]:
batch_size=32
epochs=50

In [123]:
disc=Discriminator(img_dim=img_dim).to(device)
gen=Generator(z_dim=z_dim,img_dim=img_dim).to(device)

In [124]:
fixed_noise=torch.randn((batch_size,z_dim)).to(device)

In [125]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5,std=0.5)
])

In [126]:
dataset=datasets.MNIST(root='data/',transform=transform,download=True)

In [127]:
loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True)

In [128]:
opt_disc=optim.Adam(disc.parameters(),lr=lr)
opt_gen=optim.Adam(gen.parameters(),lr=lr)

In [129]:
criterion=nn.BCELoss()

In [130]:
writer_fake=SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real=SummaryWriter(f"runs/GAN_MNIST/real")
step=0

D(real) discriminators prediction on real image- it wants 1 
g(z) image generated from noise by the generator
D(g(z)) discriminators prediction on fake imgae it wants 0


L=−[y⋅log(p)+(1−y)⋅log(1−p)] BCE loss

with torches.oneslike 1-y makes it zero you are left with -logp
y is the ones like or zeroes like 
you get log(1-p)

disc real is D(real)
disc fake is D(g(z))

In [132]:
for epochs in range(epochs):
    for batch_idx,(real,_) in enumerate(loader):
        real=real.view(-1,784).to(device)
        batch_size=real.shape[0]

        ##Train discrimintator :max log(D(real)) + log(1-D(G(z)))
        noise=torch.randn(batch_size,z_dim).to(device)
        fake=gen(noise)
        disc_real=disc(real).view(-1)#squashes them into 1d
        lossD_real=criterion(disc_real,torch.ones_like(disc_real))#max log(D(real)) trying to get this
        disc_fake=disc(fake).view(-1)
        lossD_fake=criterion(disc_fake,torch.zeros_like(disc_fake))# log(1-D(G(z)) trying to get this
        lossD=(lossD_fake+lossD_real)/2 #averaging
        disc.zero_grad()
        lossD.backward(retain_graph=True)#when you do this it cleares all the cache like in fake is cleared everything
        #compututaional grpah is the graph of the formulas
        opt_disc.step()

        #train Generator minlog(1-D(G(z))) this lead to vanishing gradients so we max(log(d(g(z))))
        output=disc(fake).view(-1)
        lossG=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()#get the gradients
        opt_gen.step()#update teh discriminators weights

        if batch_idx==0:
            print(f"{epochs+1}")
            print(f"Loss Generator:{lossG}\n Loss Discriminator:{lossD}")

            with torch.no_grad():
                fake=gen(fixed_noise).reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake=torchvision.utils.make_grid(fake,normalize=True)
                img_grid_real=torchvision.utils.make_grid(data,normalize=True)
                writer_fake.add_image(
                    "MNIST Fake image",img_grid_fake,global_step=step
                )
                writer_real.add_image(
                    "MNIST Real image",img_grid_real,global_step=step
                )
                writer_real.flush()
                writer_fake.flush()
                step+=1

1
Loss Generator:1.181633710861206
 Loss Discriminator:0.5205905437469482
2
Loss Generator:0.9782984256744385
 Loss Discriminator:0.5327741503715515
3
Loss Generator:1.2214889526367188
 Loss Discriminator:0.5596961379051208
4
Loss Generator:1.2432044744491577
 Loss Discriminator:0.5784693360328674
5
Loss Generator:1.1513179540634155
 Loss Discriminator:0.5754204392433167
6
Loss Generator:1.1876444816589355
 Loss Discriminator:0.6705034971237183
7
Loss Generator:1.2364444732666016
 Loss Discriminator:0.7147685289382935
8
Loss Generator:0.9844697117805481
 Loss Discriminator:0.5824205875396729
9
Loss Generator:1.0112885236740112
 Loss Discriminator:0.7141392827033997
10
Loss Generator:1.198325514793396
 Loss Discriminator:0.5414620637893677
11
Loss Generator:0.8715474009513855
 Loss Discriminator:0.6740407943725586
12
Loss Generator:0.9233049750328064
 Loss Discriminator:0.6600501537322998
13
Loss Generator:1.068987488746643
 Loss Discriminator:0.581080436706543
14
Loss Generator:0.93708