In [13]:
import numpy as np
import os
import PIL.Image as Image

from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from IPython import display
import matplotlib.pylab as plt
import ipywidgets

In [21]:
if torch.cuda.is_available():
    print("The code will run on GPU.")
else:
    print("The code will run on CPU.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The code will run on GPU.


In [22]:
class Horses(torch.utils.data.Dataset):
    def __init__(self, train, transform, data_path='/scratch/horse2zebra'):
        'Initialization'
        self.transform    = transform
        
        data_path         = os.path.join(data_path, 'train' if train else 'test')
        data_path         = os.path.join(data_path, 'A')
        list_files_horses = os.listdir(data_path)
        self.image_paths  = [os.path.join(data_path, file) for file in list_files_horses]
        
    def __len__(self):
        'Returns the total number of samples'
        return len(self.image_paths)

    def __getitem__(self, idx):
        'Generates one sample of data'
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        X = self.transform(image)
        return X

In [23]:
class Zebras(torch.utils.data.Dataset):
    def __init__(self, train, transform, data_path='/scratch/horse2zebra'):
        'Initialization'
        self.transform    = transform
        data_path         = os.path.join(data_path, 'train' if train else 'test')
        data_path         = os.path.join(data_path, 'B')
        list_files_zebras = os.listdir(data_path)
        self.image_paths  = [os.path.join(data_path, file) for file in list_files_zebras]
        
        self.image_paths
        
    def __len__(self):
        'Returns the total number of samples'
        return len(self.image_paths)

    def __getitem__(self, idx):
        'Generates one sample of data'
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        X = self.transform(image)
        return X

In [25]:
batch_size = 64
size = 256

transformer_train_horses = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

transformer_train_zebras = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

transformer_test_horses = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

transformer_test_zebras = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize([.5,.5,.5], [.5,.5,.5])
])

train_set_horses    = Horses(True, transformer_train_horses)
train_loader_horses = DataLoader(train_set_horses, batch_size=batch_size, shuffle=True, num_workers=3)

test_set_horses     = Horses(False, transformer_test_horses)
test_loader_horses  = DataLoader(test_set_horses, batch_size=batch_size, shuffle=True, num_workers=3)

train_set_zebras    = Zebras(True, transformer_train_zebras)
train_loader_zebras = DataLoader(train_set_zebras, batch_size=batch_size, shuffle=True, num_workers=3)

test_set_zebras     = Horses(False, transformer_test_zebras)
test_loader_zebras  = DataLoader(test_set_zebras, batch_size=batch_size, shuffle=True, num_workers=3)

<hr>

<h2> Networks

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, n_features):
        super(ResNetBlock, self).__init__()
        self.resblock = nn.Sequential(
            nn.Conv2d(n_features,n_features,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_features,n_features,3,padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.resblock(x) + x
        x = F.relu(x)
        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, n_in, n_features, num_res_blocks=5):
        super(ResNet, self).__init__()
        #First conv layers needs to output the desired number of features.
        conv_layers = [nn.Conv2d(n_in, n_features, kernel_size=3, stride=1, padding=1),
                       nn.ReLU()]
        for i in range(num_res_blocks):
            conv_layers.append(ResNetBlock(n_features))
        self.res_blocks = nn.Sequential(*conv_layers)
        self.fc = nn.Sequential(nn.Linear(32*32*n_features, 2048),
                                nn.ReLU(),
                                nn.Linear(2048, 512),
                                nn.ReLU(),
                                nn.Linear(512,10),
                                nn.Softmax(dim=1))
        
    def forward(self, x):
        x = self.res_blocks(x)
        #reshape x so it becomes flat, except for the first dimension (which is the minibatch)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return out

In [27]:
class CGAN_G(nn.Module):
    def __init__(self):
        super(CGAN_G, self).__init__()
        
        # encoder
        self.hidden0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU()
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        
        # transformer
        num_res_blocks = 9
        conv_layers = [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                       nn.ReLU()]
        for i in range(num_res_blocks):
            conv_layers.append(ResNetBlock(n_features))
        self.res_blocks = nn.Sequential(*conv_layers)
        
        # decoder
        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden5 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.res_blocks(x)
        x = self.hidden3(x)
        x = self.hidden4(x)
        x = self.hidden5(x)
        return x

In [29]:
class CGAN_D(nn.Module):
    def __init__(self):
        super(CGAN_D, self).__init__()
        
        self.hidden0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.ReLU()
        )
        self.out = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)
        return x

In [30]:
def d_loss_lsgan(x,z):
    loss = .5 * torch.mean((x - 1) ** 2) + .5 * torch.mean(z ** 2)
    return loss

In [31]:
def g_loss_lsgan(z):
    loss = .5 * torch.mean((z - 1) ** 2)
    return loss

In [None]:
#Initialize networks
D_A = CGAN_D().to(device)
D_B = CGAN_D().to(device)
G_A = CGAN_G().to(device)
G_B = CGAN_G().to(device)

#d_opt = torch.optim.SGD(d.parameters(), 0.002, momentum=0.9)
d_opt_a = torch.optim.Adam(D_A.parameters(), 0.0002, (0.5, 0.999))
g_opt_a = torch.optim.Adam(G_A.parameters(), 0.0001, (0.5, 0.999))
d_opt_b = torch.optim.Adam(D_B.parameters(), 0.0002, (0.5, 0.999))
g_opt_b = torch.optim.Adam(G_B.parameters(), 0.0001, (0.5, 0.999))

#Clean up all old variables on the GPU no longer in use
torch.cuda.empty_cache()

visualization = ipywidgets.Output()
display.display(visualization)

with visualization:
    plt.figure(figsize=(20,10))
subplots = [plt.subplot(2, 6, k+1) for k in range(12)]
num_epochs = 10
for epoch in tqdm(range(num_epochs), unit='epoch'):
    for minibatch_no, (x, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.to(device)*2-1
        
        #Update Discriminator
        d.zero_grad()
        with torch.no_grad(): #We don't need gradients for G when we update D
            xhat = g(z)

        d_loss = d_loss_vanilla(d(x), d(xhat))
        d_loss.backward()
        d_opt.step()

        #Update Generator
        z = torch.randn(batch_size, 100).cuda()
        g.zero_grad()
        xhat = g(z)
        g_loss = g_loss_vanilla(d(xhat))
        g_loss.backward()
        g_opt.step()
        
        assert(not np.isnan(d_loss.item()))
        #Plot every 100 minibatches
        if minibatch_no % 100 == 0:
            with torch.no_grad(), visualization:
                P = d(xhat)
                for k in range(11):
                    xhat_k = xhat[k].cpu().squeeze()/2+.5
                    subplots[k].imshow(xhat_k, cmap='gray')
                    subplots[k].set_title('P=%.2f' % P[k])
                    subplots[k].axis('off')

                H1 = d(g(torch.randn(batch_size, 100).to(device)))
                H2 = d(x)
               
                subplots[11].cla()
                subplots[11].hist(H1.cpu().squeeze(), label='fake', range=(0, 1), alpha=0.5)
                subplots[11].hist(H2.cpu().squeeze(), label='real', range=(0, 1), alpha=0.5)
                subplots[11].legend()
                subplots[11].set_title('Discriminator loss: %.2f' % d_loss.item())

                display.display(plt.gcf())
                display.clear_output(wait=True)
                