In [None]:
import torch
import torch.nn as nn

## Model

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU())

        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())

        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())

        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.deconv5 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        out = self.deconv1(z)
        out = self.deconv2(out)
        out = self.deconv3(out)
        out = self.deconv4(out)
        out = self.deconv5(out)
        return out

In [None]:
z = torch.randn(64, 128, 1, 1)
generator = Generator()
out = generator(z)
out.shape

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2))

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2))

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2))

        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2))
        
        self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)

    def forward(self, img):
        out = self.conv1(img)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        logit = self.conv5(out)
        return logit

In [None]:
discriminator = Discriminator()
logit = discriminator(out)
logit.shape

## Loss

In [None]:
def d_loss_fn(r_logit, f_logit):
    return -r_logit.mean() + f_logit.mean()

def g_loss_fn(f_logit):
    return - f_logit.mean()

## Data

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(148),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
train_dataset = CelebA(root='../../data', split='train', transform=transform, download=False)
val_dataset = CelebA(root='../../data', split='test', transform=transform, download=False)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=144, shuffle=True, drop_last=True)

In [None]:
batch = next(iter(train_loader))
batch[0].shape, batch[1].shape

## Training

In [None]:
import torch.optim as optim

In [None]:
G_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

### Training discriminator

In [None]:
x_real = next(iter(train_loader))[0]
x_real.shape

In [None]:
# Training discriminator
z = torch.randn(32, 128, 1, 1)
print(z.shape)

In [None]:
x_fake = generator(z).detach()
print(x_fake.shape)

In [None]:
x_real_d_logit = discriminator(x_real)
x_real_d_logit.shape

In [None]:
x_fake_d_logit = discriminator(x_fake)
x_fake_d_logit.shape

In [None]:
d_loss = d_loss_fn(x_real_d_logit, x_fake_d_logit)
d_loss

## Gradient Penalty

In [None]:
x_real.shape, x_fake.shape

In [None]:
# realとfakeの補間画像を返す
def sample(real, fake):
    shape = [real.shape[0], 1, 1, 1]
    alpha = torch.rand(shape, device=real.device)
    sample = alpha * fake + (1 - alpha) * real
    return sample

In [None]:
x_sample = sample(x_real, x_fake)
x_sample.shape

In [None]:
x_sample

In [None]:
import functools

# TODO: このpartialはなぜ必要？ないとgradでエラーが出る
pred = functools.partial(discriminator)(x_sample)
pred.shape

In [None]:
x_sample.requires_grad = True
grad = torch.autograd.grad(pred, x_sample, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
grad.shape

In [None]:
grad_norm = grad.view(grad.size(0), -1).norm(p=2, dim=1)
grad_norm.shape

In [None]:
gradient_penalty = ((grad_norm - 1)** 2).mean()
gradient_penalty

## Training Generator

In [None]:
z = torch.randn(32, 128, 1, 1)
x_fake = generator(z)
x_fake.shape

In [None]:
x_fake_d_logit = discriminator(x_fake)
x_fake_d_logit.shape

In [None]:
g_loss = g_loss_fn(x_fake_d_logit)
g_loss