In [2]:
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal

from tqdm import tqdm_notebook as tqdm
from collections import OrderedDict

In [3]:
data_fname = 'data.npz'
load_data = np.load(data_fname)
true_pi = load_data['pi']
true_mu = load_data['mu']
samples = load_data['samples']
adv_sample = load_data['adv_sample']

In [4]:
true_pi, true_mu, samples[::50], adv_sample

(array([0.23705428, 0.06574439, 0.39978303, 0.07123911, 0.22617919]),
 array([[2.02889543, 0.58387321],
        [1.48504397, 3.96130589],
        [1.26997612, 1.46258317],
        [3.4355634 , 1.58225174],
        [1.47788008, 3.48312234]]),
 array([[ 4.17993725,  1.24175583],
        [ 2.68098744,  0.32334823],
        [ 1.41416657,  0.3940763 ],
        [ 3.04086445,  4.25610681],
        [ 1.87104121,  4.03256269],
        [ 5.0214534 ,  1.57025624],
        [ 3.1324006 ,  1.80930315],
        [ 3.48725692,  2.19854219],
        [ 2.67629033, -0.01613154],
        [ 2.5919368 ,  2.00105177],
        [-0.10403355,  0.37719824],
        [ 3.65484623,  1.83349119],
        [ 0.27127926,  2.74596415],
        [ 1.16614525,  1.58661359],
        [ 1.47989679,  0.89424985],
        [ 0.9680544 ,  1.00286455],
        [ 4.67228469,  2.75282497],
        [ 0.3507703 ,  3.45548201],
        [ 2.11184957,  2.31217297],
        [ 2.10118491,  4.49285233]]),
 array([-1., -1.]))

In [5]:
def MoG_prob(x, pi, mu, cov):
    K, dim = mu.size()
    assert x.size() == (dim,)
    assert pi.size() == (K,)
    assert cov.size() == (K, dim, dim)
    
    priors = torch.softmax(pi, dim=0)
    
    prob = 0.0
    for k in range(K):
        cov2 = torch.matmul(cov[k].t(), cov[k])
        log_prob_k = -dim * 0.5 * math.log(2 * math.pi) - 0.5 * cov2.logdet() - 0.5 * cov2.inverse().matmul(x - mu[k]).dot(x - mu[k])
        prob += torch.exp(log_prob_k) * priors[k]
    return prob

def MoG_loss(X, z, pi, mu, cov, lam):
    # z: adv_sample
    K, dim = mu.size()
    assert X.size(1) == dim
    assert pi.size() == (K,)
    assert cov.size() == (K, dim, dim)
    
    loss = lam * MoG_prob(z, pi, mu, cov)
    for x in X:
        loss -= torch.log(MoG_prob(x, pi, mu, cov))
    return loss

def GD_solver(samples, adv_sample, lam=2.0, max_step=10000, lr=0.001, K=5, dim=2):
    X = torch.FloatTensor(samples[:10])
    z = torch.FloatTensor(adv_sample)

    pi = torch.rand(K, dtype=torch.float)
    pi /= pi.sum()
    pi.requires_grad_()
    mu = torch.randn(K, dim, dtype=torch.float)
    mu.requires_grad_()
    cov = torch.eye(dim, dtype=torch.float).repeat(K, 1, 1)
    cov.requires_grad_()

    params = [pi, mu, cov]
    named_params = OrderedDict([('pi', pi), ('mu', mu), ('cov', cov)])

    print('*** Init ***')
    for n, p in named_params.items():
        print(n)
        print(p)
        def _hook(grad, p=p, n=n):
            if torch.isnan(grad).sum() > 0:
                print('Error in grad:', grad)
                print('Shape:', grad.size())
                print('Parameter name:', n)
                print('p.requires_grad:', p.requires_grad)
                raise ValueError
        p.register_hook(_hook)
    print('loss:')
    print(MoG_loss(X, z, pi, mu, cov, lam=lam))

    optimizer = optim.SGD(params, lr=lr)

    for step in tqdm(range(max_step)):
        optimizer.zero_grad()
        loss = MoG_loss(X, z, pi, mu, cov, lam=lam)
        loss.backward()
        optimizer.step()

        if (step + 1) % 1000 == 0:
            print('Step {}'.format(step + 1))
            for n, p in named_params.items():
                print(n)
                print(p)
            print('loss:')
            print(loss)
    
    return pi, mu, cov

In [8]:
lam=2.0
max_step=100000
lr=0.001
K=5
dim=2

X = torch.FloatTensor(samples)
z = torch.FloatTensor(adv_sample)

pi = torch.rand(K, dtype=torch.float)
pi /= pi.sum()
pi.requires_grad_()
mu = torch.randn(K, dim, dtype=torch.float)
mu.requires_grad_()
cov = torch.eye(dim, dtype=torch.float).repeat(K, 1, 1)
cov.requires_grad_()

params = [pi, mu, cov]
named_params = OrderedDict([('pi', pi), ('mu', mu), ('cov', cov)])

print('*** Init ***')
for n, p in named_params.items():
    print(n)
    print(p)
    def _hook(grad, p=p, n=n):
        if torch.isnan(grad).sum() > 0:
            print('Error in grad:', grad)
            print('Shape:', grad.size())
            print('Parameter name:', n)
            print('p.requires_grad:', p.requires_grad)
            raise ValueError
    p.register_hook(_hook)
print('loss:')
print(MoG_loss(X, z, pi, mu, cov, lam=lam))

optimizer = optim.SGD(params, lr=lr)

for step in tqdm(range(max_step)):
    optimizer.zero_grad()
    loss = MoG_loss(X, z, pi, mu, cov, lam=lam)
    loss.backward()
    optimizer.step()

    if (step + 1) % 10 == 0:
        print('Step {}'.format(step + 1))
        for n, p in named_params.items():
            print(n)
            print(p)
        print('loss:')
        print(loss)

*** Init ***
pi
tensor([0.0526, 0.0831, 0.4028, 0.0535, 0.4080], requires_grad=True)
mu
tensor([[-0.0834,  1.5534],
        [-1.1490,  0.1361],
        [ 1.1723, -0.9099],
        [-2.2406, -0.0429],
        [-0.3928,  1.9011]], requires_grad=True)
cov
tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], requires_grad=True)
loss:
tensor(6408.6011, grad_fn=<ThSubBackward>)


HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))

Step 10
pi
tensor([ 0.6713, -0.7028,  0.8155, -0.8786,  1.0946], requires_grad=True)
mu
tensor([[ 0.9636,  2.0920],
        [-0.8570,  0.2619],
        [ 1.9601,  0.1706],
        [-2.1903, -0.0256],
        [ 0.8033,  2.6622]], requires_grad=True)
cov
tensor([[[1.7688, 0.4196],
         [0.3710, 1.5238]],

        [[1.2724, 0.1534],
         [0.1539, 0.9766]],

        [[1.7543, 0.6018],
         [0.6900, 1.3248]],

        [[1.0841, 0.0351],
         [0.0351, 1.0015]],

        [[1.6517, 0.6080],
         [0.4705, 1.7309]]], requires_grad=True)
loss:
tensor(3951.1877, grad_fn=<ThSubBackward>)
Step 20
pi
tensor([ 0.7717, -1.0206,  1.0013, -1.2517,  1.4992], requires_grad=True)
mu
tensor([[ 1.4133,  2.2651],
        [-0.7398,  0.3142],
        [ 2.4006,  0.5993],
        [-2.1628, -0.0161],
        [ 1.4071,  2.9561]], requires_grad=True)
cov
tensor([[[ 1.4023,  0.1861],
         [ 0.1094,  1.4448]],

        [[ 1.3342,  0.1946],
         [ 0.1962,  0.9496]],

        [[ 1.5258,  0.266

KeyboardInterrupt: 

In [6]:
pi, mu, cov = GD_solver(samples, adv_sample)

*** Init ***
pi
tensor([0.1639, 0.1345, 0.0582, 0.4190, 0.2243], requires_grad=True)
mu
tensor([[-0.2219,  1.0180],
        [-0.5286,  2.2717],
        [-0.8652, -0.3107],
        [-1.0530, -0.1700],
        [-1.1053,  1.3982]], requires_grad=True)
cov
tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], requires_grad=True)
loss:
tensor(69.5848, grad_fn=<ThSubBackward>)


HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

Step 1000
pi
tensor([ 1.3371,  0.6631, -0.2347, -0.1623, -0.6032], requires_grad=True)
mu
tensor([[ 1.9120,  1.3170],
        [ 1.1485,  2.5273],
        [-0.0710, -0.6226],
        [-0.2550, -0.4544],
        [-0.6992,  1.3855]], requires_grad=True)
cov
tensor([[[ 1.7059, -0.2857],
         [-0.0742,  0.3648]],

        [[ 0.9451,  0.9877],
         [ 0.5326,  0.8220]],

        [[ 1.4542,  0.0374],
         [-0.0307,  0.2930]],

        [[ 1.5630, -0.0896],
         [-0.1058,  0.4160]],

        [[ 1.4418, -0.0182],
         [-0.0026,  0.9869]]], requires_grad=True)
loss:
tensor(29.4981, grad_fn=<ThSubBackward>)
Step 2000
pi
tensor([ 1.4670,  0.8903,  0.1249, -0.4976, -0.9846], requires_grad=True)
mu
tensor([[ 2.5388,  1.1934],
        [ 1.5779,  3.0436],
        [ 0.6267, -0.9487],
        [ 0.0331, -0.7741],
        [-0.5802,  1.3420]], requires_grad=True)
cov
tensor([[[ 1.3027, -0.1987],
         [-0.0402,  0.3528]],

        [[ 0.6963,  0.6868],
         [ 0.3562,  0.6063]],

   

KeyboardInterrupt: 