In [0]:
from time import time
from multiprocessing import cpu_count

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils import data
from keras.datasets import mnist

Using TensorFlow backend.


In [0]:
# GPU configuration
use_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if use_gpu else "cpu")

# Get MNIST data inside a DataLoader

In [0]:
class Dataset(data.Dataset):
    def __init__(self, mnist1, mnist2):
        self.mnist1 = mnist1
        self.mnist2 = mnist2

    def __len__(self):
        return self.mnist1.shape[0]

    def __getitem__(self, index):
        return self.mnist1[index], self.mnist2[index]

In [0]:
# Get MNIST pictures as Torch dataloader
(mnist_data, _), (_, _) = mnist.load_data() # We only care about images
mnist_data = mnist_data / 255.
mnist_data = np.expand_dims(mnist_data, axis=3) # Add a channel dimension
# Compute two shuffled MNIST datasets
mnist_data1 = mnist_data[np.random.permutation(mnist_data.shape[0]),:,:]
mnist_data2 = mnist_data[np.random.permutation(mnist_data.shape[0]),:,:]

In [0]:
mnist_dataset = Dataset(mnist_data1, mnist_data2)
mnist_dataloader = data.DataLoader(mnist_dataset, batch_size=128, 
                                   shuffle=True, num_workers=cpu_count())

# OT-GAN

In [0]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2,
                 output_size, output_shape):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size_1)
        self.map2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.map3 = nn.Linear(hidden_size_2, output_size)
        self.f = nn.LeakyReLU(negative_slope=0.2)
        self.output_shape = output_shape

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        x = torch.sigmoid(self.map3(x))
        return torch.reshape(x, (-1,)+self.output_shape)

In [0]:
class Critic(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, output_size):
        super(Critic, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size_1)
        self.map2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.map3 = nn.Linear(hidden_size_2, output_size)
        self.f = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        x = nn.Flatten()(x)
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.map3(x)

In [0]:
class GAN():
    
    def __init__(self, dataloader, gen_params, critic_params, z_dim):

        self.dataloader = dataloader

        # default parameters for mnist 
        self.img_rows = dataloader.dataset[0][0].shape[0]
        self.img_cols = dataloader.dataset[0][0].shape[1]
        self.img_channels = dataloader.dataset[0][0].shape[2]
        self.img_shape = (self.img_rows, self.img_cols, self.img_channels)
        self.z_dim = z_dim
        
        self.generator = Generator(*gen_params).to(device)
        self.critic = Critic(*critic_params).to(device)
        
    def sample_data(self, n_sample=100):
        z_random = np.random.randn(n_sample, self.z_dim)
        z_random = torch.FloatTensor(z_random).to(device)
        samples = self.generator(z_random)
        samples = samples.detach().cpu().numpy()
        return samples

    def sinkhorn(self, a, b, C, reg=0.001, max_iters=100):
    
        K = torch.exp(-C/reg)
        u = torch.ones_like(a).to(device)
        v = torch.ones_like(b).to(device)
        for i in range(max_iters):
            u = a / torch.matmul(K,v)
            v = b / torch.matmul(K.T,u)
        return torch.matmul(torch.diag_embed(u), torch.matmul(K, torch.diag_embed(v)))

    def cost(self, batch_1, batch_2):
        norm_1 = torch.norm(batch_1, p=2, dim=1).reshape(-1,1)
        norm_2 = torch.norm(batch_2, p=2, dim=1).reshape(-1,1)
        return - torch.matmul(batch_1, batch_2.transpose(0,1)) / (torch.matmul(norm_1, norm_2.transpose(0,1))) + 1
        
    def train(self, epochs=1000, print_interval=1):
        
        criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
        c_optimizer = optim.Adam(self.critic.parameters(), lr=0.0001)
        g_optimizer = optim.Adam(self.generator.parameters(), lr=0.0001)
        d_steps = 1
        g_steps = 1
        
        for epoch in range(epochs):
            
            t=time()
            loss_to_display = []
            
            for real_1, real_2 in self.dataloader:
                
                batch_size = real_1.shape[0]

                self.critic.zero_grad()
                self.generator.zero_grad()
    
                real_1 = real_1.type(torch.FloatTensor).to(device)
                real_2 = real_2.type(torch.FloatTensor).to(device)
                
                z1 = torch.FloatTensor(np.random.randn(batch_size, self.z_dim)).to(device)
                fake_1 = self.generator(z1)
                z2 = torch.FloatTensor(np.random.randn(batch_size, self.z_dim)).to(device)
                fake_2 = self.generator(z2)
                
                critic_real_1 = self.critic(real_1)
                critic_real_2 = self.critic(real_2)
                critic_fake_1 = self.critic(fake_1)
                critic_fake_2 = self.critic(fake_2)

                # Computing all matrices of costs

                costs = torch.zeros((4, 4, batch_size, batch_size)).to(device)

                costs[0,1] = self.cost(critic_real_1, critic_real_2)
                costs[0,2] = self.cost(critic_real_1, critic_fake_1)
                costs[0,3] = self.cost(critic_real_1, critic_fake_2)
                costs[1,2] = self.cost(critic_real_2, critic_fake_1)
                costs[1,3] = self.cost(critic_real_2, critic_fake_2)
                costs[2,3] = self.cost(critic_fake_1, critic_fake_2)

                # Computing optimal plans for all costs

                a = (torch.ones(batch_size) / batch_size).to(device)
                b = (torch.ones(batch_size) / batch_size).to(device)
                
                plans = torch.zeros((4,4, batch_size, batch_size)).to(device)
                
                plans[0,1] = self.sinkhorn(a, b, costs[0,1], reg=0.01)
                plans[0,2] = self.sinkhorn(a, b, costs[0,2], reg=0.01)
                plans[0,3] = self.sinkhorn(a, b, costs[0,3], reg=0.01)
                plans[1,2] = self.sinkhorn(a, b, costs[1,2], reg=0.01)
                plans[1,3] = self.sinkhorn(a, b, costs[1,3], reg=0.01)
                plans[2,3] = self.sinkhorn(a, b, costs[2,3], reg=0.01)

                # Computing losses
                
                losses = torch.zeros((4,4)).to(device)
                
                losses[0,1] = torch.sum(plans[0,1] * costs[0,1])
                losses[0,2] = torch.sum(plans[0,2] * costs[0,2])
                losses[0,3] = torch.sum(plans[0,3] * costs[0,3])
                losses[1,2] = torch.sum(plans[1,2] * costs[1,2])
                losses[1,3] = torch.sum(plans[1,3] * costs[1,3])
                losses[2,3] = torch.sum(plans[2,3] * costs[2,3])
                
                
                loss = losses[0,2] + losses[0,3] + losses[1,2] + losses[1,3] - 2 * losses[0,1] - 2 * losses[2,3]
                
                loss.backward()
                c_optimizer.step()
                g_optimizer.step()
                
                loss_to_display.append(float(loss.detach().cpu().numpy()))


            if epoch % print_interval == 0:
                print("Epoch %s: Loss %s;  time (%s)" %
                    (epoch, np.sum(loss_to_display), time()-t))
                            
            if epoch % (print_interval*5) == 0:
                samples = self.sample_data(3)*256.
                for img in samples:
                    plt.figure()
                    plt.imshow(img[:,:,0], cmap='gray')
                    plt.show()



In [0]:
img_shape = mnist_data[0].shape
img_size = img_shape[0] * img_shape[1] * img_shape[2]

z_dim = 32
critic_dim = 32

gen_params = (z_dim, 256, 512, img_size, img_shape)
critic_params = (img_size, 512, 256, critic_dim)

In [0]:
gan = GAN(mnist_dataloader, gen_params, critic_params, z_dim)

In [0]:
gan.train(200)

torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])
torch.Size([4, 4, 128, 128])
torch.Size([128, 128])


KeyboardInterrupt: ignored

In [0]:
samples = gan.sample_data(10)*0.5 + 0.5

In [0]:
for img in samples:
    plt.figure()
    plt.imshow(img[:,:,0], cmap='gray')
    plt.show()