In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image,make_grid
import matplotlib.pyplot as plt
import os

# Hyperparameters
batch_size = 128
z_dim = 100
image_size = 28
channels = 1
epochs = 50
lr = 0.0002
beta1 = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("generated_images",exist_ok = True)


In [None]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5))
])
train_loader = torch.utils.data.DataLoader(datasets.MNIST(root = "./data",
                              train=True,
                              transform=transform,
                              download=True),
                                           batch_size = batch_size,
                                           shuffle=True)

In [None]:
class Generator(nn.Module):
  def __init__(self,z_dim):
    super().__init__()
    self.net = nn.Sequential(
        nn.ConvTranspose2d(z_dim,256,7,1,0,bias=False),  # 256 x 7 x 7
        nn.BatchNorm2d(256),
        nn.ReLU(True),
        nn.ConvTranspose2d(256,128,4,2,1,bias=False),    # 128 x 14 x 14
        nn.BatchNorm2d(128),
        nn.ReLU(True),
        nn.ConvTranspose2d(128,1,4,2,1,bias=False),    # 64 x 28 x 28
        nn.Tanh()
    )
  def forward(self,z):
    return self.net(z)

In [None]:
from torch.nn.modules.activation import Sigmoid
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Conv2d(1,64,4,2,1,bias=False),  # 64 x 14 x 14
        nn.LeakyReLU(0.2,True),
        nn.Conv2d(64,128,4,2,1,bias=False),  # 128 x 7 x 7
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2,True),
        nn.Flatten(),
        nn.Linear(128 * 7 * 7,1),
        nn.Sigmoid()
    )
  def forward(self,img):
    return self.net(img)

In [None]:
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(),lr=lr,betas = (beta1,0.999))
optimizer_D = optim.Adam(discriminator.parameters(),lr = lr,betas= (beta1,0.999))

In [None]:
fixed_noise = torch.randn(64,z_dim,1,1).to(device)
def generate_and_save_images(epoch):
  generator.eval()
  with torch.no_grad():
    fake_images = generator(fixed_noise).detach().cpu()
    fake_images = fake_images * 0.5 + 0.5
    save_image(fake_images,f"generated_images/sample_epoch_{epoch}.png",nrow=8)
  generator.train()

In [1]:
k = 3
p = 1

In [None]:
for epoch in range(1,epochs + 1):
  for i,(real_images,_) in enumerate(train_loader):
    batch_size_curr = real_images.shape[0]
    real_images = real_images.to(device)
    real_labels = torch.ones(batch_size_curr,1,device = device)
    fake_labels = torch.zeros(batch_size_curr,1,device = device)

    # Train Discriminator p - times
    for _ in range(p):
      noise = torch.randn(batch_size_curr,z_dim,1,1,device=device)
      fake_images = generator(noise)
      fake_loss = criterion(discriminator(fake_images.detach()),fake_labels)
      real_loss = criterion(discriminator(real_images),real_labels)

      total_loss = real_loss + fake_loss
      optimizer_D.zero_grad()
      total_loss.backward()
      optimizer_D.step()

     # Train Generator k - times
    for _ in range(k):
      noise = torch.randn(batch_size_curr,z_dim,1,1,device=device)
      fake_images = generator(noise)
      generator_loss = criterion(discriminator(fake_images),real_labels)  # Fool D -> labels as Real
      optimizer_G.zero_grad()
      generator_loss.backward()
      optimizer_G.step()

    if i%200 == 0:
      print(f"Epoch [{epoch}/{epochs}] | Batch [{i}/{len(train_loader)}] | "
            f"D_Loss: {total_loss.item():.4f} | G_Loss :{generator_loss.item():.4f}" )
  generator.eval()
  generate_and_save_images(epoch)
