# Sinkhorn Algorithm and Sinkhorn Divergence

In [None]:
import torch

def lse(V):
    v_m = torch.max(V, dim=1, keepdim=True).values
    return v_m + (V - v_m).exp().sum(dim=1, keepdim=True).log()

def sink_maps(x, y, eps, p):
    C = torch.norm(x.unsqueeze(1) - y.unsqueeze(0), dim=2, p=p) / eps
    S_f = lambda g: -lse(g.view(1, -1) - C)
    S_g = lambda f: -lse(f.view(1, -1) - C.T)
    return S_f, S_g

def sink(a, x, b, y, p=2, eps=1, iter=100, tol=1e-3, converge=True):
    a_log, b_log = a.log(), b.log()
    f, g = torch.zeros_like(a), torch.zeros_like(b)
    S_f, S_g = sink_maps(x, y, eps, p)
    with torch.set_grad_enabled(not converge):
        for i in range(iter):
            g_old = g
            f = S_f(g + b_log)
            g = S_g(f + a_log)
            if eps * (g - g_old).abs().mean() < tol: break
    if not converge:
        return eps * S_f(g + b_log), eps * S_g(f + a_log)
    else:
        S_f, _ = sink_maps(x.detach(), y, eps, p)
        _, S_g = sink_maps(x, y.detach(), eps, p)
        return eps * S_f((g + b_log).detach()), eps * S_g((f + a_log).detach())

def entropic_ot(a, x, b, y, **kwargs):
    f, g = sink(a, x, b, y, **kwargs)
    return f.T @ a + g.T @ b

def sinkhorn_divergence(a, x, b, y, **kwargs):
    return entropic_ot(a, x, b, y, **kwargs) \
        - 0.5 * entropic_ot(a, x, a, x, **kwargs) \
        - 0.5 * entropic_ot(b, y, b, y, **kwargs)

# GAN Toy Problem
## Gaussian Mixture Dataset

In [None]:
from math import pi

size = 100
angles = torch.linspace(0, 7/4*pi, 8).view(-1, 1)
shift = 15 * torch.hstack([torch.cos(angles), torch.sin(angles)]).repeat(size, 1)
dataset = torch.randn(8*size, 2) + shift

In [None]:
import matplotlib.pyplot as plt
plt.scatter(dataset[:, 0], dataset[:, 1], s=3)
plt.gca().set_aspect('equal', adjustable='box')

## Vanilla GAN

In [None]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, in_features=2, width=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, width),
            nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, 2)
        )
    
    def forward(self, input):
        return self.model(input)

class Discriminator(nn.Module):
    def __init__(self, in_features=2, width=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, width),
            # nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, width),
            # nn.BatchNorm1d(width),
            nn.ReLU(),
            nn.Linear(width, 1)
        )
    
    def forward(self, input):
        return self.model(input)

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 5000
batch_size = 400
latent_size = 2
width = 1024

# 
dataloader = DataLoader(
    torch.tensor(dataset, device=device), 
    batch_size, shuffle=True
    )

discriminator = Discriminator(2, width).to(device)
generator = Generator(latent_size, width).to(device)

optimizer_D = torch.optim.Adam(discriminator.parameters(), 3e-4)
optimizer_G = torch.optim.Adam(generator.parameters(), 3e-4)

def js_loss_D(real, fake):
    return F.binary_cross_entropy_with_logits(discriminator(real), torch.ones(len(real), 1, device=real.device)) \
        + F.binary_cross_entropy_with_logits(discriminator(fake), torch.zeros(len(fake), 1, device=fake.device))

def js_loss_G(fake):
    return - F.binary_cross_entropy_with_logits(discriminator(fake), torch.zeros(len(fake), 1, device=fake.device))
    # return F.binary_cross_entropy_with_logits(discriminator(fake), torch.ones(len(fake), 1, device=fake.device))

# record_epochs = [0, 10, 20, 30, 40, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, epochs-1]
records = []

for epoch in range(epochs):
    for batch in dataloader:
        for _ in range(5):
            optimizer_D.zero_grad()
            d_loss = js_loss_D(batch, generator(torch.rand(batch_size, latent_size, device=device)))
            d_loss.backward()
            optimizer_D.step()
        for _ in range(1):
            optimizer_G.zero_grad()
            g_loss = js_loss_G(generator(torch.rand(batch_size, latent_size, device=device)))
            g_loss.backward()
            optimizer_G.step()

    if epoch % 500 == 0 or epoch == epochs-1:
        print('Epoch: {} \t JS Loss: {:4f}'.format(epoch, d_loss.item()))
    
        record = generator(torch.rand(800, latent_size, device=device)).cpu().detach().numpy()
        plt.scatter(dataset[:, 0], dataset[:, 1], s=3)
        plt.scatter(record[:, 0], record[:, 1], s=3)
        plt.xlim(-20, 20); plt.ylim(-20, 20)
        plt.title('GAN')
        plt.xlabel('x1')
        plt.ylabel('x2')
        plt.gca().set_aspect('equal', adjustable='box')
        plt.show()

## EGAN

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 5000 #10000
batch_size = 400 #800
latent_size = 2
width = 1024

dataloader = DataLoader(
    torch.tensor(dataset, device=device), 
    batch_size, shuffle=True
    )
generator = Generator(latent_size, width).to(device)
optimizer = Adam(generator.parameters(), lr=3e-4)

record_epochs = [0, 10, 20, 30, 40, 50, 100, 300, 500, 1000, 5000, 9000, epochs-1]
records = []

for epoch in range(epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        loss = sinkhorn_divergence(
            torch.ones(len(batch), 1, device=device) / len(batch), 
            batch, 
            torch.ones(batch_size, 1, device=device) / batch_size, 
            generator(torch.rand(batch_size, latent_size, device=device)),
            eps=50 if epoch < 500 else 5e-4, iter=1000
            )
        loss.backward()
        optimizer.step()

    if epoch % 500 == 0 or epoch == epochs-1:
        print('Epoch: {} \t Sinkhorn Divergence: {:4f}'.format(epoch, loss.item()))

        record = generator(torch.rand(800, latent_size, device=device)).cpu().detach().numpy()
        plt.scatter(dataset[:, 0], dataset[:, 1], s=3)
        plt.scatter(record[:, 0], record[:, 1], s=3)
        plt.xlim(-20, 20); plt.ylim(-20, 20)
        plt.title('Sinkhorn EGAN')
        plt.xlabel('x1')
        plt.ylabel('x2')
        plt.gca().set_aspect('equal', adjustable='box')
        plt.show()
