In [None]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


In [None]:
# load MNIST
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5), (.5)) #, .5, .5
        ])
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)
    
# Load data
mnistdata = mnist_data()

# Generator
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        ).to(device)
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        ).to(device)
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        ).to(device)
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        ).to(device)

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

# images to vector
def images_to_vectors(images):
    return images.view(images.size(0), 784).to(device)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)      

def noise(size):
    n = Variable(torch.randn(size, 100, device=device))
    return n

def batch_imshow(vector_batch):
  imgs = vectors_to_images(vector_batch).clone().detach()
  fig, axs = plt.subplots(1, 10)
  for i in range(10):
    axs[i].imshow(imgs[i,0,:,:])
    axs[i].axis('off')
  plt.show()
  return

In [None]:
# discriminator definition
class entropic_OT(nn.Module):    
    def __init__(self, y, lambd):
        super(entropic_OT, self).__init__()
        self.yt = y.transpose(1,0).requires_grad_(False).to(device)
        self.sy2 = torch.sum(self.yt**2,0,keepdim=True)
        self.psi = nn.Parameter(torch.zeros(y.size(0),  device=device))
        self.lambd = lambd

    def forward(self, input):
        cxy = (torch.sum(input**2,1,keepdim=True) +  self.sy2 - 2*torch.matmul(input,self.yt))/self.yt.size(0)#(torch.sqrt(torch.sum(input**2,1,keepdim=True) +  self.sy2 - 2*torch.matmul(input,self.yt)))**self.p
        # test = -self.lambd*torch.logsumexp((self.psi.unsqueeze(0)-cxy)/self.lambd,1)
        # print(test.shape)
        # print(cxy.shape)
        # print(torch.max(cxy))
        # print(torch.min(cxy))
        # test = (self.psi.unsqueeze(0)-cxy)/self.lambd
        # print(test.max(0))
        # print(test.min(0))
        # print(torch.max( -self.lambd*torch.logsumexp((self.psi.unsqueeze(0)-cxy)/self.lambd,1) ))
        # print(torch.min( -self.lambd*torch.logsumexp((self.psi.unsqueeze(0)-cxy)/self.lambd,1) ))
        if self.lambd > 0:
            output = -self.lambd*torch.logsumexp((self.psi.unsqueeze(0)-cxy)/self.lambd,1)# + torch.mean(self.psi)
        else:
            output = torch.min(cxy - self.psi.unsqueeze(0),1)[0]
        return output

In [None]:
def train_discriminator(discriminator, optimizer, input_data):
    # Reset gradients
    optimizer.zero_grad()    
    # Train
    prediction = discriminator(input_data)
    # Calculate error and backpropagate
    error = -torch.mean(prediction)
    error.backward() 
    # print(discriminator.psi.grad)
    # Update weights with gradients
    optimizer.step()    
    # discriminator.psi.data-=torch.mean(discriminator.psi.data) # Return error and predictions for real and fake inputs

    return error

def train_generator(discriminator, optimizer, input_data):
    optimizer.zero_grad()    # Sample noise and generate fake data
    prediction = discriminator(input_data)    # Calculate error and backpropagate
    error = torch.mean(prediction)
    error.backward()    # Update weights with gradients
    optimizer.step()    # Return error
    return error

In [None]:
# initializer generator and discriminator
full_data = torch.utils.data.DataLoader(mnistdata, batch_size=5000, shuffle=True) ## mettre 60000 pour toute la BDD (plus lent)

lambd = 0.01
batch_data, _ = next(iter(full_data))
print(batch_data.shape)
discriminator = entropic_OT(images_to_vectors(batch_data), lambd)

# for n_batch, (real_batch,_) in enumerate(full_data):
#     print(images_to_vectors(real_batch).shape)
#     print(torch.min(images_to_vectors(real_batch)))
#     discriminator = entropic_OT(images_to_vectors(real_batch), lambd)

In [None]:
 generator = GeneratorNet()

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.2)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)


In [None]:
# Total number of epochs to train
num_epochs = 1000
batch_size = 200
n_iter_psi = 100
# discriminator.lambd = 0.1
T = time.time()

for epoch in range(num_epochs):
   

    # 1. Train Discriminator
    d_optimizer = optim.ASGD(discriminator.parameters(), lr=0.8, alpha=0.5, t0=1)
    for it in range(n_iter_psi):
        fake_data = generator(noise(batch_size)).detach()      
        # Train D
        d_error = train_discriminator(discriminator, d_optimizer, fake_data)
    discriminator.psi.data = d_optimizer.state[discriminator.psi]['ax']
    discriminator.psi.data-=torch.mean(discriminator.psi.data)
    # 2. Train Generator        
    # Generate fake data
    fake_data = generator(noise(batch_size))        
    # Train G
    g_error = train_generator(discriminator, g_optimizer, fake_data)        

    # print('epoch time = '+str(time.time()-T)+'s')
    # plot
    if epoch % 10 == 0:
        # discriminator.lambd *= 0.5
        print("epoch {}:".format(epoch))
        print("Lambda = {}:".format(discriminator.lambd))
        print("Elapsed time {}:".format(time.time()-T))
        print('G error : {:4f}'.format(g_error.item()))
        print('D error : {:4f}'.format(d_error.item()))
        print(discriminator.psi)
        batch_imshow(fake_data.cpu())
