In [None]:
import torch.nn as nn
import torch
from tqdm import tqdm_notebook as tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets
import torchvision.transforms as transforms
import torchvision

In [None]:
from DCGAN import *

In [None]:
dataset=torchvision.datasets.MNIST(root='dataset/',download=True,transform=transforms.ToTensor())
loader=DataLoader(dataset,batch_size=128,shuffle=True)

In [None]:

epochs=100
z_dim=64
device="cuda"
lr=0.0001
image_dim=28*28*1
batch_size=128



In [None]:
disc=Discriminator(image_dim).to(device)
gen=Generator(z_dim,image_dim).to(device)
opt_disc=torch.optim.Adam(disc.parameters(),lr=lr)
opt_gen=torch.optim.Adam(gen.parameters(),lr=lr)
writer_fake=SummaryWriter("fake")
writer_real=SummaryWriter("real")


# Calculating with BCE Loss

In [None]:
criterion=nn.BCEWithLogitsLoss()

In [None]:
for epoch in range(epochs):
    for batch_id,(real,_) in enumerate(loader):
        real=real.view(-1,784).to(device)
        noise=torch.randn(batch_size,z_dim).to(device)
        fake=gen(noise)
        disc_real=disc(real).view(-1)
        lossD_real=criterion(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake).view(-1)
        lossD_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD=(lossD_real+lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        output=disc(fake).view(-1)
        lossG=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if(batch_id==0):
            with torch.no_grad():
                fake=gen(noise).reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake=torchvision.utils.make_grid(fake[:32],normalize=True)
                img_grid_real=torchvision.utils.make_grid(data[:32],normalize=True)
                writer_fake.add_image("MNIST Fake Images",img_grid_fake,global_step=epoch)
                writer_real.add_image("MNIST Real Images",img_grid_real,global_step=epoch)
                print(f"Epoch [{epoch}/{epochs}] Batch {batch_id}/{len(loader)} Loss D: {lossD:.4f}, loss G: {lossG:.4f}")


        

# Calculating with W Loss


In [None]:
def get_gradient(crit, real, fake, epsilon):
    mixed_images = real * epsilon + fake * (1 - epsilon)
    mixed_scores = disc(mixed_images)
    
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

In [None]:
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
def get_gen_loss(crit_fake_pred):
    return -1*crit_fake_pred.mean()

In [None]:
def get_crit_loss(crit_fake_pred, crit_real_pred,gp, c_lambda):
    return crit_fake_pred.mean() - crit_real_pred.mean()+c_lambda*gp

In [None]:
cur_step = 0
c_lambda=10
crit_repeats=5
generator_losses = []
critic_losses = []
for epoch in range(epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(loader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            opt_disc.zero_grad()
            fake_noise = torch.randn(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = disc(fake.detach())
            crit_real_pred = disc(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(disc, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            opt_disc.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        opt_gen.zero_grad()
        fake_noise_2 = torch.rand(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = disc(fake_2)
        
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        opt_gen.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]