In [73]:
from os import listdir
import time
import datetime
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as VF
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image

import matplotlib.pyplot as plt

In [None]:
transforms.Lambda(lambda img: (torch.tensor(np.array(img), dtype=torch.float)[:, :256, :], 
                              torch.tensor(np.array(img), dtype=torch.float)[:, 256:, :]))

In [46]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

dataset = datasets.ImageFolder('train_images', transform=transform)
val_dataset = datasets.ImageFolder('val_images', transform=transform)

In [47]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=len(val_dataset))

In [31]:
for batch in dataloader:
    break

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [38]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           U-NET
##############################


class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


##############################
#        Discriminator
##############################


class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

In [45]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Calculate output of image discriminator (PatchGAN)
patch = (1, 16, 16)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

generator = generator.to(device)
discriminator = discriminator.to(device)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

def set_requires_grad(net, req_grad):
    for param in net.parameters():
        param.requires_grad = req_grad

def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs["B"].type(Tensor))
    real_B = Variable(imgs["A"].type(Tensor))
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)


# ----------
#  Training
# ----------

prev_time = time.time()

n_epochs = 100

for epoch in range(n_epochs):
    for i, batch in enumerate(dataloader):

        # Model inputs
        real_A = batch[0][:, :, :, :256]#Variable(batch["B"].type(Tensor))
        real_B = batch[0][:, :, :, 256:]#Variable(batch["A"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(torch.tensor(np.ones((real_A.size(0), *patch)), dtype=torch.float), requires_grad=False)
        fake = Variable(torch.tensor(np.zeros((real_A.size(0), *patch)), dtype=torch.float), requires_grad=False)

        # ------------------
        #  Train Generator
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = loss_GAN + 100 * loss_pixel

        loss_G.backward()

        set_requires_grad(discriminator, False)
        optimizer_G.step()
        set_requires_grad(discriminator, True)

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        print(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
                loss_GAN.item(),
                time_left,
            )
        )

        # If at sample interval save image
        #if batches_done % opt.sample_interval == 0:
        #    sample_images(batches_done)

[Epoch 0/100] [Batch 0/400] [D loss: 1.535743] [G loss: 74.877800, pixel: 0.731446, adv: 1.733238] ETA: 21:56:04.290619
[Epoch 0/100] [Batch 1/400] [D loss: 2.128437] [G loss: 71.994637, pixel: 0.698849, adv: 2.109776] ETA: 23:29:59.201419
[Epoch 0/100] [Batch 2/400] [D loss: 2.246799] [G loss: 67.437668, pixel: 0.649170, adv: 2.520687] ETA: 1 day, 1:34:13.721898
[Epoch 0/100] [Batch 3/400] [D loss: 1.631126] [G loss: 56.127113, pixel: 0.547190, adv: 1.408149] ETA: 20:00:17.255695
[Epoch 0/100] [Batch 4/400] [D loss: 1.727078] [G loss: 53.590683, pixel: 0.518232, adv: 1.767525] ETA: 19:44:43.818031
[Epoch 0/100] [Batch 5/400] [D loss: 3.136618] [G loss: 57.604401, pixel: 0.526777, adv: 4.926655] ETA: 19:03:57.964592
[Epoch 0/100] [Batch 6/400] [D loss: 2.328010] [G loss: 46.968796, pixel: 0.445398, adv: 2.428969] ETA: 18:32:50.283095
[Epoch 0/100] [Batch 7/400] [D loss: 1.509299] [G loss: 47.290497, pixel: 0.458047, adv: 1.485766] ETA: 18:34:45.904536
[Epoch 0/100] [Batch 8/400] [D los

[Epoch 0/100] [Batch 68/400] [D loss: 0.262877] [G loss: 44.289925, pixel: 0.433861, adv: 0.903860] ETA: 19:57:33.520674
[Epoch 0/100] [Batch 69/400] [D loss: 0.253105] [G loss: 48.183781, pixel: 0.471489, adv: 1.034861] ETA: 19:32:59.608001
[Epoch 0/100] [Batch 70/400] [D loss: 0.269272] [G loss: 45.029507, pixel: 0.438439, adv: 1.185568] ETA: 20:50:26.801934
[Epoch 0/100] [Batch 71/400] [D loss: 0.274529] [G loss: 39.402306, pixel: 0.385060, adv: 0.896282] ETA: 20:48:29.438095


KeyboardInterrupt: 

In [None]:
imgs = next(iter(val_dataloader))
ix = np.random.randint(0, len(val_dataset), 3)
real_A = imgs[:, :, :, :256]
real_B = imgs[:, :, :, 256:]
fake_B = generator(real_A)
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)

In [64]:
imgs = next(iter(val_dataloader))
ix = np.random.randint(0, len(val_dataset), 3).tolist()
real_A = imgs[0][ix, :, :, :256]
real_B = imgs[0][ix, :, :, 256:]

In [66]:
fake_B = generator(real_A)

In [71]:
save_image(fake_B, "images/%s.png" % ('facades_' +str(1)), nrow=5, normalize=True)

In [77]:
img = VF.to_pil_image(fake_B.reshape((3, 256, 256 * 3)))

In [None]:
plt.imshow(np.asarray(img))