In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from torch.utils.tensorboard import SummaryWriter

root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/'

# Make Critic and Generator classes

In [2]:
class Critic(nn.Module):
    def __init__(self, num_classes, img_size, img_channels=3, features=64):
        super().__init__()
        self.img_size = img_size
        self.crit = nn.Sequential(
            self._make_conv2d_block(img_channels+1, features, use_bn=False),

            self._make_conv2d_block(features, features*2),
            self._make_conv2d_block(features*2, features*4),
            self._make_conv2d_block(features*4, features*8),

            self._make_conv2d_block(features*8, 1, 4, 2, 0, use_act=False, use_bn=False),
        )
        self.embed = nn.Embedding(num_classes, img_size*img_size)

    def _make_conv2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True, leak=0.2):
        layers = [
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
        ]
        if use_bn:
            layers.append(nn.InstanceNorm2d(out_channels))
        if use_act:
            layers.append(nn.LeakyReLU(leak))
        return nn.Sequential(*layers)
    
    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], axis=1)
        return self.crit(x)
    

class Generator(nn.Module):
    def __init__(self, num_classes, embed_size, z_dim=100, features=64, img_channels=3) -> None:
        super().__init__()
        self.gen = nn.Sequential(
            self._make_convT2d_block(z_dim+embed_size, features*8, 4, 1, 0),

            self._make_convT2d_block(features*8, features*4),
            self._make_convT2d_block(features*4, features*2),
            self._make_convT2d_block(features*2, features),

            self._make_convT2d_block(features, img_channels, use_bn=False, use_act=False),
            nn.Tanh(),
        )
        self.embed = nn.Embedding(num_classes, embed_size)

    def _make_convT2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True):
        layers = [
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            )
        ]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
        if use_act:
            layers.append(nn.ReLU())
        return nn.Sequential(*layers)        
    
    def forward(self, x, labels):
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        return self.gen(x)
    
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m , (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(
                tensor=m.weight.data,
                mean=0.0,
                std=0.02,
            )

def find_gradient_penalty(critic, labels, real, fake, device='gpu'):
    b, c, h, w = real.shape
    alpha = torch.rand((b, 1, 1, 1)).repeat(1, c, h, w).to(device)
    interpolated_batch = real*alpha + fake*(1-alpha)
    critic_scores = critic(interpolated_batch, labels)

    gradient = torch.autograd.grad(
        outputs=critic_scores,
        inputs=interpolated_batch,
        grad_outputs=torch.ones_like(critic_scores),
        retain_graph=True,
        create_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm-1)**2)
    return gradient_penalty

def save_checkpoint(state, filename='mnist_wgangp.pth.tar'):
    print("=> Saving Checkpoint <=")
    torch.save(state, filename)

def load_checkpoint(checkpoint, gen, crit):
    print("=> Loading Checkpoint <=")
    gen.load_state_dict(checkpoint['gen'])
    crit.load_state_dict(checkpoint['crit'])

# Testing

In [3]:
# BATCH_SIZE, NUM_CHANNELS, H, W = 8, 3, 64, 64
# Z_DIM = 100
# FEATURES = 64

# img_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, H, W))
# crit = Critic(10, 64, NUM_CHANNELS, FEATURES)
# assert crit(img_test).shape == (BATCH_SIZE, 1, 1, 1), 'Crit Test Failed'

# noise_test = torch.randn((BATCH_SIZE, Z_DIM, 1, 1))
# gen = Generator(10, 100, Z_DIM, FEATURES, NUM_CHANNELS)
# assert gen(noise_test).shape == (BATCH_SIZE, NUM_CHANNELS, H, W), 'Gen Test Failed'

# print('Success!')

# Hyperparameters

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
NUM_CHANNELS = 1
IMG_SIZE = 64
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
FEATURES_CRIT = 64
FEATURES_GEN = 128
LR = 1e-4
NUM_EPOCHS = 50
CRIT_ITER = 5
GRID_SHOW = 10
LAMBDA_GP = 10

# SetUp

In [5]:
crit = Critic(NUM_CLASSES, IMG_SIZE, NUM_CHANNELS, FEATURES_CRIT).to(device)
gen = Generator(NUM_CLASSES, GEN_EMBEDDING, Z_DIM, FEATURES_GEN, NUM_CHANNELS).to(device)
initialize_weights(crit), initialize_weights(gen)
fixed_labels = torch.arange(10).to(device)

transformations = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(NUM_CHANNELS)], [0.5 for _ in range(NUM_CHANNELS)]
        ),
    ]
)
# images = datasets.ImageFolder(root=root_path + 'celebs', transform=transformations)
images = datasets.MNIST(root='mnist', train=True, transform=transformations, download=True)
images_loader = DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)

optim_crit = optim.Adam(crit.parameters(), lr=LR, betas=(0.5, 0.9))
optim_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.9))

writer_fake = SummaryWriter('logs/fake')
writer_real = SummaryWriter('logs/real')
steps = 1

In [6]:
def model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Generator Size: {model_size(gen):.3e}, Discriminator Size: {model_size(crit):.3e}')

Generator Size: 1.429e+07, Discriminator Size: 2.804e+06


# Training

In [6]:
crit.train(), gen.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(images_loader):
        real = real.to(device)
        labels = labels.to(device)

        for _ in range(CRIT_ITER):
            noise = torch.randn((real.shape[0], Z_DIM, 1, 1)).to(device)
            fake = gen(noise, labels)

            crit_real = crit(real, labels).view(-1)
            crit_fake = crit(fake, labels).view(-1)
            gp = find_gradient_penalty(crit, labels, real, fake, device)
            loss_crit = (
                -(torch.mean(crit_real) - torch.mean(crit_fake)) + LAMBDA_GP*gp
                )
            crit.zero_grad()
            loss_crit.backward(retain_graph=True)
            optim_crit.step()

        crit_fake = crit(fake, labels).view(-1)
        loss_gen = -torch.mean(crit_fake)

        gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        if batch_idx % 40 == 0:
            print(
                f'Epoch: {epoch+1}/{NUM_EPOCHS} -- Step: {steps} -- Batch: {batch_idx+1}/{len(images_loader)} -- Crit Loss: {loss_crit:.4f} -- Gen Loss: {loss_gen:.4f}'
            )
            with torch.no_grad():
                fixed_noise = torch.randn((GRID_SHOW, Z_DIM, 1, 1)).to(device)
                fake = gen(fixed_noise, fixed_labels)

                fake_images = torchvision.utils.make_grid(fake, nrow=5, normalize=True)
                save_image(fake_images, f'wgangp_results/{steps}.png')
                real_images = torchvision.utils.make_grid(real[:GRID_SHOW], nrow=5, normalize=True)
                writer_fake.add_scalar('Gen Loss', loss_gen, global_step=steps)
                writer_fake.add_image(
                    'Fake', fake_images, global_step=steps
                )
                writer_real.add_image(
                    'Real', real_images, global_step=steps
                )
                steps += 1

Epoch: 1/50 -- Step: 1 -- Batch: 1/469 -- Crit Loss: -10.7217 -- Gen Loss: 9.2009
Epoch: 1/50 -- Step: 2 -- Batch: 41/469 -- Crit Loss: -131.0640 -- Gen Loss: 110.4134
Epoch: 1/50 -- Step: 3 -- Batch: 81/469 -- Crit Loss: -128.1404 -- Gen Loss: 120.5416
Epoch: 1/50 -- Step: 4 -- Batch: 121/469 -- Crit Loss: -115.0725 -- Gen Loss: 109.5433
Epoch: 1/50 -- Step: 5 -- Batch: 161/469 -- Crit Loss: -101.7197 -- Gen Loss: 99.7237
Epoch: 1/50 -- Step: 6 -- Batch: 201/469 -- Crit Loss: -87.0332 -- Gen Loss: 95.6893
Epoch: 1/50 -- Step: 7 -- Batch: 241/469 -- Crit Loss: -71.2586 -- Gen Loss: 90.8242
Epoch: 1/50 -- Step: 8 -- Batch: 281/469 -- Crit Loss: -58.5242 -- Gen Loss: 92.8212
Epoch: 1/50 -- Step: 9 -- Batch: 321/469 -- Crit Loss: -46.1799 -- Gen Loss: 90.5199
Epoch: 1/50 -- Step: 10 -- Batch: 361/469 -- Crit Loss: -38.3794 -- Gen Loss: 94.3378
Epoch: 1/50 -- Step: 11 -- Batch: 401/469 -- Crit Loss: -31.8554 -- Gen Loss: 97.4765
Epoch: 1/50 -- Step: 12 -- Batch: 441/469 -- Crit Loss: -26.9