In [1]:
from torch.utils.data import Dataset, DataLoader
import torchvision
import numpy as np
from torchvision import transforms
from torch import nn
import torch
from torch import optim
import tensorboard
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as datasets

torch.cuda.empty_cache()

In [2]:
celeb_data_path = "./celeb_dataset/"
bitmoji_path = "./Bitmoji-Faces/"

img_size = 64
no_channels = 3
batch_size = 64

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((img_size, img_size)),
        transforms.Normalize(
            [0.5 for _ in range(no_channels)], [0.5 for _ in range(no_channels)]
        ),
    ]
)

celeb_dataset = datasets.ImageFolder(root="./celeb_dataset/", transform=transform)
bitmoji_dataset = datasets.ImageFolder(root=bitmoji_path, transform=transform)

celeb_dataloader = DataLoader(celeb_dataset, batch_size, shuffle=True)
bitmoji_dataloader = DataLoader(bitmoji_dataset, batch_size, shuffle=True)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))


def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m,nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)

In [4]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
import random

LEARNING_RATE = 1e-5
EPOCHS = 5
GRADIENT_PENALITY_LAMBDA = 10
LAMBDA_CYCLE = 10
LAMBDA_IDENTITY = 1

gen_x = Generator(3).to(device)
gen_y = Generator(3).to(device)
disc_x = Discriminator(3, 64).to(device)
disc_y = Discriminator(3, 64).to(device)

writer_1 = SummaryWriter(f"logs/real")
writer_2 = SummaryWriter(f"logs/fake")

opt_gen = optim.Adam(
    list(gen_x.parameters()) + list(gen_y.parameters()),
    LEARNING_RATE,
    (0.5, 0.999)
)

opt_disc_x = optim.Adam(disc_x.parameters(), LEARNING_RATE, (0.5, 0.999))
opt_disc_y = optim.Adam(disc_y.parameters(), LEARNING_RATE, (0.5, 0.999))

l1 = nn.L1Loss()
mse = nn.MSELoss()

# for tensorboard plotting
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen_x.train()
gen_y.train()
disc_x.train()
disc_y.train()

num_iters = 100000
step = 0

for epoch in range(EPOCHS):
    for iter in range(num_iters):
        rand_x = random.randint(0, num_iters)
        rand_y = random.randint(0, num_iters)

        real_x = celeb_dataset[rand_x]
        real_y = bitmoji_dataset[rand_y]
        real_x = real_x.to(device)
        real_y = real_y.to(device)

        # Train discriminator x
        fake_x = gen_x(real_x)
        d_x_real = disc_x(real_x).reshape(-1)
        d_x_fake = disc_x(fake_x.detach()).reshape(-1)
        gp = gradient_penalty(disc_x, real_x, fake_x, device)
        loss_disc_x = -(torch.mean(d_x_real) - torch.mean(d_x_fake)) + GRADIENT_PENALITY_LAMBDA * gp
        opt_disc_x.zero_grad()
        loss_disc_x.backward(retain_graph=True)
        opt_disc_x.step()

        #Train discriminator y
        fake_y = gen_x(real_y)
        d_y_real = disc_y(real_y).reshape(-1)
        d_y_fake = disc_y(fake_y.detach()).reshape(-1)
        gp = gradient_penalty(disc_y, real_y, fake_y, device)
        loss_disc_y = -(torch.mean(d_y_real) - torch.mean(d_y_fake)) + GRADIENT_PENALITY_LAMBDA * gp
        opt_disc_y.zero_grad()
        loss_disc_y.backward(retain_graph=True)
        opt_disc_y.step()

        # Adveserial Loss
        gen_x_fake = disc_x(fake_x).reshape(-1)
        loss_gen_x = -torch.mean(gen_x_fake)
        gen_Y_fake = disc_y(fake_y).reshape(-1)
        loss_gen_y = -torch.mean(gen_Y_fake)

        # Cycle Loss
        cycle_x = gen_x(fake_y)
        cycle_y = gen_y(fake_x)
        cycle_x_loss = l1(real_x, cycle_x)
        cycle_y_loss = l1(real_y, cycle_y)

        # Identity Loss
        id_x = gen_x(real_x)
        id_y = gen_y(real_y)
        id_x_loss = l1(real_x, id_x)
        id_y_loss = l1(real_y, id_y)

        net_gen_loss = (
            loss_gen_x +
            loss_gen_y +
            (cycle_x_loss * LAMBDA_CYCLE) +
            (cycle_y_loss * LAMBDA_CYCLE) +
            (id_x_loss * LAMBDA_IDENTITY) +
            (id_y_loss * LAMBDA_IDENTITY)
        )

        opt_gen.zero_grad()
        net_gen_loss.backward()
        opt_gen.step()

        if iter % 100 == 0:
            with torch.no_grad():
                index = random.randint(0, num_iters)

                x = celeb_dataset[index]
                y = bitmoji_dataset[index]

                fake_x = gen_x(y)
                fake_y = gen_y(x)
                cycle_x = gen_x(fake_y)
                cycle_y = gen_y(fake_x)

                img_grid_1 = torchvision.utils.make_grid([x, fake_y, cycle_x], nrow=3, normalize=True)
                img_grid_2 = torchvision.utils.make_grid([y, fake_x, cycle_y], nrow=3, normalize=True)

                writer_1.add_image("Domain 1", img_grid_1, step)
                writer_2.add_image("Domain 2", img_grid_2, step)

                step += 1

AttributeError: 'tuple' object has no attribute 'to'

In [7]:
gen_x

Generator(
  (initial): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), padding_mode=reflect)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
  )
  (down_blocks): ModuleList(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
      )
    )
    (1): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
      )
    )
  )
  (res_blocks): Sequential(
    (0): ResidualBlock(
      (block): Sequential(
        (0): ConvBlock(
          (conv): S

In [None]:
gen_y

In [None]:
disc_x

In [None]:
disc_y