In [10]:
from core.gradest import infer, version
import torch 
from core.torchGradFlow import infer_cv, plot_norm_contour
version()

import numpy as np
from numpy import ones, zeros, eye 


Gradient Flow: Version 1.0
Song Liu et al., Variational Gradient Descent using Local Linear Models
arxiv: https://arxiv.org/abs/2305.15577

Copyright, Song Liu (song.liu@bristol.ac.uk)
Powered by Juzhen (https://github.com/anewgithubname/Juzhen)
    _           _                
   (_)_   _ ___| |__   ___ _ __  
   | | | | |_  / '_ \ / _ \ '_ \ 
   | | |_| |/ /| | | |  __/ | | |
  _/ |\__,_/___|_| |_|\___|_| |_|
 |__/                            
                                 


In [11]:
def gradcomp(mup, covarp, n=1001):
    from scipy.stats import multivariate_normal as MVN

    d = 2
    
    muq = zeros(d)
    covarq = eye(d)

    Xp = MVN(mup, covarp).rvs(n).astype(np.float32)
    Xq = MVN(muq, covarq).rvs(n).astype(np.float32)

    # generate a grid [-2, 2]
    x = torch.linspace(-2, 2, 20)
    y = torch.linspace(-2, 2, 20)
    x0 = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)

    grad = infer(Xp, Xq, x0.numpy())
    return mup, covarp, muq, covarq, x0, grad


In [12]:
def plot_grad(mup, covarp, muq, covarq, x0, grad):

    import matplotlib.pyplot as plt
    from torch.distributions.multivariate_normal import MultivariateNormal as torchMVN

    plt.figure(figsize=(5, 5))
    plot_norm_contour(mup, covarp)
    plot_norm_contour(muq, covarq, 'b')
    plt.title("red: p, blue q, green: estimated grad, red: true gradient")

    x0.requires_grad = True
    logr0 = torchMVN(torch.from_numpy(mup), torch.from_numpy(covarp)).log_prob(x0) - torchMVN(torch.from_numpy(muq), torch.from_numpy(covarq)).log_prob(x0)
    grad_logr0 = torch.autograd.grad(logr0.sum(), x0)[0]

    #plot the gradient estimate
    plt.quiver(x0[:, 0].detach().cpu(), x0[:, 1].detach().cpu(), 
            grad[:, 0], 
            grad[:, 1], scale=40, color='g')

    plt.quiver(x0[:, 0].detach().cpu(), x0[:, 1].detach().cpu(), 
            grad_logr0[:, 0].cpu(),
            grad_logr0[:, 1].cpu(), scale=40, color='r')

    plt.xlim(-2, 2)
    plt.ylim(-2, 2)

In [13]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

def f(mup1, mup2, varp, n):
    mup, covarp, muq, covarq, x0, grad = gradcomp(np.array([mup1, mup2]), eye(2)*varp, n)
    plot_grad(mup, covarp, muq, covarq, x0, grad)

interact_manual(f, mup1=widgets.FloatSlider(min=-2, max=2, step=.1, value=0, continuous_update=False), 
          mup2=widgets.FloatSlider(min=-2, max=2, step=.1, value=0, continuous_update=False), 
          varp=widgets.FloatSlider(min=.25, max=2, step=.1, value=1, continuous_update=False),
          n=widgets.IntSlider(min=500, max=10000, step=50, value=500, continuous_update=False))


interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='mup1', max=2.0, min=-2.0), …

<function __main__.f(mup1, mup2, varp, n)>