<a href="https://colab.research.google.com/github/SarveshD7/WGAN-Pytorch/blob/main/WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ***Pros of WGAN***
Better Stability

Loss means something unlike in normal GANs.
Loss: Terminating criteria

Prevernts Mode Collapse:
# Mode collapse
Is when the model outputs only a specific class images
Usually you want your GAN to produce a wide variety of outputs. You want, for example, a different face for every random input to your face generator.

However, if a generator produces an especially plausible output, the generator may learn to produce only that output. In fact, the generator is always trying to find the one output that seems most plausible to the discriminator.

If the generator starts producing the same output (or a small set of outputs) over and over again, the discriminator's best strategy is to learn to always reject that output. But if the next generation of discriminator gets stuck in a local minimum and doesn't find the best strategy, then it's too easy for the next generator iteration to find the most plausible output for the current discriminator.

Each iteration of generator over-optimizes for a particular discriminator, and the discriminator never manages to learn its way out of the trap. As a result the generators rotate through a small set of output types. This form of GAN failure is called mode collapse.

# Wasserstein loss:
The Wasserstein loss alleviates mode collapse by letting you train the discriminator to optimality without worrying about vanishing gradients. If the discriminator doesn't get stuck in local minima, it learns to reject the outputs that the generator stabilizes on. So the generator has to try something new.

# ***Cons of WGAN***
Longer to train

---



In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_critic = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

In [22]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):  # features_d is channels that are going to change in different layers
    super(Discriminator, self).__init__()
    self.critic = nn.Sequential(
        # Input: N x channels_img x 64 x 64
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),  # 32x32
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4,2,1),  # 16x16
        self._block(features_d*2, features_d*4, 4,2,1),  # 8x8
        self._block(features_d*4, features_d*8, 4,2,1),  # 4x4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1x
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        # nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self,x):
    return self.critic(x)


In [4]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    # Input: N x z_dim x 1 x 1
    self.gen = nn.Sequential(
        self._block(z_dim, features_g*16, 4, 1, 0), # 4x4
        self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
        self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
        self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
        nn.ConvTranspose2d(features_g*2,channels_img, kernel_size=4, stride=2, padding=1),  # 64x64
        nn.Tanh()  # Range = [-1,1]
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        # nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )

  def forward(self, x):
    return self.gen(x)


In [5]:
# Initializing weights of the model with a normal distribution with given mean and standard deviation
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [8]:
transforms = torchvision.transforms.Compose(
   [ torchvision.transforms.Resize(IMAGE_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)],
    )]
)

In [13]:
dataset = datasets.MNIST(root="/content/dataset/", train=True, transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [19]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_critic).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [20]:
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)
# criterion = nn.BCELoss()  Not needed in WGAN


In [21]:
gen.train()
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [21]:
# Wrting the training loop
for epoch in range(NUM_EPOCHS):
  print(f"Started epoch-> {epoch}")
  for batch_idx, (real, _) in enumerate(loader):
    real = real.to(device)
    # Train the Discriminator (Critic) CRITIC_ITERATIONS time for every time Generator is trained
    for _ in range(CRITIC_ITERATIONS):
      noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
      fake = gen(noise)
      critic_real = critic(real).reshape(-1)
      critic_fake = critic(fake).reshape(-1)
      loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
      critic.zero_grad()
      loss_critic.backward(retain_graph=True)  # retain_graph=True since we want to reuse fake and Pytorch erases intermediate results on loss.backward()
      opt_critic.step()

      for p in critic.parameters():
        p.data.clamp_(WEIGHT_CLIP, WEIGHT_CLIP)
    # Training the Generator
    output = critic(fake).reshape(-1)
    loss_gen = -torch.mean(output)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()
  print(f"Completed epoch-> {epoch}")
