In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os


In [29]:
IMG_SIZE = 128
BATCH_SIZE = 4
EPOCHS = 200
LAMBDA_L1 = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [30]:
class FaceDataset(Dataset):
    def __init__(self, root):
        self.real_dir = os.path.join(root, "real")
        self.comic_dir = os.path.join(root, "comic")
        self.images = os.listdir(self.real_dir)

        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        name = self.images[idx]
        real = Image.open(os.path.join(self.real_dir, name)).convert("RGB")
        comic = Image.open(os.path.join(self.comic_dir, name)).convert("RGB")

        return self.transform(real), self.transform(comic)


In [31]:
class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()

        def down(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 4, 2, 1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2)
            )

        def up(in_c, out_c):
            return nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
                nn.BatchNorm2d(out_c),
                nn.ReLU()
            )

        self.d1 = down(3, 64)
        self.d2 = down(64, 128)
        self.d3 = down(128, 256)
        self.d4 = down(256, 512)

        self.u1 = up(512, 256)
        self.u2 = up(512, 128)
        self.u3 = up(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)
        d4 = self.d4(d3)

        u1 = self.u1(d4)
        u2 = self.u2(torch.cat([u1, d3], 1))
        u3 = self.u3(torch.cat([u2, d2], 1))

        return self.final(torch.cat([u3, d1], 1))


In [32]:
class PatchDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], 1))


In [33]:
dataset = FaceDataset("train")
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

G = UNetGenerator().to(DEVICE)
G.eval()
D = PatchDiscriminator().to(DEVICE)

opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

for epoch in range(EPOCHS):
    for real, comic in loader:
        real, comic = real.to(DEVICE), comic.to(DEVICE)

        # ---------------------
        # Train Discriminator
        # ---------------------
        fake = G(real)

        D_real = D(real, comic)
        D_fake = D(real, fake.detach())

        loss_D = (criterion_GAN(D_real, torch.ones_like(D_real)) +
                  criterion_GAN(D_fake, torch.zeros_like(D_fake))) * 0.5

        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # -----------------
        # Train Generator
        # -----------------
        D_fake = D(real, fake)
        loss_G = criterion_GAN(D_fake, torch.ones_like(D_fake)) + \
                 LAMBDA_L1 * criterion_L1(fake, comic)

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

    print(f"Epoch {epoch} | D: {loss_D.item():.4f} | G: {loss_G.item():.4f}")


Epoch 0 | D: 0.6675 | G: 39.2742
Epoch 1 | D: 0.6253 | G: 37.8299
Epoch 2 | D: 0.6088 | G: 30.1862
Epoch 3 | D: 0.5976 | G: 25.2234
Epoch 4 | D: 0.6354 | G: 19.6381
Epoch 5 | D: 0.6090 | G: 16.9838
Epoch 6 | D: 0.6063 | G: 15.8521
Epoch 7 | D: 0.5731 | G: 15.4483
Epoch 8 | D: 0.5565 | G: 14.3382
Epoch 9 | D: 0.5042 | G: 12.5873
Epoch 10 | D: 0.4874 | G: 14.2626
Epoch 11 | D: 0.4376 | G: 13.2247
Epoch 12 | D: 0.5476 | G: 12.8965
Epoch 13 | D: 0.4647 | G: 12.6731
Epoch 14 | D: 0.4236 | G: 13.3178
Epoch 15 | D: 0.3965 | G: 12.3911
Epoch 16 | D: 0.3614 | G: 11.6619
Epoch 17 | D: 0.3814 | G: 11.5810
Epoch 18 | D: 0.4841 | G: 12.2905
Epoch 19 | D: 0.4464 | G: 12.0989
Epoch 20 | D: 0.3683 | G: 11.1056
Epoch 21 | D: 0.3593 | G: 10.6318
Epoch 22 | D: 0.4167 | G: 11.5636
Epoch 23 | D: 0.4351 | G: 10.1235
Epoch 24 | D: 0.5154 | G: 10.5241
Epoch 25 | D: 0.4060 | G: 10.2519
Epoch 26 | D: 0.3110 | G: 11.2847
Epoch 27 | D: 0.3072 | G: 9.7368
Epoch 28 | D: 0.2730 | G: 9.9237
Epoch 29 | D: 0.2588 | G: 

In [34]:
# After training finishes
torch.save(G.state_dict(), "pix2pix_generator_128.pth")
print("Generator saved!")


Generator saved!


In [35]:
class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()

        def down(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 4, 2, 1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2)
            )

        def up(in_c, out_c):
            return nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
                nn.BatchNorm2d(out_c),
                nn.ReLU()
            )

        self.d1 = down(3, 64)
        self.d2 = down(64, 128)
        self.d3 = down(128, 256)
        self.d4 = down(256, 512)

        self.u1 = up(512, 256)
        self.u2 = up(512, 128)
        self.u3 = up(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)
        d4 = self.d4(d3)

        u1 = self.u1(d4)
        u2 = self.u2(torch.cat([u1, d3], 1))
        u3 = self.u3(torch.cat([u2, d2], 1))

        return self.final(torch.cat([u3, d1], 1))


In [36]:
G = UNetGenerator().to(DEVICE)
G.load_state_dict(torch.load("pix2pix_generator_128.pth", map_location=DEVICE))
G.eval()


UNetGenerator(
  (d1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (d2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (d3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (d4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (u1): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(4,

In [37]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


In [38]:
@torch.no_grad()
def generate_comic(input_path, output_path):
    img = Image.open(input_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(DEVICE)

    fake = G(x)

    # De-normalize
    fake = (fake.squeeze(0) + 1) / 2
    fake = fake.clamp(0, 1)

    out = transforms.ToPILImage()(fake.cpu())
    out.save(output_path)

    print("Saved:", output_path)


In [39]:
generate_comic(
    input_path="face.6.jpg",
    output_path="face.png"
)


Saved: face.png
