In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

In [3]:
os.chdir('drive/MyDrive/Studies/ArtGan')

In [None]:
# !unzip dataset_artgan_realism.zip

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [6]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

In [8]:
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [9]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [10]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [11]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [12]:
import glob
import random
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

In [13]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

In [14]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%s_A" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%s_B" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])
        image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [15]:
import argparse
import os
import numpy as np
import math
import itertools
import datetime
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [None]:
os.makedirs("images/realism_art", exist_ok=True)
os.makedirs("saved_models/realism_art", exist_ok=True)

In [16]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

In [17]:
cuda = torch.cuda.is_available()

input_shape = (3, 256, 256)
G_AB = GeneratorResNet(input_shape, 9)
G_BA = GeneratorResNet(input_shape, 9)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

In [None]:
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

In [None]:
start_epoch = 80
n_epochs = 200

In [None]:
if start_epoch != 0 and not cuda:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % ('realism_art', start_epoch), map_location=torch.device('cpu')))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % ('realism_art', start_epoch), map_location=torch.device('cpu')))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % ('realism_art', start_epoch), map_location=torch.device('cpu')))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % ('realism_art', start_epoch), map_location=torch.device('cpu')))
elif start_epoch != 0 and cuda:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % ('realism_art', start_epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % ('realism_art', start_epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % ('realism_art', start_epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % ('realism_art', start_epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

In [None]:
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, start_epoch, 100).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, start_epoch, 100).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, start_epoch, 100).step
)

In [32]:
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Image transformations
transforms_ = [
    transforms.Resize(int(256 * 1.12), Image.BICUBIC),
    transforms.RandomCrop((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
# './ArtGAN_project'
# Training data loader
# Test data loader
val_dataloader = DataLoader(
    ImageDataset('', transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)
dataloader = DataLoader(
    ImageDataset('', transforms_=transforms_, unaligned=True),
    batch_size=1,
    shuffle=True,
    num_workers=8,
)

  cpuset_checked))


In [None]:
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % ('realism_art', batches_done), normalize=False)

In [None]:
import sys

In [None]:
prev_time = time.time()
for epoch in range(start_epoch, n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + 10.0 * loss_cycle + 5.0 * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = 200 * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                200,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % 1000 == 0:
            sample_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if epoch % 5 == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % ('realism_art', epoch))
        torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % ('realism_art', epoch))
        torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % ('realism_art', epoch))
        torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % ('realism_art', epoch))

In [None]:
# Initially wanted to train for 200 epochs, but only got 140 epochs.
# Each epochs took approximately 2 hours 40 mins to train.
# Testing an evaluating.

In [43]:
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

# Load pretrained models
# 140 epochs done
epoch = 140
G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % ('realism_art', epoch)))
G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % ('realism_art', epoch)))
D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % ('realism_art', epoch)))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % ('realism_art', epoch)))

<All keys matched successfully>

In [46]:
def get_parameter_count(model):
 return (sum(p.numel() for p in model.parameters()), 
         sum(p.numel() for p in model.parameters() if p.requires_grad))

print("G_AB paramters:", get_parameter_count(G_AB))
print("G_BA paramters:", get_parameter_count(G_BA))
print("D_A paramters:", get_parameter_count(D_A))
print("D_B paramters:", get_parameter_count(D_B))

G_AB paramters: (11378179, 11378179)
G_BA paramters: (11378179, 11378179)
D_A paramters: (2764737, 2764737)
D_B paramters: (2764737, 2764737)


In [33]:
# Test dataloader
test_datasetloader = DataLoader(
    ImageDataset('', transforms_=transforms_, unaligned=True, mode="check"),
    batch_size=1,
    shuffle=True,
    num_workers=1,
)

In [None]:
os.makedirs("images/test_real_A", exist_ok=True)
os.makedirs("images/test_real_B", exist_ok=True)
os.makedirs("images/test_fake_A", exist_ok=True)
os.makedirs("images/test_fake_B", exist_ok=True)

In [None]:
def sample_images(count):
    """Saves a generated sample from the test set"""
    test_real_A = "test_real_A"
    test_real_B = "test_real_B"
    test_fake_A = "test_fake_A"
    test_fake_B = "test_fake_B"
    imgs = next(iter(test_datasetloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along y-axis
    save_image(real_A, "images/%s/%s.png" % (test_real_A, count), normalize=True)
    save_image(real_B, "images/%s/%s.png" % (test_real_B, count), normalize=True)
    save_image(fake_A, "images/%s/%s.png" % (test_fake_A, count), normalize=True)
    save_image(fake_B, "images/%s/%s.png" % (test_fake_B, count), normalize=True)

In [34]:
 imgs = next(iter(test_datasetloader))

In [35]:
real_A = Variable(imgs["A"].type(Tensor))

In [38]:
real_A.shape

torch.Size([1, 3, 256, 256])

In [None]:
total_img = min(len(os.listdir('check_A')), len(os.listdir('check_B')))

In [None]:
for c in range(total_img):
  sample_images(c)

In [20]:
import glob
import numpy as np
import scipy.misc
from scipy.spatial.distance import minkowski
from scipy.stats import ks_2samp
import time, imageio, os
import torch
from tqdm import tqdm
from PIL import Image

In [21]:
def get_image_vector(filename):
    image = Image.open(filename)
    image = image.resize((256, 256))
    image.save('tmp_file.'+filename.split('.')[-1])
    image.close()
    im = imageio.imread('tmp_file.'+filename.split('.')[-1], pilmode='RGB')

    return np.float32(np.ndarray.flatten(im)) / 255

In [22]:
def gpu_LS(real, gen):
    # to torch tensors
    t_gen = torch.from_numpy(gen)
    t_real = torch.from_numpy(real)

    dist_real = torch.cdist(t_real, t_real)  # ICD 1
    dist_real = torch.flatten(torch.tril(dist_real, diagonal=-1))  # remove repeats
    dist_real = dist_real[dist_real.nonzero()].flatten()  # remove distance=0 for distances btw same data points

    dist_gen = torch.cdist(t_gen, t_gen)  # ICD 2
    dist_gen = torch.flatten(torch.tril(dist_gen, diagonal=-1))  # remove repeats
    dist_gen = dist_gen[dist_gen.nonzero()].flatten()  # remove distance=0 for distances btw same data points

    distbtw = torch.cdist(t_gen, t_real)  # BCD
    distbtw = torch.flatten(distbtw)

    D_Sep_1, _ = ks_2samp(dist_real, distbtw)
    D_Sep_2, _ = ks_2samp(dist_gen, distbtw)

    return 1 - np.max([D_Sep_1, D_Sep_2])  # LS=1-DSI

In [52]:
import random
# Base-line between Real-life and Real-art - network only generates 2 types of images.
filenames_1 = glob.glob(os.path.join('images/', 'test_real_A/*'))
gen = np.array([get_image_vector(filename) for filename in filenames_1])

filenames_2 = glob.glob(os.path.join('images/', 'test_real_B/*'))
filenames_2 = random.choices(filenames_2, k=2)
real = np.array([get_image_vector(filename) for filename in filenames_2])

print('real #:   ' + str(len(real)))
print('gen #:   ' + str(len(gen)))

print('\n', 'LS= ', gpu_LS(real, gen))

real #:   2
gen #:   630

 LS=  0.26190476190476186


In [53]:
import random
# Base-line between Real-life and Real-art - network only generates 2 types of images.
filenames_1 = glob.glob(os.path.join('images/', 'test_real_B/*'))
gen = np.array([get_image_vector(filename) for filename in filenames_1])

filenames_2 = glob.glob(os.path.join('images/', 'test_real_A/*'))
filenames_2 = random.choices(filenames_2, k=2)
real = np.array([get_image_vector(filename) for filename in filenames_2])

print('real #:   ' + str(len(real)))
print('gen #:   ' + str(len(gen)))

print('\n', 'LS= ', gpu_LS(real, gen))

real #:   2
gen #:   630

 LS=  0.29126984126984123


In [24]:
# Real-life to generated Art
filenames_1 = glob.glob(os.path.join('images/', 'test_fake_A/*'))
gen = np.array([get_image_vector(filename) for filename in filenames_1])

filenames_2 = glob.glob(os.path.join('images/', 'test_real_B/*'))
real = np.array([get_image_vector(filename) for filename in filenames_2])

print('real #:   ' + str(len(real)))
print('gen #:   ' + str(len(gen)))

print('\n', 'LS= ', gpu_LS(real, gen))

real #:   630
gen #:   630

 LS=  0.9467593764232419


In [25]:
# Art to generated real-life
filenames_1 = glob.glob(os.path.join('images/', 'test_real_A/*'))
gen = np.array([get_image_vector(filename) for filename in filenames_1])

filenames_2 = glob.glob(os.path.join('images/', 'test_fake_B/*'))
real = np.array([get_image_vector(filename) for filename in filenames_2])

print('real #:   ' + str(len(real)))
print('gen #:   ' + str(len(gen)))

print('\n', 'LS= ', gpu_LS(real, gen))

real #:   630
gen #:   630

 LS=  0.9454908329698246
