In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard
import albumentations as A

In [None]:
class Generator(nn.Module):
  def __init__(self,latent_dim,img_channels,feature_g): # latent_dim=(-1,100,1,1),feature_g=64
    super(Generator,self).__init__()
    self.gen = nn.Sequential(
        self.block(latent_dim,feature_g*16,4,1,0), # 4,4,1024
        self.block(feature_g*16,feature_g*8,4,2,1), # 8,8,512
        self.block(feature_g*8,feature_g*4,4,2,1), # 16,16,256
        self.block(feature_g*4,feature_g*2,4,2,1), # 32,32,128
        nn.ConvTranspose2d(feature_g*2,img_channels,4,2,1),# 64,64,3
        nn.Tanh()
    )

  def block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())

  def forward(self,x):
    return self.gen(x)

class Discriminator(nn.Module):
  def __init__(self,img_channels,feature_d):
    super(Discriminator,self).__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(img_channels,feature_d*2,4,2,1),# 32,32,
        nn.LeakyReLU(0.2),
        self.block(feature_d*2,feature_d*4,4,2,1), # 16,16
        self.block(feature_d*4,feature_d*8,4,2,1), #  8,8
        self.block(feature_d*8,feature_d*16,4,2,1), # 4,4
        nn.Conv2d(feature_d*16,1,4,1,0), # 1,1
        nn.Sigmoid()
    )

  def block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())

  def forward(self,x):
    return self.disc(x)

In [None]:
# Training Loop
N, img_channels, H, W = 8, 1, 64, 64
latent_dim = 100
features_g=64
features_d=64
lr_gen = 5e-5
lr_disc = 5e-5
epochs=5
batch_size=32

device = "cuda" if torch.cuda.is_available() else "cpu"

gen = Generator(latent_dim, img_channels, features_g).to(device)
disc = Discriminator(img_channels, features_d).to(device)

fixed_noise = torch.rand((batch_size,latent_dim,1,1)).to(device)
transform = transforms.Compose(
    [
        transforms.Resize((H,W)),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)]),
    ]
)

dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

gen_opt = torch.optim.RMSprop(gen.parameters(),lr=lr_gen)
disc_opt = torch.optim.RMSprop(disc.parameters(),lr=lr_disc)

criterion = nn.BCELoss()

step=0
critic_loop=10
writer_fake = SummaryWriter(f"/content/logs/fake")
writer_real = SummaryWriter(f"/content/logs/real")
import time
for epoch in range(epochs):

  for batch_idx, (real_img, _) in enumerate(dataloader):
    s_time=time.time()
    N = real_img.shape[0]
    real_img = real_img.to(device)
    for _ in range(critic_loop):
      lat_vec = torch.rand((N,latent_dim,1,1)).to(device)
      fake_img = gen(lat_vec) # (N,3,H,W)
      ##### Train Discriminator ####
      real_pred = disc(real_img).view(-1) # Flatten it from (N,1,1,1) to (N)
      fake_pred = disc(fake_img.detach()).view(-1)
      lossD = -(torch.mean(real_pred) - torch.mean(fake_pred))
      disc_opt.zero_grad()
      lossD.backward()
      disc_opt.step()

    ##### Train Generator ####
    pred = disc(fake_img).view(-1)
    lossG = -torch.mean(pred)

    gen_opt.zero_grad()
    lossG.backward()
    gen_opt.step()

    if batch_idx == 0 or batch_idx%500 == 0:
      print(
          f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(dataloader)} \
                Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
      )

    if batch_idx == 0:
      with torch.no_grad():
          fake = gen(fixed_noise)
          data = real_img
          img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
          img_grid_real = torchvision.utils.make_grid(data, normalize=True)

          writer_fake.add_image(
              "Mnist Fake Images", img_grid_fake, global_step=step
          )
          writer_real.add_image(
              "Mnist Real Images", img_grid_real, global_step=step
          )
          step += 1
    # print((time.time()-s_time)/60)