In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision
from datetime import datetime

now = datetime.now()

In [2]:
class Discriminator(nn.Module):
    def __init__(self,input_features):
        super().__init__()
        self.disc=nn.Sequential(
            nn.Linear(input_features,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def __init__(self,z_dim,input_features):
        super().__init__()
        self.gen=nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,input_features),
            nn.Tanh() #pixel value between -1 and 1
        )
    def forward(self,x):
        return self.gen(x)

In [4]:
DEVICE="cuda" if torch.cuda.is_available() else "cpu"
lr=3e-5
z_dim=32
image_dim=28*28*1
batch_size=32
num_epochs=50

disc=Discriminator(image_dim).to(DEVICE)
gen=Generator(z_dim,image_dim).to(DEVICE)

fixed_noise=torch.randn((batch_size,z_dim)).to(DEVICE)
transforms= transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))]
)
dataset = datasets.MNIST(root="datasets/",transform=transforms,download=True)
loader=DataLoader(dataset, batch_size=batch_size,shuffle=True)
opt_disc=optim.Adam(disc.parameters(),lr=lr)
opt_gen=optim.Adam(gen.parameters(),lr=lr)
critereon=nn.BCELoss()
writer_fake=SummaryWriter(f"runs/GAN/fake/"+ now.strftime("%Y%m%d-%H%M%S") + "/")
writer_real=SummaryWriter(f"runs/GAN/real/"+ now.strftime("%Y%m%d-%H%M%S") + "/")

step=0
gen.train()
disc.train()
for epoch in range(num_epochs):
    for batch_index, (real,_) in enumerate(loader):
        real=real.view(-1,784).to(DEVICE)
        batch_size=real.shape[0]

        #Discriminator loss: max(log(D(real)) + log(1 - D(G(z))) )
        noise=torch.randn(batch_size,z_dim).to(DEVICE)
        fake_img=gen(noise)
        disc_real=disc(real).view(-1)
        lossD_real=critereon(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake_img.detach()).view(-1)
        lossD_fake=critereon(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_fake+lossD_real)/2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        #Train Generator
        output=disc(fake_img).view(-1)
        lossG=critereon(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if(batch_index==0):
            print(f'[{epoch}/{num_epochs}--Loss(D):{lossD:.4f}--Loss(G):{lossG:.4f}')

        #start training
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake=torchvision.utils.make_grid(fake,normalize=True)
                img_grid_real=torchvision.utils.make_grid(data,normalize=True)

                writer_fake.add_image(
                    "Fake img1",img_grid_fake,global_step=step
                )
                writer_real.add_image(
                    "Real img1",img_grid_real,global_step=step
                )

                step+=1



[0/50--Loss(D):0.6350--Loss(G):0.6812
[1/50--Loss(D):0.1324--Loss(G):2.3503
[2/50--Loss(D):0.0570--Loss(G):3.1967
[3/50--Loss(D):0.0498--Loss(G):3.8571
[4/50--Loss(D):0.0484--Loss(G):3.6545
[5/50--Loss(D):0.0756--Loss(G):4.0169
[6/50--Loss(D):0.0216--Loss(G):5.1679
[7/50--Loss(D):0.0189--Loss(G):4.8751
[8/50--Loss(D):0.0484--Loss(G):4.3366
[9/50--Loss(D):0.0130--Loss(G):4.7509
[10/50--Loss(D):0.0093--Loss(G):4.9355
[11/50--Loss(D):0.0058--Loss(G):5.5140
[12/50--Loss(D):0.0254--Loss(G):4.2647
[13/50--Loss(D):0.0338--Loss(G):4.8530
[14/50--Loss(D):0.0051--Loss(G):5.9261
[15/50--Loss(D):0.0252--Loss(G):4.8001
[16/50--Loss(D):0.0490--Loss(G):6.5986
[17/50--Loss(D):0.0130--Loss(G):5.5162
[18/50--Loss(D):0.0135--Loss(G):5.7511
[19/50--Loss(D):0.0217--Loss(G):6.0421
[20/50--Loss(D):0.0094--Loss(G):5.7583
[21/50--Loss(D):0.0228--Loss(G):5.4186
[22/50--Loss(D):0.0157--Loss(G):5.3369
[23/50--Loss(D):0.0054--Loss(G):6.5307
[24/50--Loss(D):0.0059--Loss(G):6.3735
[25/50--Loss(D):0.0083--Loss(G):6.0