<a href="https://colab.research.google.com/github/ZichuLiu/NER-pytorch/blob/master/Multi_Gaussian_Exp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as functional
from torch.distributions.normal import Normal
from torch.autograd import grad, Variable, backward
from torch.utils.data import DataLoader, RandomSampler

from torchvision import transforms
from torchvision.datasets import mnist

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import scipy as sp
import scipy.stats

import seaborn as sns
import pandas as pd

from time import time
import os
import socket
from datetime import datetime


Define Discriminator and Generator



In [15]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=6):
        super(Discriminator, self).__init__()
        layers = [nn.Linear(input_size, hidden_size)]
        for i in range(num_layers - 2):
            layers.extend([nn.Tanh(),
                           nn.Linear(hidden_size, hidden_size)])
        layers.append(nn.Linear(hidden_size, output_size).cuda())
        self.net = nn.Sequential(*layers)
        self.output = torch.nn.Sigmoid()

    def forward(self, x):
        return self.output(self.net(x))

    def get_penalty(self, x_true, x_gen):
        x_true = x_true.view_as(x_gen).cuda()
        alpha = torch.rand((len(x_true),) + (1,) * (x_true.dim() - 1))
        if x_true.is_cuda:
            alpha = alpha.cuda(x_true.get_device())
        x_penalty = Variable(alpha * x_true + (1 - alpha) * x_gen, requires_grad=True).cuda()
        p_penalty = self.forward(x_penalty)
        gradients = grad(p_penalty, x_penalty, grad_outputs=torch.ones_like(p_penalty).cuda(
            x_true.get_device()) if x_true.is_cuda else torch.ones_like(p_penalty), create_graph=True,
                         retain_graph=True, only_inputs=True)[0]
        penalty = ((gradients.view(len(x_true), -1).norm(2, 1) - 1) ** 2).mean()

        return penalty

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=6):
        super(Generator, self).__init__()
        layers = [nn.Linear(input_size, hidden_size)]
        for i in range(num_layers-2):
            layers.extend([nn.Tanh(),
                          nn.Linear(hidden_size, hidden_size)])
        layers.append(nn.Linear(hidden_size, output_size).cuda())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

Define Sinkhorn distance and related utility functions

In [11]:
def sinkhorn_normalized(x, y, epsilon, n, niter,p=1):
    Cxy = cost_matrix(x, y, p)
    Wxy, pi = sinkhorn_loss(x, y, epsilon, n, niter, C_Matrix=Cxy)
    Cxx = cost_matrix(x, x, p)
    Wxx, pi_x = sinkhorn_loss(x, x, epsilon, n, niter, C_Matrix=Cxx)
    Cyy = cost_matrix(y, y, p)
    Wyy, pi_y = sinkhorn_loss(y, y, epsilon, n, niter, C_Matrix=Cyy)
    return 2 * Wxy - Wxx - Wyy, (pi, pi_x, pi_y)


def mixed_sinkhorn_normalized(x, y, fx, fy, epsilon, n, niter, p=1, nfac=1.0):
    if len(x.shape) > 2:
        x = x.view(x.shape[0], -1).cuda()
        y = y.view(y.shape[0], -1).cuda()
    Cxy = cost_matrix(x, y, p) + nfac * RKHS_Norm(fx, fy)
    # Cxy = cost_matrix(x, y, p) + cost_matrix(fx, fy, p)
    Wxy, pi = sinkhorn_loss(x, y, epsilon, n, niter, C_Matrix=Cxy)
    Cxx = cost_matrix(x, x, p) + nfac * RKHS_Norm(fx,fx)
    # Cxx = cost_matrix(x, x, p) + cost_matrix(fx, fx, p)
    Wxx, pi_x = sinkhorn_loss(x, x, epsilon, n, niter, C_Matrix=Cxx)
    Cyy = cost_matrix(y, y, p) + nfac * RKHS_Norm(fy,fy)
    # Cyy = cost_matrix(y, y, p) + cost_matrix(fy, fy, p)
    Wyy, pi_y = sinkhorn_loss(y, y, epsilon, n, niter, C_Matrix=Cyy)
    # Wxx = 0
    # Wyy = 0
    return 2 * Wxy - Wxx - Wyy, pi


def sinkhorn_loss(x, y, epsilon, n, niter, C_Matrix):
    """
    Given two emprical measures with n points each with locations x and y
    outputs an approximation of the OT cost with regularization parameter epsilon
    niter is the max. number of steps in sinkhorn loop
    """

    # The Sinkhorn algorithm takes as input three variables :

    C = C_Matrix  # Wasserstein cost function
    # both marginals are fixed with equal weights
    mu = Variable(1. / n * torch.cuda.FloatTensor(n).fill_(1), requires_grad=False)
    nu = Variable(1. / n * torch.cuda.FloatTensor(n).fill_(1), requires_grad=False)

    # Parameters of the Sinkhorn algorithm.
    rho = 1  # (.5) **2          # unbalanced transport
    tau = -.8  # nesterov-like acceleration
    lam = rho / (rho + epsilon)  # Update exponent
    thresh = 10 ** (-1)  # stopping criterion

    # Elementary operations .....................................................................
    def ave(u, u1):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

    def M(u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(1) + v.unsqueeze(0)) / epsilon

    def lse(A):
        "log-sum-exp"
        return torch.logsumexp(A, dim=1, keepdim=True)

    # Actual Sinkhorn loop ......................................................................
    u, v, err = 0. * mu, 0. * nu, 0.
    actual_nits = 0  # to check if algorithm terminates because of threshold or max iterations reached

    for i in range(niter):
        u1 = u  # useful to check the update
        u = epsilon * (torch.log(mu) - lse(M(u, v)).squeeze()) + u
        v = epsilon * (torch.log(nu) - lse(M(u, v).t()).squeeze()) + v
        # accelerated unbalanced iterations
        # u = ave( u, lam * ( epsilon * ( torch.log(mu) - lse(M(u,v)).squeeze()   ) + u ) )
        # v = ave( v, lam * ( epsilon * ( torch.log(nu) - lse(M(u,v).t()).squeeze() ) + v ) )
        err = (u - u1).abs().sum()

        actual_nits += 1
        if (err.data.tolist() < thresh):
            break
    U, V = u, v
    pi = torch.exp(M(U, V))  # Transport plan pi = diag(a)*K*diag(b)
    cost = torch.sum(pi * C)  # Sinkhorn cost

    return cost, pi


def cost_matrix(x, y, p=1):
    "Returns the matrix of $|x_i-y_j|^p$."
    x_col = x.unsqueeze(1).cuda()
    y_lin = y.unsqueeze(0).cuda()
    c = torch.sum((torch.abs(x_col - y_lin)).cuda() ** p, 2)
    return c


def RBF_Kernel(fx, fy, gamma):
    "Returns the matrix of $exp(-gamma * |x_i-y_j|^2)$."
    x_col = fx.unsqueeze(1).cuda()
    y_lin = fy.unsqueeze(0).cuda()
    c = torch.norm(torch.abs(x_col - y_lin), p=1, dim=2)
    # c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2)
    RBF_K = torch.exp(-gamma * c)
    return RBF_K


def RKHS_Norm(x, y, gamma=0.5):
    Kxy = RBF_Kernel(x, y, gamma)
    return 1 + 1 - 2 * Kxy

Define real sample generators and noise sampler

In [12]:
def x_real_builder(batch_size):
    sigma = .01
    skel = np.array([
        [2.0, 2.0],
        [2.0, 1.0],
        [2.0, 0.0],
        [2.0, -1.0],
        [2.0, -2.0],
        [1.0, 2.0],
        [1.0, 1.0],
        [1.0, 0.0],
        [1.0, -1.0],
        [1.0, -2.0],
        [0.0, 2.0],
        [0.0, 1.0],
        [0.0, 0.0],
        [0.0, -1.0],
        [0.0, -2.0],
        [-1.0, 2.0],
        [-1.0, 1.0],
        [-1.0, 0.0],
        [-1.0, -1.0],
        [-1.0, -2.0],
        [-2.0, 2.0],
        [-2.0, 1.0],
        [-2.0, 0.0],
        [-2.0, -1.0],
        [-2.0, -2.0],
    ])
    temp = np.tile(skel, (batch_size // 25 + 1, 1))
    mus = temp[0:batch_size, :]
    m = Normal(torch.FloatTensor([.0]), torch.FloatTensor([sigma]))
    samples = m.sample((batch_size, 2))
    samples = samples.view((batch_size, 2))
    return samples.new(mus) + samples  # * .2


def get_noise_sampler():
    return lambda m, n: torch.randn(m, n).requires_grad_().cuda()

Configurations

In [13]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--exp_name", type=str, default='Sinkhorn_GAN')
parser.add_argument("--optim", type=str, default="ema", help="optimization algorithm to use")
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--g_lr", type=float, default=1e-3)
parser.add_argument("--d_lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=2e-5)
parser.add_argument("--grad_penalty", type=float, default=1.0)
parser.add_argument("--mix_metric_flag", type=int, default=1)
parser.add_argument("--nonlinear_OT_flag", type=int, default=0)
parser.add_argument("--nonlinear_fac", type=float, default=1.0)

parser.add_argument("--num_epochs", type=int, default=2000000)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=500)
parser.add_argument("--print_interval", type=int, default=1000)
parser.add_argument("--vis_interval", type=int, default=1000)

parser.add_argument("--latent_dim", type=int, default=64)
parser.add_argument("--hidden_size", type=int, default=384)
parser.add_argument("--num_layers", type=int, default=6)
parser.add_argument("--seed", type=int, default=0)

args = parser.parse_args(args=[])
print("see all args:", args)

see all args: Namespace(alpha=0.2, batch_size=500, d_lr=0.0001, exp_name='Sinkhorn_GAN', g_lr=0.001, grad_penalty=1.0, hidden_size=384, latent_dim=64, mix_metric_flag=1, nonlinear_OT_flag=0, nonlinear_fac=1.0, num_epochs=2000000, num_layers=6, num_workers=4, optim='ema', print_interval=1000, seed=0, vis_interval=1000, weight_decay=2e-05)


In [3]:
args.exp_name

'Sinkhorn_GAN'

main training loop

In [16]:
expname = args.exp_name
current_time = datetime.now().strftime('%Y-%m-%d_%H')
log_dir = os.path.join('runs', expname + "_" + current_time + "_" + socket.gethostname(),
                           'mix' + str(args.mix_metric_flag) + 'fac' + str(args.nonlinear_fac),
                           'nonlinear' + str(args.nonlinear_OT_flag))
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(log_dir, "args.txt"), "w") as fp:
    for arg in vars(args):
        fp.write("%s:%s \n" % (arg, str(getattr(args, arg))))

np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

epsilon = .01

g_input_size = args.latent_dim

d_minibatch_size = args.batch_size

num_epochs = args.num_epochs

d_learning_rate = args.d_lr
g_learning_rate = args.g_lr

noise_data = get_noise_sampler()

g_hidden_size = args.hidden_size
g_output_size = 2

d_input_size = 2  
d_hidden_size = args.hidden_size
d_output_size = 32

G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size,
              num_layers=args.num_layers).cuda()
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size,
                  num_layers=args.num_layers).cuda()
g_optimizer = torch.optim.Adam(G.parameters(), lr=args.g_lr, weight_decay=args.weight_decay)
d_optimizer = torch.optim.Adam(D.parameters(), lr=args.d_lr, weight_decay=args.weight_decay)

z_test = torch.rand((d_minibatch_size, g_input_size)).cuda()
z_test = Variable((z_test - 0.5) * 2)

for epoch in range(num_epochs):
    z = torch.rand((d_minibatch_size, g_input_size)).cuda()
    z = Variable((z - 0.5) * 2)
    images = x_real_builder(d_minibatch_size).float().cuda()
    generated_imgs = G.forward(z)

    if args.mix_metric_flag:
        nfac = args.nonlinear_fac
        D_fake = D.forward(generated_imgs)
        D_real = D.forward(images)
        D_fake = D_fake.view(D_fake.shape[0], -1)
        D_real = D_real.view(D_real.shape[0], -1)

        if epoch % 5 == 0:
            for param in D.parameters():
                param.requires_grad = True
            for param in G.parameters():
                param.requires_grad = False

            loss, _ = mixed_sinkhorn_normalized(generated_imgs, images, D_fake, D_real, epsilon, d_minibatch_size,
                                                500, nfac=nfac)
            d_optimizer.zero_grad()
            if args.grad_penalty:
                grad_penalty = D.get_penalty(images, generated_imgs)
                D_loss = -loss + args.grad_penalty * grad_penalty
            else:
                D_loss = -loss
            D_loss.backward()
            d_optimizer.step()

        else:
            for param in D.parameters():
                param.requires_grad = False
            for param in G.parameters():
                param.requires_grad = True
            g_loss, _ = mixed_sinkhorn_normalized(generated_imgs, images, D_fake, D_real, epsilon, d_minibatch_size,
                                                  500, nfac=nfac)
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            print(g_loss.data.tolist())

    elif args.nonlinear_OT_flag:

        D_fake = D.forward(generated_imgs)
        D_real = D.forward(images)
        D_fake = D_fake.view(D_fake.shape[0], -1)
        D_real = D_real.view(D_real.shape[0], -1)

        if epoch % 5 == 0:
            for param in D.parameters():
                param.requires_grad = True
            for param in G.parameters():
                param.requires_grad = False

            loss, _ = sinkhorn_normalized(D_fake, D_real, epsilon, d_minibatch_size, 500)
            d_optimizer.zero_grad()
            if args.grad_penalty:
                grad_penalty = D.get_penalty(images, generated_imgs)
                D_loss = -loss + args.grad_penalty * grad_penalty
            else:
                D_loss = -loss
            D_loss.backward()
            d_optimizer.step()

        else:
            for param in D.parameters():
                param.requires_grad = False
            for param in G.parameters():
                param.requires_grad = True
            g_loss, _ = sinkhorn_normalized(D_fake, D_real, epsilon, d_minibatch_size, 500)
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            print(g_loss.data.tolist())
    else:
        g_loss, _ = sinkhorn_normalized(generated_imgs, images, epsilon, d_minibatch_size, 500)
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        print(g_loss.data.tolist())

    if epoch % 5 == 0:
        images = x_real_builder(d_minibatch_size)
        fake_data = G.forward(z_test)
        fake_data = [item.data.tolist() for item in fake_data]
        real_data = [item.data.tolist() for item in images]
        X = [-2, -1, 0., 1, 2]
        Y = [-2, -1, 0., 1, 2]
        fig, axes = plt.subplots(1, 1)
        for x in X:
            for y in Y:
                axes.plot(x, y, 'go')
        for item in fake_data:
            axes.plot(item[0], item[1], 'b.')
        for item in real_data:
            axes.plot(item[0], item[1], 'r.')

        axes.grid()
        fig.savefig(os.path.join(log_dir, "gauss_iter_%i.jpg" % epoch))

        plt.close()

5.248564720153809
5.249998092651367
4.9268574714660645
4.337074279785156
3.2011466026306152
1.8678901195526123
1.9524468183517456
2.383619546890259
1.8689770698547363
1.702895998954773
1.6313105821609497
1.6742336750030518
1.7272125482559204
1.9923789501190186
2.014218807220459


KeyboardInterrupt: ignored