In [1]:
import torch
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Critic(nn.Module):
  def __init__(self, img_channels, features_d, img_size, num_classes):
    super(Critic, self).__init__()
    self.img_size = img_size

    self.disc = nn.Sequential(
        nn.Conv2d(in_channels=img_channels + 1, out_channels=features_d, kernel_size=4, stride=1, padding=1),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, 4, 2, 0)
    )

    self.embedding = nn.Embedding(num_classes, img_size * img_size)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.InstanceNorm2d(num_features=out_channels, affine=True),
        nn.LeakyReLU(0.2)
    )
  def forward(self, x, labels):
    embeds = self.embedding(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
    x = torch.cat([x, embeds], dim=1)
    return self.disc(x)

In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_channels, features_g, num_classes, img_size, embed_size):
    super(Generator, self).__init__()
    self.img_size = img_size
    self.gen = nn.Sequential(
        self._block(in_channels=z_dim +embed_size,
                    out_channels=features_g * 16,
                    kernel_size=4,
                    stride=1,
                    padding=0),
        self._block(features_g * 16, features_g * 8, 4, 2, 1),
        self._block(features_g * 8, features_g * 4, 4, 2, 1),
        self._block(features_g * 4, features_g * 2, 4, 2, 1),
        nn.ConvTranspose2d(features_g * 2, img_channels, 4, 2, 1),
        nn.Tanh()
    )
    self.embedding = nn.Embedding(num_classes, embed_size)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_channels,
                           out_channels=out_channels,
                           stride=stride,
                           kernel_size=kernel_size,
                           padding=padding,
                           bias=False),
        nn.InstanceNorm2d(out_channels, affine=True),
        nn.ReLU()
    )

  def forward(self, x, labels):
    embed = self.embedding(labels).unsqueeze(2).unsqueeze(3)
    x = torch.cat([x, embed], dim=1)
    return self.gen(x)

In [4]:
writer_real = SummaryWriter(log_dir='logs/Cond_WGANs/Real')
writer_fake = SummaryWriter(log_dir='logs/Cond_WGANs/Fake')
step = 0

In [5]:
EPOCHS = 5
CRITIC_ITERATIONS = 5
EMBED_SIZE = 100
Z_DIM = 100
LEARNING_RATE = 1e-4
IMG_SIZE = 64
IMG_CHANNELS = 1
NUM_CLASSES = 10
BATCH_SIZE = 64
FEATURES_G = 64
FEATURES_D = 64
LAMBDA_GP = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.InstanceNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.2)

In [7]:
def gradient_penalty(critic, real, fake, labels, device='cpu'):
  BATCH_SIZE, C, H, W = real.shape
  epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
  interpolated = real * epsilon + fake * (1 - epsilon)

  mixed_scores = critic(interpolated, labels)

  gradient = torch.autograd.grad(
      inputs=interpolated,
      outputs=mixed_scores,
      grad_outputs=torch.ones_like(mixed_scores),
      create_graph=True,
      retain_graph=True
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim=1)
  gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
  return gradient_penalty

In [8]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)])
])

In [9]:
dataset = datasets.MNIST(root='data/', download=True, transform=transform)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
gen = Generator(Z_DIM, IMG_CHANNELS, FEATURES_G, NUM_CLASSES, IMG_SIZE, EMBED_SIZE).to(device)
initialize_weights(gen)

critic = Critic(IMG_CHANNELS, FEATURES_D, IMG_SIZE, NUM_CLASSES).to(device)
initialize_weights(critic)

In [11]:
gen_opt = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
critic_opt = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

In [12]:
from tqdm import tqdm

critic.train()
gen.train()

for epoch in range(EPOCHS):
  for batch_idx, (real, labels) in enumerate(tqdm(loader)):
    real = real.to(device)
    BATCH_SIZE = real.shape[0]
    labels = labels.to(device)

    for _ in range(CRITIC_ITERATIONS):
      noise = torch.rand((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
      fake = gen(noise, labels)

      critic_fake = critic(fake, labels).reshape(-1)
      critic_real = critic(real, labels).reshape(-1)

      gp = gradient_penalty(critic, real, fake, labels, device=device)

      critic_loss = -(torch.mean(critic_real) + torch.mean(critic_fake)) + (gp * LAMBDA_GP)

      critic.zero_grad()
      critic_loss.backward(retain_graph=True)
      critic_opt.step()

    outputs = critic(fake, labels)
    gen_loss = -torch.mean(outputs)
    gen.zero_grad()
    gen_loss.backward()
    gen_opt.step()

    if batch_idx % 100 == 0 and batch_idx > 0:
      gen.eval()
      critic.eval()
      print(f"\nEpoch [{epoch}/{EPOCHS}], Loss D: {critic_loss:.4f}, loss G: {gen_loss:.4f}")

      with torch.no_grad():
        fake = gen(noise, labels)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1
        gen.train()
        critic.train()

 11%|██▏                 | 101/938 [01:32<13:22,  1.04it/s]


Epoch [0/5], Loss D: -282.9848, loss G: -148.2388


 21%|████▎               | 201/938 [03:03<12:06,  1.02it/s]


Epoch [0/5], Loss D: -739.3724, loss G: -372.5141


 32%|██████▍             | 301/938 [04:34<10:28,  1.01it/s]


Epoch [0/5], Loss D: -1212.5627, loss G: -607.4992


 43%|████████▌           | 401/938 [06:05<08:49,  1.02it/s]


Epoch [0/5], Loss D: -1789.4749, loss G: -895.8740


 53%|██████████▋         | 501/938 [07:37<07:10,  1.01it/s]


Epoch [0/5], Loss D: -2490.3088, loss G: -1246.2604


 64%|████████████▊       | 601/938 [09:08<05:31,  1.02it/s]


Epoch [0/5], Loss D: -3318.4053, loss G: -1660.3459


 75%|██████████████▉     | 701/938 [10:40<03:53,  1.02it/s]


Epoch [0/5], Loss D: -4276.5571, loss G: -2139.6248


 85%|█████████████████   | 801/938 [12:11<02:14,  1.02it/s]


Epoch [0/5], Loss D: -5364.8540, loss G: -2683.9658


 96%|███████████████████▏| 901/938 [13:43<00:36,  1.02it/s]


Epoch [0/5], Loss D: -6584.9507, loss G: -3294.0862


100%|████████████████████| 938/938 [14:16<00:00,  1.10it/s]
 11%|██▏                 | 101/938 [01:32<13:43,  1.02it/s]


Epoch [1/5], Loss D: -8484.1484, loss G: -4243.8042


 21%|████▎               | 201/938 [03:03<12:04,  1.02it/s]


Epoch [1/5], Loss D: -10016.8574, loss G: -5010.3696


 32%|██████▍             | 301/938 [04:35<10:26,  1.02it/s]


Epoch [1/5], Loss D: -11681.7842, loss G: -5843.2148


 43%|████████▌           | 401/938 [06:06<08:48,  1.02it/s]


Epoch [1/5], Loss D: -13479.6641, loss G: -6742.0942


 53%|██████████▋         | 501/938 [07:38<07:10,  1.02it/s]


Epoch [1/5], Loss D: -15410.9248, loss G: -7707.8809


 64%|████████████▊       | 601/938 [09:09<05:31,  1.02it/s]


Epoch [1/5], Loss D: -17473.8984, loss G: -8739.4590


 75%|██████████████▉     | 701/938 [10:39<03:51,  1.02it/s]


Epoch [1/5], Loss D: -19670.4355, loss G: -9837.8936


 85%|█████████████████   | 801/938 [12:09<02:12,  1.03it/s]


Epoch [1/5], Loss D: -21997.5645, loss G: -11002.6855


 96%|███████████████████▏| 901/938 [13:40<00:36,  1.02it/s]


Epoch [1/5], Loss D: -24461.4688, loss G: -12234.3555


100%|████████████████████| 938/938 [14:12<00:00,  1.10it/s]
 11%|██▏                 | 101/938 [01:30<13:45,  1.01it/s]


Epoch [2/5], Loss D: -28085.6855, loss G: -14045.9268


 21%|████▎               | 201/938 [03:01<11:48,  1.04it/s]


Epoch [2/5], Loss D: -30868.9219, loss G: -15437.6836


 32%|██████▍             | 301/938 [04:31<10:15,  1.03it/s]


Epoch [2/5], Loss D: -33776.3242, loss G: -16896.6328


 43%|████████▌           | 401/938 [06:01<08:48,  1.02it/s]


Epoch [2/5], Loss D: -36839.9375, loss G: -18423.8477


 53%|██████████▋         | 501/938 [07:32<07:10,  1.02it/s]


Epoch [2/5], Loss D: -40023.8086, loss G: -20017.4062


 64%|████████████▊       | 601/938 [09:04<05:33,  1.01it/s]


Epoch [2/5], Loss D: -43347.9766, loss G: -21678.2246


 75%|██████████████▉     | 701/938 [10:36<03:54,  1.01it/s]


Epoch [2/5], Loss D: -46796.2695, loss G: -23407.3359


 85%|█████████████████   | 801/938 [12:08<02:15,  1.01it/s]


Epoch [2/5], Loss D: -50398.0742, loss G: -25203.3711


 96%|███████████████████▏| 901/938 [13:40<00:36,  1.01it/s]


Epoch [2/5], Loss D: -54102.6523, loss G: -27066.1582


100%|████████████████████| 938/938 [14:13<00:00,  1.10it/s]
 11%|██▏                 | 101/938 [01:32<13:47,  1.01it/s]


Epoch [3/5], Loss D: -59490.9062, loss G: -29749.9961


 21%|████▎               | 201/938 [03:04<12:03,  1.02it/s]


Epoch [3/5], Loss D: -63536.1758, loss G: -31774.4707


 32%|██████▍             | 301/938 [04:35<10:25,  1.02it/s]


Epoch [3/5], Loss D: -67723.5938, loss G: -33866.2188


 43%|████████▌           | 401/938 [06:06<08:46,  1.02it/s]


Epoch [3/5], Loss D: -72041.6016, loss G: -36025.4375


 53%|██████████▋         | 501/938 [07:37<07:09,  1.02it/s]


Epoch [3/5], Loss D: -76495.8672, loss G: -38252.8047


 64%|████████████▊       | 601/938 [09:08<05:30,  1.02it/s]


Epoch [3/5], Loss D: -81089.4062, loss G: -40549.7578


 75%|██████████████▉     | 701/938 [10:40<03:53,  1.02it/s]


Epoch [3/5], Loss D: -85818.4531, loss G: -42914.4922


 85%|█████████████████   | 801/938 [12:11<02:14,  1.02it/s]


Epoch [3/5], Loss D: -90686.0234, loss G: -45348.3594


 96%|███████████████████▏| 901/938 [13:42<00:36,  1.02it/s]


Epoch [3/5], Loss D: -95692.8906, loss G: -47851.7852


100%|████████████████████| 938/938 [14:15<00:00,  1.10it/s]
 11%|██▏                 | 101/938 [01:32<13:49,  1.01it/s]


Epoch [4/5], Loss D: -102833.3047, loss G: -51422.4141


 21%|████▎               | 201/938 [03:04<12:08,  1.01it/s]


Epoch [4/5], Loss D: -108164.4297, loss G: -54087.7891


 32%|██████▍             | 301/938 [04:36<10:30,  1.01it/s]


Epoch [4/5], Loss D: -113632.9219, loss G: -56822.1797


 43%|████████▌           | 401/938 [06:08<08:51,  1.01it/s]


Epoch [4/5], Loss D: -119240.6953, loss G: -59626.4180


 53%|██████████▋         | 501/938 [07:38<07:12,  1.01it/s]


Epoch [4/5], Loss D: -124931.6562, loss G: -62499.5547


 64%|████████████▊       | 601/938 [09:08<05:33,  1.01it/s]


Epoch [4/5], Loss D: -130856.3125, loss G: -65442.2656


 75%|██████████████▉     | 701/938 [10:40<03:55,  1.01it/s]


Epoch [4/5], Loss D: -136825.7812, loss G: -68451.2500


 85%|█████████████████   | 801/938 [12:12<02:15,  1.01it/s]


Epoch [4/5], Loss D: -143050.4219, loss G: -71531.6406


 96%|███████████████████▏| 901/938 [13:44<00:36,  1.01it/s]


Epoch [4/5], Loss D: -149345.1562, loss G: -74679.5000


100%|████████████████████| 938/938 [14:18<00:00,  1.09it/s]
