In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transform
from torch.utils.tensorboard import SummaryWriter

In [23]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc=nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        return self.disc(x)

In [24]:
class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super().__init__()
        self.gen=nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh(),
        )
    def forward(self,x):
        return self.gen(x)

In [25]:
transform=transform.Compose([
    transform.ToTensor(),transform.Normalize(mean=0.5,std=0.5)
])

In [26]:
dataset=datasets.MNIST(root='data/',transform=transform,download=True)

In [27]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [28]:
lr=3e-4
lr

0.0003

In [29]:
z_dim=100
img_dim=28*28

In [30]:
gen=Generator(z_dim=z_dim,img_dim=img_dim).to(device)

In [31]:
disc=Discriminator(img_dim=img_dim).to(device)

In [32]:
batch_size=32

In [33]:
fixed_noise=torch.randn(batch_size,z_dim).to(device)

In [34]:
loader=DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

In [35]:
opt_disc=optim.Adam(disc.parameters(),lr=lr)
opt_gen=optim.Adam(gen.parameters(),lr=lr)

In [36]:
criterion=nn.BCELoss()

In [37]:
writer_fake=SummaryWriter(f"runs/RGAN_MNIST/fake")
writer_real=SummaryWriter(f"runs/RGAN_MNIST/real")
step=0

In [38]:
for epochs in range(200):
    for batch_idx,(real,_) in enumerate(loader):
        real=real.view(-1,784).to(device)
        batch_size=real.shape[0]

        noise=torch.randn(batch_size,z_dim).to(device)
        fake=gen(noise)
        disc_real=disc(real).view(-1)
        lossD_real=criterion(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake).view(-1)
        lossD_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
        loss_D=(lossD_fake+lossD_real)/2#optmizing the difference between the fake and gen image pred is what disc is doing
        disc.zero_grad()
        loss_D.backward(retain_graph=True)
        opt_disc.step()

        output=disc(fake).view(-1)#fake is the image generated, optimizing the generated image loss is what generator is doing
        loss_G=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        loss_G.backward()
        opt_gen.step()

        if batch_idx==0:
            print(f"Epoch:{epochs}")
            print(f"Loss Generator:{loss_G}\nLoss Discriminator:{loss_D}")

            with torch.no_grad():
                fake=gen(fixed_noise).reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake=torchvision.utils.make_grid(fake,normalize=True)
                img_grid_real=torchvision.utils.make_grid(real,normalize=True)
                writer_fake.add_image(
                    "MNIST_Generated_Images",img_grid_fake,global_step=step
                )
                writer_real.add_image(
                    "MNIST_Real_Images",img_grid_real,global_step=step
                )
                step+=1


Epoch:0
Loss Generator:0.7285459041595459
Loss Discriminator:0.6496785879135132
Epoch:1
Loss Generator:1.3258187770843506
Loss Discriminator:0.42927712202072144
Epoch:2
Loss Generator:0.8693891763687134
Loss Discriminator:0.6125895380973816
Epoch:3
Loss Generator:0.7162713408470154
Loss Discriminator:0.831730306148529
Epoch:4
Loss Generator:1.039095163345337
Loss Discriminator:0.4831978678703308
Epoch:5
Loss Generator:1.163738489151001
Loss Discriminator:0.4721294641494751
Epoch:6
Loss Generator:0.5646107792854309
Loss Discriminator:0.9028291702270508
Epoch:7
Loss Generator:1.0207855701446533
Loss Discriminator:0.6893266439437866
Epoch:8
Loss Generator:1.0265462398529053
Loss Discriminator:0.630061149597168
Epoch:9
Loss Generator:0.9728230834007263
Loss Discriminator:0.5132157206535339
Epoch:10
Loss Generator:1.0659806728363037
Loss Discriminator:0.6028090715408325
Epoch:11
Loss Generator:1.051514983177185
Loss Discriminator:0.6069328188896179
Epoch:12
Loss Generator:1.200023889541626
