In [None]:
!pip install imageio torch matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
import random
import os
import math
import time
from tqdm import tqdm
import imageio
import itertools
import datetime
import sys

In [None]:
def initialize_weights_normal(module):
    class_name = module.__class__.__name__
    if class_name.find("Conv") != -1:
        torch.nn.init.normal_(module.weight.data, 0.0, 0.02)
    elif class_name.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(module.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(module.bias.data, 0.0)

In [None]:
class LearningRateScheduler:
    def __init__(self, total_epochs, offset_epoch, decay_start_epoch):
        assert (total_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.total_epochs = total_epochs
        self.offset_epoch = offset_epoch
        self.decay_start_epoch = decay_start_epoch
    def step(self, current_epoch):
        return 1.0 - max(0, current_epoch + self.offset_epoch - self.decay_start_epoch) / (self.total_epochs - self.decay_start_epoch)

In [None]:
class ResidualLayer(nn.Module):
    def __init__(self, num_features):
        super(ResidualLayer, self).__init__()

        convolutional_block = [
            nn.ReflectionPad2d(1),
            nn.Conv2d(num_features, num_features, 3),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(num_features, num_features, 3),
            nn.InstanceNorm2d(num_features),
        ]
        self.convolutional_block = nn.Sequential(*convolutional_block)
    def forward(self, input_tensor):
        return input_tensor + self.convolutional_block(input_tensor)

In [None]:
class FeatureEncoder(nn.Module):
    def __init__(self, input_channels=3, initial_dim=64, num_downsamples=2, shared_layer=None):
        super(FeatureEncoder, self).__init__()
        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, initial_dim, 7),
            nn.InstanceNorm2d(initial_dim),
            nn.LeakyReLU(0.2, inplace=True),
        ]
        current_dim = initial_dim
        for _ in range(num_downsamples):
            layers += [
                nn.Conv2d(current_dim, current_dim * 2, 4, stride=2, padding=1),
                nn.InstanceNorm2d(current_dim * 2),
                nn.ReLU(inplace=True),
            ]
            current_dim *= 2
        for _ in range(3):
            layers += [ResidualLayer(current_dim)]
        self.model_layers = nn.Sequential(*layers)
        self.shared_layer = shared_layer
    def reparameterize(self, mean_tensor):
        TensorType = torch.cuda.FloatTensor if mean_tensor.is_cuda else torch.FloatTensor
        sampled_tensor = Variable(TensorType(np.random.normal(0, 1, mean_tensor.shape)))
        return sampled_tensor + mean_tensor
    def forward(self, input_data):
        encoded_features = self.model_layers(input_data)
        mean_tensor = self.shared_layer(encoded_features)
        latent_vector = self.reparameterize(mean_tensor)
        return mean_tensor, latent_vector

In [None]:
class FeatureGenerator(nn.Module):
    def __init__(self, output_channels=3, base_dim=64, num_upsamples=2, shared_layer=None):
        super(FeatureGenerator, self).__init__()
        self.shared_layer = shared_layer
        layers = []
        current_dim = base_dim * (2 ** num_upsamples)
        for _ in range(3):
            layers += [ResidualLayer(current_dim)]
        for _ in range(num_upsamples):
            layers += [
                nn.ConvTranspose2d(current_dim, current_dim // 2, 4, stride=2, padding=1),
                nn.InstanceNorm2d(current_dim // 2),
                nn.LeakyReLU(0.2, inplace=True),
            ]
            current_dim //= 2
        layers += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(current_dim, output_channels, 7),
            nn.Tanh()
        ]
        self.model_layers = nn.Sequential(*layers)
    def forward(self, input_tensor):
        transformed_tensor = self.shared_layer(input_tensor)
        generated_output = self.model_layers(transformed_tensor)
        return generated_output

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_shape):
        super(PatchDiscriminator, self).__init__()
        channels, image_height, image_width = input_shape
        self.output_shape = (1, image_height // 2 ** 4, image_width // 2 ** 4)
        def create_discriminator_block(input_filters, output_filters, apply_normalization=True):
            block_layers = [nn.Conv2d(input_filters, output_filters, 4, stride=2, padding=1)]
            if apply_normalization:
                block_layers.append(nn.InstanceNorm2d(output_filters))
            block_layers.append(nn.LeakyReLU(0.2, inplace=True))
            return block_layers
        self.model_layers = nn.Sequential(
            *create_discriminator_block(channels, 64, apply_normalization=False),
            *create_discriminator_block(64, 128),
            *create_discriminator_block(128, 256),
            *create_discriminator_block(256, 512),
            nn.Conv2d(512, 1, 3, padding=1)
        )
    def forward(self, input_image):
        return self.model_layers(input_image)

In [None]:
class PairedImageDataset(Dataset):
    def __init__(self, dataset_root, transforms_=None, allow_unaligned=True, dataset_mode="train"):
        self.transform = transforms_  # Use the passed Compose object
        self.allow_unaligned = allow_unaligned
        self.monet_images = sorted(glob.glob(os.path.join(dataset_root, "monet_jpg") + "/*.*"))
        self.photo_images = sorted(glob.glob(os.path.join(dataset_root, "photo_jpg") + "/*.*"))
    def __getitem__(self, index):
        monet_image = self.transform(Image.open(self.monet_images[index % len(self.monet_images)]))
        if self.allow_unaligned:
            photo_image = self.transform(
                Image.open(self.photo_images[random.randint(0, len(self.photo_images) - 1)])
            )
        else:
            photo_image = self.transform(Image.open(self.photo_images[index % len(self.photo_images)]))
        return {"monet_image": monet_image, "photo_image": photo_image}
    def __len__(self):
        return min(len(self.monet_images), len(self.photo_images))

In [None]:
is_cuda_available = torch.cuda.is_available()
use_cuda = True if is_cuda_available else False

In [None]:
output_directory = "generated_images"
batch_size = 1
learning_rate = 0.005
beta1 = 0.5
beta2 = 0.999
learning_rate_decay_epoch = 2000
num_workers = 1
image_height = 256
image_width = 256
num_channels = 3
sample_save_interval = 2000
checkpoint_save_interval = 1000

In [None]:
learning_rates = [0.0002]
num_downsampling_layers = 1

In [None]:
def denormalize_image(tensor_image, mean_value=0.5, std_value=0.5):
    if torch.is_tensor(tensor_image):
        tensor_image = tensor_image.detach().numpy()
    restored_image = tensor_image * std_value + mean_value
    restored_image = restored_image * 255
    return np.uint8(restored_image)

In [None]:
def save_sample_images(batch_number):
    images = next(iter(val_dataloader))
    monet_images = Variable(images["monet_image"].type(tensor_type))
    photo_images = Variable(images["photo_image"].type(tensor_type))
    _, monet_latent_vector = encoder_1(monet_images)
    _, photo_latent_vector = encoder_2(photo_images)
    fake_monet_images = generator_1(photo_latent_vector)
    fake_photo_images = generator_2(monet_latent_vector)
    monet_grid = denormalize_image(
        make_grid(photo_images.cpu(), nrow=4).permute(1, 2, 0).numpy()
    )
    fake_monet_grid = denormalize_image(
        make_grid(fake_monet_images.cpu(), nrow=4).permute(1, 2, 0).numpy()
    )
    fig, (axis_monet, axis_fake_monet) = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(30, 20))
    axis_monet.imshow(monet_grid)
    axis_monet.axis("off")
    axis_monet.set_title("Photo Images (X)")
    axis_fake_monet.imshow(fake_monet_grid)
    axis_fake_monet.axis("off")
    axis_fake_monet.set_title("Generated Monet-like Images (Fake Y)")
    plt.show()

In [None]:
def compute_kl_divergence(mean_tensor):
    squared_mean = torch.pow(mean_tensor, 2)
    kl_loss = torch.mean(squared_mean)
    return kl_loss

In [None]:
base_dim = 64
losses_over_epochs = []
n_epochs = 15
dataset_name = 'gan-getting-started'
beta1 = 0.5
beta2 = 0.999
epoch = 0
batches_completed = 0
for learning_rate in learning_rates:
    gan_loss_criterion = torch.nn.MSELoss()
    pixel_loss_criterion = torch.nn.L1Loss()
    input_image_shape = (3, 256, 256)
    shared_embedding_dim = base_dim * (2 ** num_downsampling_layers)
    shared_encoder_block = ResidualLayer(num_features=shared_embedding_dim)
    encoder_1 = FeatureEncoder(initial_dim=base_dim, num_downsamples=num_downsampling_layers, shared_layer=shared_encoder_block)
    encoder_2 = FeatureEncoder(initial_dim=base_dim, num_downsamples=num_downsampling_layers, shared_layer=shared_encoder_block)
    shared_generator_block = ResidualLayer(num_features=shared_embedding_dim)
    generator_1 = FeatureGenerator(base_dim=base_dim, num_upsamples=num_downsampling_layers, shared_layer=shared_generator_block)
    generator_2 = FeatureGenerator(base_dim=base_dim, num_upsamples=num_downsampling_layers, shared_layer=shared_generator_block)
    discriminator_1 = PatchDiscriminator(input_image_shape)
    discriminator_2 = PatchDiscriminator(input_image_shape)
    if use_cuda:
        encoder_1.cuda()
        encoder_2.cuda()
        generator_1.cuda()
        generator_2.cuda()
        discriminator_1.cuda()
        discriminator_2.cuda()
        gan_loss_criterion.cuda()
        pixel_loss_criterion.cuda()
    for model in [encoder_1, encoder_2, generator_1, generator_2, discriminator_1, discriminator_2]:
        model.apply(initialize_weights_normal)
    loss_weights = {
        "gan": 10,
        "kl_encoded": 0.1,
        "id_pixel": 100,
        "kl_translated": 0.1,
        "cycle_pixel": 100,
    }
    optimizer_generator = torch.optim.Adam(
        itertools.chain(encoder_1.parameters(), encoder_2.parameters(),
                        generator_1.parameters(), generator_2.parameters()),
        lr=learning_rate,
        betas=(beta1, beta2),
    )
    optimizer_discriminator_1 = torch.optim.Adam(discriminator_1.parameters(), lr=learning_rate, betas=(beta1, beta2))
    optimizer_discriminator_2 = torch.optim.Adam(discriminator_2.parameters(), lr=learning_rate, betas=(beta1, beta2))
    lr_schedulers = {
        "generator": torch.optim.lr_scheduler.LambdaLR(
            optimizer_generator, lr_lambda=LearningRateScheduler(n_epochs, 0, 0).step
        ),
        "discriminator_1": torch.optim.lr_scheduler.LambdaLR(
            optimizer_discriminator_1, lr_lambda=LearningRateScheduler(n_epochs, 0, 0).step
        ),
        "discriminator_2": torch.optim.lr_scheduler.LambdaLR(
            optimizer_discriminator_2, lr_lambda=LearningRateScheduler(n_epochs, 0, 0).step
        ),
    }
    tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    image_transforms = transforms.Compose([
        transforms.Resize(int(256 * 1.12), Image.BICUBIC),
        transforms.RandomCrop((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    epoch_losses = []
    start_time = time.time()
    for epoch in range(n_epochs):
        torch.manual_seed(epoch)
        epoch_loss_generator = 0
        train_loader = DataLoader(
        PairedImageDataset(dataset_name, transforms_=image_transforms, allow_unaligned=True),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        )
        val_dataloader = DataLoader(
        PairedImageDataset(dataset_name, transforms_=image_transforms, allow_unaligned=True, dataset_mode="test"),
        batch_size=5,
        shuffle=False,
        num_workers=1,
        )
        for i, batch in enumerate(train_loader):
            real_monet_images = Variable(batch["monet_image"].type(tensor_type))
            real_photo_images = Variable(batch["photo_image"].type(tensor_type))
            valid_labels = Variable(tensor_type(np.ones((real_monet_images.size(0), *discriminator_1.output_shape))),
                                    requires_grad=False)
            fake_labels = Variable(tensor_type(np.zeros((real_monet_images.size(0), *discriminator_1.output_shape))),
                                   requires_grad=False)
            optimizer_generator.zero_grad()
            monet_mean, monet_latent = encoder_1(real_monet_images)
            photo_mean, photo_latent = encoder_2(real_photo_images)
            reconstructed_monet = generator_1(monet_latent)
            reconstructed_photo = generator_2(photo_latent)
            fake_monet = generator_1(photo_latent)
            fake_photo = generator_2(monet_latent)
            monet_cycle_mean, monet_cycle_latent = encoder_1(fake_monet)
            photo_cycle_mean, photo_cycle_latent = encoder_2(fake_photo)
            cycle_monet = generator_1(photo_cycle_latent)
            cycle_photo = generator_2(monet_cycle_latent)
            loss_gan_1 = loss_weights["gan"] * gan_loss_criterion(discriminator_1(fake_monet), valid_labels)
            loss_gan_2 = loss_weights["gan"] * gan_loss_criterion(discriminator_2(fake_photo), valid_labels)
            loss_kl_monet = loss_weights["kl_encoded"] * compute_kl_divergence(monet_mean)
            loss_kl_photo = loss_weights["kl_encoded"] * compute_kl_divergence(photo_mean)
            loss_pixel_monet = loss_weights["id_pixel"] * pixel_loss_criterion(reconstructed_monet, real_monet_images)
            loss_pixel_photo = loss_weights["id_pixel"] * pixel_loss_criterion(reconstructed_photo, real_photo_images)
            loss_cycle_monet = loss_weights["cycle_pixel"] * pixel_loss_criterion(cycle_monet, real_monet_images)
            loss_cycle_photo = loss_weights["cycle_pixel"] * pixel_loss_criterion(cycle_photo, real_photo_images)

            total_loss_generator = (
                loss_gan_1 + loss_gan_2 +
                loss_kl_monet + loss_kl_photo +
                loss_pixel_monet + loss_pixel_photo +
                loss_cycle_monet + loss_cycle_photo
            )

            total_loss_generator.backward()
            optimizer_generator.step()
            epoch_loss_generator += total_loss_generator.item() / len(train_loader)
            optimizer_discriminator_1.zero_grad()
            loss_discriminator_1 = gan_loss_criterion(discriminator_1(real_monet_images), valid_labels) + \
                                   gan_loss_criterion(discriminator_1(fake_monet.detach()), fake_labels)
            loss_discriminator_1.backward()
            optimizer_discriminator_1.step()
            optimizer_discriminator_2.zero_grad()
            loss_discriminator_2 = gan_loss_criterion(discriminator_2(real_photo_images), valid_labels) + \
                                   gan_loss_criterion(discriminator_2(fake_photo.detach()), fake_labels)
            loss_discriminator_2.backward()
            optimizer_discriminator_2.step()
            batches_completed = epoch * len(train_loader) + i
            estimated_time_left = datetime.timedelta(
                seconds=(n_epochs * len(train_loader) - batches_completed) * (time.time() - start_time)
            )
            sys.stdout.write(
                f"\r[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_loader)}] "
                f"[D Loss: {(loss_discriminator_1 + loss_discriminator_2).item()}] "
                f"[G Loss: {total_loss_generator.item()}] ETA: {estimated_time_left}"
            )
        if batches_completed % sample_save_interval == 0:
            save_sample_images(batches_completed)
        epoch_losses.append(epoch_loss_generator)
        for scheduler in lr_schedulers.values():
            scheduler.step()
    losses_over_epochs.append(epoch_losses)

[Epoch 14/15] [Batch 299/300] [D Loss: 0.28558385372161865] [G Loss: 59.693138122558594] ETA: 0:14:39.882703405832311

In [None]:
import os
class ImageDataset(Dataset):
    def __init__(self, img_path, img_size=256, normalize=True):
        self.img_path = img_path
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor()
            ])
        self.img_idx = {number_: img_ for number_, img_ in enumerate(os.listdir(self.img_path))}
    def __len__(self):
        return len(self.img_idx)
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_path, self.img_idx[idx])
        img = Image.open(img_path)
        img = self.transform(img)
        return img
encoder_1.eval()
generator_1.eval()
path_photo = "gan-getting-started/photo_jpg"
dataset_photo = ImageDataset(path_photo, img_size=256, normalize=True)
submit_dataloader = DataLoader(dataset_photo, batch_size=1, shuffle=False)
output_dir = "generatedimages"
os.makedirs(output_dir, exist_ok=True)
mean_ = 0.5
std_ = 0.5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder_1.to(device)
generator_1.to(device)
dataiter = iter(submit_dataloader)
for image_idx in tqdm(range(len(submit_dataloader))):
    fixed_X = next(dataiter)
    _, encod_fake = encoder_1(fixed_X.to(device))
    fake_Y = generator_1(encod_fake)
    fake_Y = fake_Y.detach().cpu().numpy()
    fake_Y = denormalize_image(fake_Y, mean_value=mean_, std_value=std_)
    fake_Y = fake_Y[0].transpose(1, 2, 0)
    fake_Y = np.uint8(fake_Y)
    fake_Y = Image.fromarray(fake_Y)
    fake_Y.save(os.path.join(output_dir, f"{image_idx}.jpg"))
encoder_1.train()
generator_1.train()

100%|██████████| 7038/7038 [06:17<00:00, 18.66it/s]


FeatureGenerator(
  (shared_layer): ResidualLayer(
    (convolutional_block): Sequential(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU(inplace=True)
      (4): ReflectionPad2d((1, 1, 1, 1))
      (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (6): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
  )
  (model_layers): Sequential(
    (0): ResidualLayer(
      (convolutional_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d((1, 1, 1, 1))
        (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (6): InstanceNorm2

In [None]:
print(losses_over_epochs)

[[132.05707829793292, 118.11951235453297, 108.16086034138999, 106.5316303126017, 99.82629423777253, 97.97450312296546, 95.07357529958087, 89.69938290913899, 89.15468419392904, 86.69816014607746, 83.75062998453782, 82.78465667724605, 79.68022748311368, 78.10049790700279, 76.4397444534302]]


In [None]:
!pip install pytorch_fid

In [None]:
from pytorch_fid.fid_score import calculate_fid_given_paths
real_images_path = 'gan-getting-started/photo_jpg'
generated_images_path = 'generatedimages'
fid_score = calculate_fid_given_paths(
    paths=[real_images_path, generated_images_path],
    batch_size=50,
    device='cuda',
    dims=2048
)
print(f"FID Score: {fid_score}")

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:01<00:00, 60.8MB/s]
100%|██████████| 141/141 [00:49<00:00,  2.83it/s]
100%|██████████| 141/141 [00:33<00:00,  4.18it/s]


FID Score: 22.316519949675524
