In [1]:
# %%script /Library/Frameworks/Python.framework/Versions/3.9/bin/python3

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, ConcatDataset

def visualize_images(batch, filepath=None, save_image=False, single=False):
    grid = torchvision.utils.make_grid(batch, nrow=8)
    if single:
        grid = torchvision.utils.make_grid(batch, nrow=1)
    if save_image:
        torchvision.utils.save_image(grid, filepath)
    else:
        plt.figure(figsize=(100,100))
        plt.imshow(grid.permute(2, 2, 0))
        plt.show()

## Initializing Models

In [2]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        #Input: N x channels_img x 32 x 32
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), #16x16
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1), #8x8
        self._block(features_d*2, features_d*4, 4, 2, 1), #8x8
        self._block(features_d*4, features_d*8, 4, 2, 1),  #4x4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1x1 representing probability
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.InstanceNorm2d(out_channels, affine=True),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.disc(x)

In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        #Input: N x z_dim x 1 x 1
        self._block(z_dim, features_g*16, 4, 1, 0), # img 4 x 4
        self._block(features_g*16, features_g*8, 4, 2, 1), # img 8x8
        self._block(features_g*8, features_g*4, 4, 2, 1), # img 16x16
        self._block(features_g*4, features_g*2, 4, 2, 1 ), #img 32x32
        nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), #N x channels_img x 64 x 64
        nn.Tanh() #model aligns with normalzed images between -1 and 1
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.gen(x)

In [4]:
def gradient_penalty(critic, real, fake, device="cpu"):
  batch_size, c, H, W = real.shape
  epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, c, H, W).to(device)
  interpolated_images = (real * epsilon) + fake * (1 - epsilon)

  mixed_scores = critic(interpolated_images)
  gradient = torch.autograd.grad(
      inputs = interpolated_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim=1)
  gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
  return gradient_penalty

In [5]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal(m.weight.data, 0.0, 0.02)

## Training

In [6]:
#Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-4 
Z_DIM = 100
batch_size = 64
features_critic = 64
features_gen = 64
critic_iterations = 5
lambda_gp = 10
SHVN_IMG_SIZE = 64
SHVN_CHANNELS_IMG = 3
shvn_num_epochs = 20

#Initializing the models
shvn_gen = Generator(Z_DIM, SHVN_CHANNELS_IMG, features_gen).to(device)
initialize_weights(shvn_gen)
shvn_gen.train()

shvn_critic = Discriminator(SHVN_CHANNELS_IMG, features_critic).to(device)
initialize_weights(shvn_critic)
shvn_critic.train()

  after removing the cwd from sys.path.


Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

In [7]:
#Normalizing the data
transforms = transforms.Compose(
    [
     transforms.Resize(SHVN_IMG_SIZE),
     transforms.ToTensor(),
     transforms.Normalize([0.5 for _ in range(SHVN_CHANNELS_IMG)], [0.5 for _ in range(SHVN_CHANNELS_IMG)])
     ]
)
    
shvn_pre_dataset = datasets.SVHN(root='data/', split='train', download=True, transform=transforms)
shvn_dataset = torch.utils.data.Subset(shvn_pre_dataset, range(30000))

generated_mnist = ImageFolder(root = "../input/wgan-generated-mnist/", transform=transforms)
dataset = ConcatDataset([shvn_dataset, generated_mnist])
loader = DataLoader(dataset= dataset, batch_size = 64, shuffle = True)

opt_shvn_critic = optim.Adam(shvn_critic.parameters(), lr=lr, betas=(0.0, 0.9))
opt_shvn_gen = optim.Adam(shvn_gen.parameters(), lr=lr, betas=(0.0, 0.9))

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to data/train_32x32.mat


  0%|          | 0/182040794 [00:00<?, ?it/s]

In [None]:
for epoch in range(shvn_num_epochs):
  for shvn_batch_idx, (shvn_real, _) in enumerate(loader):
    shvn_real = shvn_real.to(device)
    shvn_new_batch_size = shvn_real.shape[0]
    

    # Training critic
    for _ in range(critic_iterations):
      noise = torch.randn(shvn_new_batch_size, Z_DIM, 1, 1).to(device)
      shvn_fake = shvn_gen(noise)
      critic_real = shvn_critic(shvn_real).reshape(-1)
      critic_fake = shvn_critic(shvn_fake).reshape(-1)

      gp = gradient_penalty(shvn_critic, shvn_real, shvn_fake, device=device)
      loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp)
    
      shvn_critic.zero_grad()
      loss_critic.backward(retain_graph=True)
      opt_shvn_critic.step()


    # Training Generator min -D(G(z))
    output = shvn_critic(shvn_fake).reshape(-1)
    loss_gen = -(torch.mean(output))
    shvn_gen.zero_grad()
    loss_gen.backward()
    opt_shvn_gen.step()

    if shvn_batch_idx % 300 == 0:
      print(f"Epoch[{epoch+1}/{shvn_num_epochs}] \n"F"Loss D: {loss_critic:.4f}, Loss G: {loss_gen:.4f}")

  save_noise = torch.randn(batch_size, Z_DIM, 1, 1).to(device)
  sample_8x8 = shvn_gen(save_noise).reshape(-1,1,SHVN_IMG_SIZE,SHVN_IMG_SIZE)
  filepath = (f"./{epoch+21}.png")
  visualize_images(sample_8x8*0.5+0.5, filepath=filepath, save_image=True)

##### Saving the Model

In [9]:
torch.save(shvn_gen.state_dict(), "./shvn_gen.pth")
torch.save(shvn_critic.state_dict(), "./shvn_critic.pth" )