# Stein Variational Gradient Descent (SVGD)

### References
- [Qiang Liu, Dilin Wang, "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm", NIPS, 2016.](https://arxiv.org/abs/1608.04471)

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation
color = {'orange': '#FF4B00', 'blue': '#005AFF', 'green': '#03AF7A', 'purple': '#990099'}

## Prepare the Ground Truth Probablity Density Function

In [None]:
# Mixture Gaussian Distribution
class MixtureGaussianDistribution:
    def __init__(
        self,
        dimension: int, # the dimension of the data : writing it as D below
        num_components: int, # how many gaussian components in the mixture? : writing it as K below
        weight_array: np.ndarray, # the weights of the gaussian components : size is (K,)
        mu_array: np.ndarray, # the means of the gaussian components : size is (K, D)
        cov_array: np.ndarray, # the standard deviations of the gaussian components : size is (K, D, D)
    ) -> None:
        # get the number of components
        self.pdf_name = "Mixture Gaussian Distribution"
        self.dimension = dimension # D
        self.num_components = num_components # K

        # check if the size of weight_array is (K,)
        if not(weight_array.shape == (self.num_components,)):
            raise ValueError("The size of the weight_array must be (num_components,)")
        # check if the size of mu_array is (K, D)
        if not(mu_array.shape == (self.num_components, self.dimension)):
            raise ValueError("The size of the mu_array must be (num_components, dimension)")
        # check if the size of cov_array is (K, D, D)
        if not(cov_array.shape == (self.num_components, self.dimension, self.dimension)):
            raise ValueError("The size of the cov_array must be (num_components, dimension, dimension)")
        # check if the sum of the weights is 1
        if not(np.abs(np.sum(weight_array) - 1) < 1e-10):
            raise ValueError("The sum of the weights must be 1")

        # store the parameters
        self.weight_array = weight_array
        self.mu_array = mu_array
        self.cov_array = cov_array

        # obtain precision_array from cov_array
        self.precision_array = np.zeros((self.num_components, self.dimension, self.dimension))
        for k in range(self.num_components):
            self.precision_array[k] = np.linalg.inv(self.cov_array[k])

    def gaussian(self, x: np.ndarray, mu: np.ndarray, sigma: np.ndarray) -> float:
        # calculate the probability of the Gaussian distribution
        return 1 / np.sqrt((2 * np.pi) ** self.dimension * np.linalg.det(sigma)) * np.exp(-0.5 * (x - mu).T @ np.linalg.inv(sigma) @ (x - mu))

    def mix_gaussian(self, x: np.ndarray) -> float:
        # calculate the probability
        prob = 0
        for k in range(self.num_components):
            prob += self.weight_array[k] * self.gaussian(x, self.mu_array[k], self.cov_array[k])
        return prob

    def dln_mix_gaussian(self, x: np.ndarray) -> np.ndarray:
        # calculate the derivative of the probability
        ddx_mix_gaussian = np.zeros(self.dimension)
        for k in range(self.num_components):
            ddx_mix_gaussian += self.weight_array[k] * self.gaussian(x, self.mu_array[k], self.cov_array[k]) * (-1.0 * self.precision_array[k] @ (x - self.mu_array[k]))
        return ddx_mix_gaussian / self.mix_gaussian(x)

    def prob(self, x: np.ndarray) -> float:
        # return p(x)
        return self.mix_gaussian(x)
    
    def dlnprob(self, x: np.ndarray) -> np.ndarray:
        # return nabla ln p(x)
        return self.dln_mix_gaussian(x)
    
    def plot_2d(self, 
                idx_select=[0, 1], # 2-dim .e. x[idx_select[0]] and x[idx_select[1]] are plotted
                x_min=-10, x_max=10, y_min=-10, y_max=10, res_x=100, res_y=100, # axis range and resolution
                levels=10, dpi=100, save_path=None # other settings
        ):
        # plot the contour of the probability density function
        fig, ax = plt.subplots(dpi=dpi)
        x = np.linspace(x_min, x_max, res_x)
        y = np.linspace(y_min, y_max, res_y)
        X, Y = np.meshgrid(x, y)
        Z = np.zeros((res_x, res_y))
        for i in range(res_x):
            for j in range(res_y):
                Z[i, j] = self.prob(np.array([X[i, j], Y[i, j]]))
        ax.contour(X, Y, Z, levels=levels, cmap='viridis')
        ax.set_aspect('equal')
        ax.set_title(f"{self.pdf_name}", fontsize=16)
        ax.set_xlabel(f"x_{idx_select[0]}", fontsize=14)
        ax.set_ylabel(f"x_{idx_select[1]}", fontsize=14)
        # save figure if necessary
        if save_path is not None:
            plt.savefig(save_path)
        plt.show()

# run test
gmm_test = MixtureGaussianDistribution(
    dimension=2,
    num_components=3,
    weight_array=np.array([0.3, 0.3, 0.4]),
    mu_array=np.array([[-4, 3], [4, -3], [5, 5]]),
    cov_array=np.array([[[4, 2], [2, 5]], [[6, 0], [0, 4]], [[3.0, 0.0], [0.0, 3.0]]])
)
gmm_test.plot_2d()

## SVGD Algorithm

In [None]:
# SVGD class
class SVGD:
    def __init__(
        self,
        dim_particle: int, # the dimension of a particle
        num_particles: int, # the number of particles
        gd_step_size: float, # the step size of the gradient descent
        dlnprob: callable, # the gradient of the log probability
        particles_init: np.ndarray, # the initial particles: size is (num_particles, dim_particle)
    ):
        # load parameters
        self.dim_particle = dim_particle
        self.num_particles = num_particles
        self.gd_step_size = gd_step_size
        self.dlnprob = dlnprob
        
        # initialization
        self.reset(particles_init=particles_init)

    def reset(
        self,
        particles_init: np.ndarray, # the initial particles: size is (num_particles, dim_particle)
    ):
        # check if the size of particles_init is (num_particles, dim_particle)
        if not(particles_init.shape == (self.num_particles, self.dim_particle)):
            raise ValueError("The size of particles_init must be (num_particles, dim_particle)")

        # load the initial particles
        self.particles = particles_init

        # reset the iteration count
        self._iter_count = 0

    def kernel(
        self,
        x_a: np.ndarray,
        x_b: np.ndarray,
        sqh = 0.5 # if sqh>0, use the value, if sqh<0 use median trick mensioned in the svgd paper.
    ):
        # TODO: implement median trick of sqh (when sqh < 0)
        return np.exp(-1.0 * np.linalg.norm(x_a - x_b, 2) / sqh**2)

    def d_kernel(
        self,
        x_a: np.ndarray,
        x_b: np.ndarray,
        sqh = 0.5 # if sqh>0, use the value, if sqh<0 use median trick mensioned in the svgd paper.
    ):
        # TODO: implement median trick of sqh (when sqh < 0)
        return (-1.0 / sqh**2) * (x_a - x_b) * np.exp(-1.0 * np.linalg.norm(x_a - x_b, 2) / sqh**2)

    def update(
        self,
        n_iter: int, # the number of iterations
    ):
        # start the timer
        time_ms_calc_start = time.perf_counter() * 1000.0

        # update the particles for n_iter times using SVGD
        for l in range(n_iter):
            self._iter_count += 1 # increment the total iteration count
            for i in range(self.num_particles):
                x_i_l = self.particles[i, :] # load x_i_l
                x_i_l_plus_1 = x_i_l # initialize x_i_l_plus_1
                for j in range(self.num_particles):
                    x_j_l = self.particles[j, :]
                    kxx = self.kernel(x_j_l, x_i_l)
                    dlnpx = self.dlnprob(x_j_l)
                    dkxx = self.d_kernel(x_j_l, x_i_l)
                    x_i_l_plus_1 += (1.0 / self.num_particles) * kxx * dlnpx + dkxx
                self.particles[i, :] = x_i_l_plus_1

        # stop the timer
        time_ms_calc_end = time.perf_counter() * 1000.0

        # return the updated particles and the calculation time [ms]
        return self.particles.copy(), (time_ms_calc_end - time_ms_calc_start), self._iter_count

## Visualizer

In [None]:
# Visualizer class
class Visualizer:
    def __init__(
        self, 
        pdf: callable,
        view_x_min: float, view_x_max: float,
        view_y_min: float, view_y_max: float,
        dpi: int = 100, res_x: int = 100, res_y: int = 100, figsize: list[int] = [7, 7],
        make_animation: bool = False,
    ) -> None:
        # load parameters
        self.pdf = pdf
        self.view_x_min, self.view_x_max, self.res_x = view_x_min, view_x_max, res_x
        self.view_y_min, self.view_y_max, self.res_y = view_y_min, view_y_max, res_y
        self.dpi = dpi
        self.figsize = figsize

        # initialize figure
        self.__init_fig()

        # initialize animation if necessary
        self.make_animation = make_animation
        if self.make_animation:
            self.__init_animation()

    def __init_fig(self):
        # make figure
        self.fig, self.ax = plt.subplots(1, 1, dpi=self.dpi, figsize=self.figsize)

        # prepare grid_x, grid_y
        self.grid_x = np.linspace(self.view_x_min, self.view_x_max, self.res_x)
        self.grid_y = np.linspace(self.view_y_min, self.view_y_max, self.res_y)
        self.mesh_x, self.mesh_y = np.meshgrid(self.grid_x, self.grid_y)
        self.mesh_z = np.zeros((self.res_x, self.res_y))
        for i in range(self.res_x):
            for j in range(self.res_y):
                self.mesh_z[i, j] = self.pdf(np.array([self.mesh_x[i, j], self.mesh_y[i, j]]))

        # graph layout settings
        self.ax.set_xlim(self.view_x_min, self.view_x_max)
        self.ax.set_ylim(self.view_y_min, self.view_y_max)
        self.ax.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
        self.ax.tick_params(bottom=False, left=False, right=False, top=False)
        self.ax.set_aspect('equal')

    def plot_2d(
        self,
        init_fig: bool = True,
        iter_count: int = None,
        particles: np.ndarray = None, # set np.ndarray if scatter plot is necessary. size: (?, 2)
    ) -> None:
        # initialize figure if necessary
        if init_fig:
            self.__init_fig()

        # visualize the results
        self.ax.contour(self.mesh_x, self.mesh_y, self.mesh_z, cmap='viridis')

        # set iter_count as title if necessary
        if iter_count is not None:
            self.ax.set_title(f"iteration: {iter_count}")
        
        # plot particles if necessary
        if particles is not None:
            self.ax.scatter(particles[:, 0], particles[:, 1], color=color["blue"])

        # show figure
        plt.show()

    def __init_animation(
        self
    ) -> None:
        # clear animation frames
        self.anim_frames = []

    def add_animation_frame(
        self,
        particles: np.ndarray,
        iter_count: int = None,
        init_anim = False
    ) -> None:
        # initialize animation if necessary
        if init_anim:
            self.__init_animation()

        # add contents to the latest frame
        frame = [self.ax.contour(self.mesh_x, self.mesh_y, self.mesh_z, cmap='viridis')]
        frame += [self.ax.scatter(particles[:, 0], particles[:, 1], color=color["blue"])]

        # add title text
        info_text = f"Bayesian Inference with SVGD"
        frame += [self.ax.text(0.5, 1.07, info_text, ha='center', transform=self.ax.transAxes, fontsize=14, fontfamily='monospace')]


        # add text info if necessary
        if iter_count is not None:
            # draw the information text
            info_text = f"iteration: {iter_count: 5d}"
            frame += [self.ax.text(0.5, 1.02, info_text, ha='center', transform=self.ax.transAxes, fontsize=14, fontfamily='monospace')]

        # append frame
        self.anim_frames.append(frame)

    def save_animation(self, filename, interval=500, movie_writer="ffmpeg") -> None:
        # save animation of the recorded frames (ffmpeg required)
        ani = ArtistAnimation(self.fig, self.anim_frames, interval=interval, blit=True)
        ani.save(filename, writer=movie_writer)

## Run Bayesian Inference with SVGD

In [None]:
# prepare ground truth probability density function (pdf) and its logaritmic gradient
ground_truth_pdf = MixtureGaussianDistribution(
    dimension=2,
    num_components=3,
    weight_array=np.array([0.3, 0.3, 0.4]),
    mu_array=np.array([[-4, 3], [4, -3], [5, 5]]),
    cov_array=np.array([[[4, 2], [2, 5]], [[6, 0], [0, 4]], [[3.0, 0.0], [0.0, 3.0]]])
)

# set parameters and initialize svgd
NUM_PARTICLES = 50
inferrer = SVGD(
    dim_particle=2,
    num_particles=NUM_PARTICLES,
    gd_step_size=5.0,
    dlnprob=ground_truth_pdf.dlnprob,
    particles_init=np.random.randn(NUM_PARTICLES, 2)
)

# initialize visualizer
vis = Visualizer(
    pdf = ground_truth_pdf.prob,
    view_x_min = -10.0,
    view_x_max = +10.0,
    view_y_min = -10.0,
    view_y_max = +10.0,
    make_animation=True,
)

# run bayesian infere50nce to approximate the ground truth pdf using svgd
for _ in range(50):
    # update svgd particles
    particles_updated, calc_time, current_iter_count = inferrer.update(n_iter=10)
    print(f"iteration count: {current_iter_count:04d}")

    # [FOR DEBUG] plot figure
    # vis.plot_2d(init_fig=True, iter_count=current_iter_count, particles=particles_updated)

    # add animation frame
    vis.add_animation_frame(particles_updated, iter_count=current_iter_count)

# save animation
vis.save_animation(interval=100, filename="svgd.mp4")