In [13]:
import numpy as np
import os
import PIL.Image as Image

from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from IPython import display
import matplotlib.pylab as plt
import ipywidgets

In [21]:
if torch.cuda.is_available():
    print("The code will run on GPU.")
else:
    print("The code will run on CPU.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The code will run on GPU.


In [22]:
class Horses(torch.utils.data.Dataset):
    def __init__(self, train, transform, data_path='/scratch/horse2zebra'):
        'Initialization'
        self.transform    = transform
        
        data_path         = os.path.join(data_path, 'train' if train else 'test')
        data_path         = os.path.join(data_path, 'A')
        list_files_horses = os.listdir(data_path)
        self.image_paths  = [os.path.join(data_path, file) for file in list_files_horses]
        
    def __len__(self):
        'Returns the total number of samples'
        return len(self.image_paths)

    def __getitem__(self, idx):
        'Generates one sample of data'
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        X = self.transform(image)
        return X

In [23]:
class Zebras(torch.utils.data.Dataset):
    def __init__(self, train, transform, data_path='/scratch/horse2zebra'):
        'Initialization'
        self.transform    = transform
        data_path         = os.path.join(data_path, 'train' if train else 'test')
        data_path         = os.path.join(data_path, 'B')
        list_files_zebras = os.listdir(data_path)
        self.image_paths  = [os.path.join(data_path, file) for file in list_files_zebras]
        
        self.image_paths
        
    def __len__(self):
        'Returns the total number of samples'
        return len(self.image_paths)

    def __getitem__(self, idx):
        'Generates one sample of data'
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        X = self.transform(image)
        return X

In [33]:
batch_size = 1
size = 256

transformer_train_horses = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

transformer_train_zebras = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

transformer_test_horses = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

transformer_test_zebras = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

train_set_horses    = Horses(True, transformer_train_horses)
train_loader_horses = DataLoader(train_set_horses, batch_size=batch_size, shuffle=True, num_workers=3)

test_set_horses     = Horses(False, transformer_test_horses)
test_loader_horses  = DataLoader(test_set_horses, batch_size=batch_size, shuffle=True, num_workers=3)

train_set_zebras    = Zebras(True, transformer_train_zebras)
train_loader_zebras = DataLoader(train_set_zebras, batch_size=batch_size, shuffle=True, num_workers=3)

test_set_zebras     = Horses(False, transformer_test_zebras)
test_loader_zebras  = DataLoader(test_set_zebras, batch_size=batch_size, shuffle=True, num_workers=3)

<hr>

<h2> Networks

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, n_features):
        super(ResNetBlock, self).__init__()
        self.resblock = nn.Sequential(
            nn.Conv2d(n_features,n_features,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_features,n_features,3,padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.resblock(x) + x
        x = F.relu(x)
        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, n_in, n_features, num_res_blocks=5):
        super(ResNet, self).__init__()
        #First conv layers needs to output the desired number of features.
        conv_layers = [nn.Conv2d(n_in, n_features, kernel_size=3, stride=1, padding=1),
                       nn.ReLU()]
        for i in range(num_res_blocks):
            conv_layers.append(ResNetBlock(n_features))
        self.res_blocks = nn.Sequential(*conv_layers)
        self.fc = nn.Sequential(nn.Linear(32*32*n_features, 2048),
                                nn.ReLU(),
                                nn.Linear(2048, 512),
                                nn.ReLU(),
                                nn.Linear(512,10),
                                nn.Softmax(dim=1))
        
    def forward(self, x):
        x = self.res_blocks(x)
        #reshape x so it becomes flat, except for the first dimension (which is the minibatch)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return out

In [32]:
class CGAN_G(nn.Module):
    def __init__(self):
        super(CGAN_G, self).__init__()
        
        # encoder
        self.hidden0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.LeakyReLU(.2),
            nn.BatchNorm2d()
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(.2),
            nn.BatchNorm2d()
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(.2),
            nn.BatchNorm2d()
        )
        
        # transformer
        num_res_blocks = 9
        conv_layers = [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                       nn.ReLU()]
        for i in range(num_res_blocks):
            conv_layers.append(ResNetBlock(n_features))
        self.res_blocks = nn.Sequential(*conv_layers)
        
        # decoder
        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(.2),
            nn.BatchNorm2d()
        )
        self.hidden4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(.2),
            nn.BatchNorm2d()
        )
        self.hidden5 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, kernel_size=7, stride=1, padding=3), #padding=1?
            nn.LeakyReLU(.2),
            nn.BatchNorm2d()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.res_blocks(x)
        x = self.hidden3(x)
        x = self.hidden4(x)
        x = self.hidden5(x)
        return x

In [29]:
class CGAN_D(nn.Module):
    def __init__(self):
        super(CGAN_D, self).__init__()
        
        self.hidden0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.ReLU()
        )
        self.out = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)
        return x

In [31]:
def g_loss_lsgan(z):
    loss = .5 * torch.mean((z - 1) ** 2)
    return loss

In [34]:
def cycl_loss(ga2b, gb2a, a, b):
    return toch.mean(gb2a(ga2b(a)) - a) + torch.mean(ga2b(gb2a(b)) - b)

In [None]:
#Initialize networks
D_A   = CGAN_D().to(device)
D_B   = CGAN_D().to(device)
G_A2B = CGAN_G().to(device)
G_B2A = CGAN_G().to(device)

# optimizers
d_opt_A   = torch.optim.Adam(D_A.parameters(), 0.0002, (0.5, 0.999))
g_opt_A2B = torch.optim.Adam(G_A2B.parameters(), 0.0001, (0.5, 0.999))
d_opt_B   = torch.optim.Adam(D_B.parameters(), 0.0002, (0.5, 0.999))
g_opt_B2A = torch.optim.Adam(G_B2A.parameters(), 0.0001, (0.5, 0.999))

# losses
criterion_gan   = g_loss_lsgan
criterion_cycle = cycl_loss
criterion_id    = torch.nn.L1Loss() 

# lambdas
lambda_identity = .5
lambda_cycle    = 10

#Clean up all old variables on the GPU no longer in use
torch.cuda.empty_cache()

num_epochs = 10

for epoch in tqdm(range(num_epochs), unit='epoch'):
    for i, (real_A, real_B) in tqdm(enumerate(zip(train_loader_horses, train_loader_zebras)), total=len(train_loader_horses)):
        g_opt_A2B.zero_grad()
        g_opt_B2A.zero_grad()
        
        # identity loss
        replicate_B     = G_A2B(real_B)
        loss_identity_B = criterion_id(replicate_B, real_B) * lambda_identity
        
        replicate_A     = G_B2A(real_A)
        loss_identity_A = criterion_id(replicate_A, real_A) * lambda_identity
        
        