In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from glob import glob

In [2]:
class ImageDataset(Dataset):
    def __init__(self, transform=None):
        self.training_sketches = sorted(glob('dataset/training_i/cropped_sketches/*.jpg'))
        self.training_photos = sorted(glob('dataset/training_ii/cropped_photos/*.jpg'))
        self.transform = transform

    def __len__(self) -> int:
        return len(self.training_sketches)

    def __getitem__(self, index: int):
        X = Image.open(self.training_sketches[index]).resize((256, 256))
        Y = Image.open(self.training_photos[index]).resize((256, 256))

        if self.transform:
            X = self.transform(X)
            Y = self.transform(Y)

        return (X, Y)

In [3]:
transform = transforms.ToTensor()
training_dataset = ImageDataset(transform=transform)
training_dataloader = DataLoader(training_dataset, batch_size=4, num_workers=0)

In [4]:
def sample(kernel, in_channels, out_channels, dropout=False):
    """Downsample or Upsample based on kernel and in/out dimensions."""
    
    layers = []
    layers.append(kernel(in_channels, out_channels, kernel_size=4, stride=2, padding=1))
    layers.append(nn.BatchNorm2d(num_features=out_channels))
    if dropout:
        layers.append(nn.Dropout())
    layers.append(nn.LeakyReLU())

    return nn.Sequential(*layers)

In [5]:
class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.generator = nn.Sequential(

            # Downsample
            sample(nn.Conv2d, 1, 64),                              # [64, 128, 128]
            sample(nn.Conv2d, 64, 128),                            # [128, 64, 64]
            sample(nn.Conv2d, 128, 256),                           # [256, 32, 32]
            sample(nn.Conv2d, 256, 256, dropout=True),             # [256, 16, 16]
            sample(nn.Conv2d, 256, 256, dropout=True),             # [256, 8, 8]

            # Upsample
            sample(nn.ConvTranspose2d, 256, 256, dropout=True),    # [256, 16, 16]
            sample(nn.ConvTranspose2d, 256, 256),                  # [256, 32, 32]
            sample(nn.ConvTranspose2d, 256, 128),                  # [128, 64, 64]
            sample(nn.ConvTranspose2d, 128, 64),                   # [64, 128, 128]
            sample(nn.ConvTranspose2d, 64, 3),                     # [3, 256, 256]

        )

    def forward(self, input_image):
        # input_image.shape = [1, 256, 256]
        return self.generator(input_image)

In [6]:
class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.discriminator = nn.Sequential(
            sample(nn.Conv2d, 6, 64),                              # [64, 128, 128]
            sample(nn.Conv2d, 64, 128),                            # [128, 64, 64]
            sample(nn.Conv2d, 128, 256),                           # [256, 32, 32]
            nn.Conv2d(256, 128, kernel_size=4, padding=1),         # [128, 31, 31]
            nn.LeakyReLU(),
            nn.Conv2d(128, 64, kernel_size=4, padding=1),          # [64, 30, 30]
            nn.LeakyReLU(),
            nn.Conv2d(64, 1, kernel_size=4, padding=1)             # [1, 29, 29] 
        )

    def forward(self, input_image, target_image):
        # input_image.shape = [3, 256, 256]
        # target_image.shape = [3, 256, 256]
        concat = torch.cat([input_image, target_image], dim=1)
        return self.discriminator(concat)

In [8]:
gan_loss_1 = nn.BCEWithLogitsLoss()
gan_loss_2 = nn.MSELoss()

generator = Generator()
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters())
optimizer_D = torch.optim.Adam(discriminator.parameters())
