In [35]:
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal

from tqdm import tqdm_notebook as tqdm
from collections import OrderedDict

In [4]:
data_fname = 'data.npz'
load_data = np.load(data_fname)
true_pi = load_data['pi']
true_mu = load_data['mu']
samples = load_data['samples']
adv_sample = load_data['adv_sample']

In [6]:
true_pi, true_mu, samples[::50], adv_sample

(array([0.23705428, 0.06574439, 0.39978303, 0.07123911, 0.22617919]),
 array([[2.02889543, 0.58387321],
        [1.48504397, 3.96130589],
        [1.26997612, 1.46258317],
        [3.4355634 , 1.58225174],
        [1.47788008, 3.48312234]]),
 array([[ 4.17993725,  1.24175583],
        [ 2.68098744,  0.32334823],
        [ 1.41416657,  0.3940763 ],
        [ 3.04086445,  4.25610681],
        [ 1.87104121,  4.03256269],
        [ 5.0214534 ,  1.57025624],
        [ 3.1324006 ,  1.80930315],
        [ 3.48725692,  2.19854219],
        [ 2.67629033, -0.01613154],
        [ 2.5919368 ,  2.00105177],
        [-0.10403355,  0.37719824],
        [ 3.65484623,  1.83349119],
        [ 0.27127926,  2.74596415],
        [ 1.16614525,  1.58661359],
        [ 1.47989679,  0.89424985],
        [ 0.9680544 ,  1.00286455],
        [ 4.67228469,  2.75282497],
        [ 0.3507703 ,  3.45548201],
        [ 2.11184957,  2.31217297],
        [ 2.10118491,  4.49285233]]),
 array([-1., -1.]))

In [68]:
def MoG_prob(x, pi, mu, cov):
    K, dim = mu.size()
    assert x.size() == (dim,)
    assert pi.size() == (K,)
    assert cov.size() == (K, dim, dim)
    
    priors = torch.softmax(pi, dim=0)
    
    prob = 0.0
    for k in range(K):
        cov2 = torch.matmul(cov[k].t(), cov[k])
        log_prob_k = -dim * 0.5 * math.log(2 * math.pi) - 0.5 * cov2.logdet() - 0.5 * cov2.inverse().matmul(x - mu[k]).dot(x - mu[k])
        prob += torch.exp(log_prob_k) * priors[k]
    return prob

def MoG_loss(X, z, pi, mu, cov, lam):
    # z: adv_sample
    K, dim = mu.size()
    assert X.size(1) == dim
    assert pi.size() == (K,)
    assert cov.size() == (K, dim, dim)
    
    loss = lam * MoG_prob(z, pi, mu, cov)
    for x in X:
        loss -= torch.log(MoG_prob(x, pi, mu, cov))
    return loss

def GD_solver(samples, adv_sample, lam=2.0, max_step=10000, lr=0.01, K=5, dim=2):
    X = torch.tensor(samples)
    z = torch.tensor(adv_sample)
    
    pi = torch.rand(K, requires_grad=True)
    pi /= pi.sum()
    mu = torch.randn(K, dim, requires_grad=True)
    cov = torch.eye(dim).repeat(K, 1).requires_grad_()
    
    print('*** Init ***')
    print('pi:')
    print(pi)
    print('mu:')
    print(mu)
    print('cov:')
    print(cov)
    print('loss:')
    print(MoG_loss(X, z, pi, mu, cov, lam=lam))
    
    optimizer = optim.SGD([pi, mu, cov], lr=lr)
    
    for step in tqdm(range(max_step)):
        optimizer.zero_grad()
        loss = MoG_loss(X, z, pi, mu, cov, lam=lam)
        loss.backward()
        optimizer.step()
        
        if (step + 1) % 100 == 0:
            print('Step {}'.format(step))
            print('pi:')
            print(pi)
            print('mu:')
            print(mu)
            print('cov:')
            print(cov)
            print('loss:')
            print(loss)
    
    return pi, mu, cov

In [73]:
lam=2.0
max_step=100000
lr=0.001
K=5
dim=2

X = torch.FloatTensor(samples[:10])
z = torch.FloatTensor(adv_sample)

pi = torch.rand(K, dtype=torch.float)
pi /= pi.sum()
pi.requires_grad_()
mu = torch.randn(K, dim, dtype=torch.float)
mu.requires_grad_()
cov = torch.eye(dim, dtype=torch.float).repeat(K, 1, 1)
cov.requires_grad_()

params = [pi, mu, cov]
named_params = OrderedDict([('pi', pi), ('mu', mu), ('cov', cov)])

print('*** Init ***')
for n, p in named_params.items():
    print(n)
    print(p)
    def _hook(grad, p=p, n=n):
        if torch.isnan(grad).sum() > 0:
            print('Error in grad:', grad)
            print('Shape:', grad.size())
            print('Parameter name:', n)
            print('p.requires_grad:', p.requires_grad)
            raise ValueError
    p.register_hook(_hook)
print('loss:')
print(MoG_loss(X, z, pi, mu, cov, lam=lam))

optimizer = optim.SGD(params, lr=lr)

for step in tqdm(range(max_step)):
    optimizer.zero_grad()
    loss = MoG_loss(X, z, pi, mu, cov, lam=lam)
    loss.backward()
    optimizer.step()

    if (step + 1) % 1000 == 0:
        print('Step {}'.format(step + 1))
        for n, p in named_params.items():
            print(n)
            print(p)
        print('loss:')
        print(loss)

*** Init ***
pi
tensor([0.1777, 0.1601, 0.1980, 0.2296, 0.2346], requires_grad=True)
mu
tensor([[-0.6630, -0.7774],
        [ 1.0629, -0.0470],
        [-0.3393,  0.1417],
        [-1.1408,  0.4482],
        [ 2.0556, -2.8561]], requires_grad=True)
cov
tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], requires_grad=True)
loss:
tensor(64.0133, grad_fn=<ThSubBackward>)


HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))

Step 1000
pi
tensor([-0.3527,  2.4673, -0.0929, -0.4501, -0.5716], requires_grad=True)
mu
tensor([[-0.2620, -0.7911],
        [ 2.2233,  1.5458],
        [ 0.2996,  0.3389],
        [-0.8062,  0.5309],
        [ 2.0056, -2.7227]], requires_grad=True)
cov
tensor([[[ 1.0928,  0.0576],
         [ 0.0689,  0.7590]],

        [[ 1.1582, -0.0057],
         [-0.0785,  1.3659]],

        [[ 0.9570,  0.2624],
         [ 0.2126,  1.2744]],

        [[ 1.3312,  0.1374],
         [ 0.1421,  1.1000]],

        [[ 1.0128, -0.0693],
         [-0.0744,  1.2059]]], requires_grad=True)
loss:
tensor(34.4722, grad_fn=<ThSubBackward>)
Step 2000
pi
tensor([-0.1258,  2.7062, -0.0869, -0.6749, -0.8186], requires_grad=True)
mu
tensor([[ 0.2894, -1.0852],
        [ 2.2727,  1.6652],
        [ 0.6177,  0.3353],
        [-0.7001,  0.5295],
        [ 1.9796, -2.6588]], requires_grad=True)
cov
tensor([[[ 0.6297,  0.6022],
         [ 0.1295,  0.1141]],

        [[ 1.1307, -0.0545],
         [-0.1335,  1.3169]],

   

KeyboardInterrupt: 