In [1]:
# %%script /Library/Frameworks/Python.framework/Versions/3.9/bin/python3

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

def visualize_images(batch, filepath=None, save_image=False, single=False):
    grid = torchvision.utils.make_grid(batch, nrow=8)
    if single:
        grid = torchvision.utils.make_grid(batch, nrow=1)
    if save_image:
        torchvision.utils.save_image(grid, filepath)
    else:
        plt.figure(figsize=(100,100))
        plt.imshow(grid.permute(2, 2, 0))
        plt.show()

## WGAN-GP (MNIST)

In [2]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        #Input: N x channels_img x 32 x 32
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), #16x16
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1), #8x8
        self._block(features_d*2, features_d*4, 4, 2, 1), #8x8
        self._block(features_d*4, features_d*8, 4, 2, 1),  #4x4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1x1 representing probability
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.InstanceNorm2d(out_channels, affine=True),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.disc(x)

In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        #Input: N x z_dim x 1 x 1
        self._block(z_dim, features_g*16, 4, 1, 0), # img 4 x 4
        self._block(features_g*16, features_g*8, 4, 2, 1), # img 8x8
        self._block(features_g*8, features_g*4, 4, 2, 1), # img 16x16
        self._block(features_g*4, features_g*2, 4, 2, 1 ), #img 32x32
        nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), #N x channels_img x 64 x 64
        nn.Tanh() #model aligns with normalzed images between -1 and 1
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.gen(x)

In [4]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal(m.weight.data, 0.0, 0.02)

In [5]:
#Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-4 
Z_DIM = 100
IMG_SIZE = 64
CHANNELS_IMG = 1
batch_size = 64
num_epochs = 10
features_critic = 64
features_gen = 64
critic_iterations = 5
lambda_gp = 10

#Initializing the models
# gen = Generator(Z_DIM, CHANNELS_IMG, features_gen).to(device)
# critic = Discriminator(CHANNELS_IMG, features_critic).to(device)
# initialize_weights(gen)
# initialize_weights(critic)

#Normalizing the data
transforms = transforms.Compose(
    [
     transforms.Resize(IMG_SIZE),
     transforms.ToTensor(),
     transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
     ]
)

In [6]:
def gradient_penalty(critic, real, fake, device="cpu"):
  batch_size, c, H, W = real.shape
  epsilon = torch.rand((batch_size, c, 1, 1)).repeat(1, c, H, W).to(device)
  interpolated_images = (real * epsilon) + fake * (1 - epsilon)

  mixed_scores = critic(interpolated_images)
  gradient = torch.autograd.grad(
      inputs = interpolated_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim=1)
  gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
  return gradient_penalty

In [7]:
GEN_FILE = "../input/output-model/gen_2.pth"
# torch.save(gen.state_dict(), GEN_FILE)

loaded_gen = Generator(Z_DIM, CHANNELS_IMG, features_gen).to(device)
loaded_gen.load_state_dict(torch.load(GEN_FILE))
loaded_gen.train()

#---

CRITIC_FILE = "../input/output-model/critic_2.pth"
# torch.save(critic.state_dict(), CRITIC_FILE)

loaded_critic = Discriminator(CHANNELS_IMG, features_critic).to(device)
loaded_critic.load_state_dict(torch.load(CRITIC_FILE))
loaded_critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

In [8]:
dataset = datasets.MNIST("dataset/", transform=transforms, download=True)
#test_dataset = torch.utils.data.Subset(dataset, range(10000))
loader = DataLoader(dataset, batch_size = batch_size, shuffle=True)

opt_critic = optim.Adam(loaded_critic.parameters(), lr=lr, betas=(0.0, 0.9))
opt_gen = optim.Adam(loaded_gen.parameters(), lr=lr, betas=(0.0, 0.9))

# gen.train()
# critic.train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [9]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        new_batch_size = real.shape[0]

        # Training critic
        for _ in range(critic_iterations):
          noise = torch.randn(new_batch_size, Z_DIM, 1, 1).to(device)
          fake = loaded_gen(noise)
          critic_real = loaded_critic(real).reshape(-1)
          critic_fake = loaded_critic(fake).reshape(-1)

          gp = gradient_penalty(loaded_critic, real, fake, device=device)
          loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp)

          loaded_critic.zero_grad()
          loss_critic.backward(retain_graph=True)
          opt_critic.step()


        # Training Generator min -D(G(z))
        output = loaded_critic(fake).reshape(-1)
        loss_gen = -(torch.mean(output))
        loaded_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()


        if batch_idx % 100 == 0:
            print(f"Epoch[{epoch+3}/{num_epochs}]  Batch[{batch_idx}/937] \n"F"Loss D: {loss_critic:.4f}, Loss G: {loss_gen:.4f}")

    save_noise = torch.randn(batch_size, Z_DIM, 1, 1).to(device)
    sample_8x8 = loaded_gen(save_noise).reshape(-1,1,IMG_SIZE,IMG_SIZE)
    filepath = (f"./{epoch+20}.png")
    visualize_images(sample_8x8*0.5+0.5, filepath=filepath, save_image=True)

Epoch[3/10]  Batch[0/937] 
Loss D: -4.7875, Loss G: -41.9317
Epoch[3/10]  Batch[100/937] 
Loss D: -6.4032, Loss G: -42.9390
Epoch[3/10]  Batch[200/937] 
Loss D: -5.5045, Loss G: -45.2581
Epoch[3/10]  Batch[300/937] 
Loss D: -8.0953, Loss G: -48.9978
Epoch[3/10]  Batch[400/937] 
Loss D: -5.1560, Loss G: -46.3866
Epoch[3/10]  Batch[500/937] 
Loss D: -5.1991, Loss G: -43.8552
Epoch[3/10]  Batch[600/937] 
Loss D: -4.4756, Loss G: -52.7531
Epoch[3/10]  Batch[700/937] 
Loss D: -3.4474, Loss G: -54.8127
Epoch[3/10]  Batch[800/937] 
Loss D: -4.8000, Loss G: -46.4909
Epoch[3/10]  Batch[900/937] 
Loss D: -3.5890, Loss G: -48.4664
Epoch[4/10]  Batch[0/937] 
Loss D: -3.7596, Loss G: -46.8140
Epoch[4/10]  Batch[100/937] 
Loss D: -3.7190, Loss G: -56.4249
Epoch[4/10]  Batch[200/937] 
Loss D: -5.2032, Loss G: -53.6386
Epoch[4/10]  Batch[300/937] 
Loss D: -5.4736, Loss G: -56.1157
Epoch[4/10]  Batch[400/937] 
Loss D: -5.3431, Loss G: -52.0358
Epoch[4/10]  Batch[500/937] 
Loss D: -5.4670, Loss G: -52.0

In [10]:
torch.save(loaded_gen.state_dict(), "./gen_3.pth")
torch.save(loaded_critic.state_dict(), "./critic_3.pth" )