In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
from tqdm import trange
from scipy.spatial.distance import pdist, squareform

device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
class Small_F_net(nn.Module):

    def __init__(self, z_dim, latent_dim):
        super().__init__()
        self.z_dim = z_dim
        self.latent_dim = latent_dim
        self.dnn = nn.Sequential(nn.Linear(self.z_dim, self.latent_dim),
                                 nn.Tanh(),
                                 nn.Linear(self.latent_dim, self.latent_dim),
                                 nn.Tanh(),
                                 nn.Linear(self.latent_dim, self.z_dim))

    def forward(self, x):
        f = self.dnn(x)
        return f


class Conditioned_Diffusion:

    def __init__(self, num_interval, num_obs, beta=10.0, T=1.0):

        self.num_interval = num_interval
        self.num_obs = num_obs
        self.T = T
        self.stepsize = T / num_interval
        self.beta = beta

    def drift(self, u):
        return self.beta * u * (1 - u**2) / (1 + u**2)

    def generate_path(self, sigma=0.1, batch_size=1, xi=None):
        if isinstance(xi, torch.Tensor):
            batch_size = xi.shape[0]
        else:
            xi = torch.randn(batch_size, self.num_interval).to(device)
        x = torch.zeros(batch_size, self.num_interval + 1).to(device)
        u = torch.zeros(batch_size, self.num_interval + 1).to(device)
        for k in range(self.num_interval):
            temp = np.sqrt(self.stepsize) * xi[:, k]
            x[:, k + 1] = x[:, k] + temp
            u[:, k + 1] = u[:, k] + self.stepsize * self.drift(u[:, k]) + temp

        noise = torch.randn(batch_size, self.num_obs).to(device) * sigma
        obs_interval = self.num_interval / self.num_obs
        y = u[:,
              np.arange(obs_interval, self.num_interval +
                        1, obs_interval)] + noise

        return xi, x, u, y

    def loglikelihood(self, xi, y, sigma=0.1):
        """
        Parameters:
            xi: inputs, with shape (batch, num_interval)
            y: observations, with shape (batch, num_obs)
            sigma: standard deviation of observation noises
        """

        batch = xi.shape[0]
        uk = torch.zeros(batch).to(device)
        logll = torch.zeros(batch).to(device)
        obs_interval = self.num_interval / self.num_obs

        for k in range(self.num_interval):
            vk = uk + self.stepsize * self.drift(uk) + np.sqrt(
                self.stepsize) * xi[:, k]
            if k > 0 and k % obs_interval == 0:
                logll = logll - 0.5 * (y[:, int(k / obs_interval) - 1] -
                                       uk)**2 / sigma**2
            uk = vk
        logll = logll - 0.5 * (y[:, -1] - uk)**2 / sigma**2
        return logll
    
    def posterior_score(self, xi, y, sigma=0.1):
        """
        Parameters:
            xi: inputs, with shape (batch, num_interval)
            y: observations, with shape (batch, num_obs)
            sigma: standard deviation of observation noises
        """

        prior_score = -xi
        dup_xi = xi.view(xi.shape[0], -1)
        dup_xi.requires_grad_(True)
        logll = self.loglikelihood(dup_xi, y, sigma)
        likelihood_score = autograd.grad(logll.sum(), dup_xi)[0]

        score = prior_score + likelihood_score
        return score

    def evaluation(self, xi):
        _, _, u, _ = self.generate_path(xi=xi.detach())
        u = u.detach().cpu()[:, 1:]
        u_sgld_true = torch.load("cd_sgld_u_12345.pt").detach().cpu()
        mmd = MMDStatistic(u.shape[0], u_sgld_true.shape[0])
        logMMD = np.log(mmd(u, u_sgld_true, [0.1, 1]))

        return logMMD


class MMDStatistic:
    def __init__(self, n_1, n_2):
        self.n_1 = n_1
        self.n_2 = n_2

        self.a00 = 1. / (n_1 * (n_1 - 1))
        self.a11 = 1. / (n_2 * (n_2 - 1))
        self.a01 = - 1. / (n_1 * n_2)

    def __call__(self, sample_1, sample_2, alphas, ret_matrix=False):
 
        sample_12 = torch.cat((sample_1, sample_2), 0)
        distances = squareform(pdist(sample_12))

        kernels = None
        for alpha in alphas:
            kernels_a = np.exp(- alpha * distances ** 4)
            if kernels is None:
                kernels = kernels_a
            else:
                kernels += kernels_a

        k_1 = kernels[:self.n_1, :self.n_1]
        k_2 = kernels[self.n_1:, self.n_1:]
        k_12 = kernels[:self.n_1, self.n_1:]

        mmd = (2 * self.a01 * k_12.sum() +
               self.a00 * (k_1.sum() - np.trace(k_1)) +
               self.a11 * (k_2.sum() - np.trace(k_2)))
        if ret_matrix:
            return mmd, kernels
        else:
            return mmd


# Particle VI class

In [None]:
class ParVI():

    def __init__(self, target_score, dim, latent_dim):
        self.target_score = target_score
        self.dim = dim
        self.f_net = Small_F_net(dim, latent_dim).to(device)

    def precondition_g(self, x, alpha=1):
        # beta = 0.1
        beta = 1.0
        dup_x = x.reshape(-1, self.dim)
        dup_x.requires_grad_(True)
        n = dup_x.shape[0]
        fx = self.f_net(dup_x)
        H = 1 / torch.var(dup_x, dim=0)**alpha
        H = H.repeat(n, 1).detach()
        self.H = beta * H + (1 - beta) * self.H
        h_norm = 0.5 * torch.sum(self.H * fx * fx) / n

        return h_norm

    def svgd_kernel(self, x, h=-1):
        x = x.detach().numpy()
        sq_dist = pdist(x)
        pairwise_dists = squareform(sq_dist)**2
        if h < 0:  # if h < 0, using median trick
            h = np.median(pairwise_dists)
            h = np.sqrt(0.5 * h / np.log(x.shape[0] + 1))

        # compute the rbf kernel
        Kxy = np.exp(-pairwise_dists / h**2 / 2)

        dxkxy = -np.matmul(Kxy, x)
        sumkxy = np.sum(Kxy, axis=1)
        for i in range(x.shape[1]):
            dxkxy[:, i] = dxkxy[:, i] + np.multiply(x[:, i], sumkxy)
        dxkxy = dxkxy / (h**2)
        return torch.from_numpy(Kxy).float(), torch.from_numpy(dxkxy).float()

    def gssm(self,
             samples,
             n_particles=1,
             g_fn='p_norm',
             p=2,
             precond_alpha=1.0):
        dup_samples = samples.view(-1, self.dim)
        dup_samples.requires_grad_(True)

        score = self.target_score(dup_samples)
        f = self.f_net(dup_samples)

        loss1 = torch.sum(f * score, dim=-1).mean()
        loss2 = torch.zeros(samples.shape[0]).to(device)
        for _ in range(n_particles):
            vectors = torch.randn_like(dup_samples).to(device)
            gradv = torch.sum(f * vectors)
            grad2 = autograd.grad(gradv,
                                  dup_samples,
                                  create_graph=True,
                                  retain_graph=True)[0]
            loss2 += torch.sum(vectors * grad2, dim=-1) / n_particles
        loss2 = loss2.mean()

        if g_fn == 'p_norm':
            loss3 = torch.norm(f, p=p, dim=-1)**p / p
            loss3 = loss3.mean()
        elif g_fn == 'precondition':
            loss3 = self.precondition_g(dup_samples, precond_alpha)

        loss = loss1 + loss2 - loss3
        return loss

    def sample(
        self,
        f_opt,
        g_fn,
        sample_size,
        p=2,
        precond_alpha=1,
        step_size=1e-3,
        adagrad=False,
        adaptive=False,
        p_step=1e-4,
        lb=1.1,
        ub=6.0,
        alpha=0.9,
        n_epoch=2000,
        f_iter=1,
        pre_train_epoch=100,
        check_frq=200,
        evaluation=lambda x: 0,
    ):

        x = torch.randn(sample_size, self.dim).to(device)
        xs = []
        info = []
        historical_grad = 0
        fudge_factor = 1e-6
        self.H = torch.zeros(sample_size, self.dim,
                             requires_grad=False).to(device)

        for i in range(pre_train_epoch):
            dup_x = x.data
            dup_x.requires_grad_(True)
            f_loss = -self.gssm(dup_x, g_fn=g_fn, p=p)
            f_opt.zero_grad()
            f_loss.backward()
            f_opt.step()

        for ep in trange(n_epoch):
            dup_x = x.data
            dup_x.requires_grad_(True)

            if g_fn == 'sgld':
                noise = torch.randn_like(x).to(device)
                x = x + step_size * self.target_score(dup_x) + np.sqrt(
                    2 * step_size) * noise
            else:
                if g_fn == 'svgd':
                    s = self.target_score(dup_x)
                    kxy, dxkxy = self.svgd_kernel(x.cpu(), h=-1)
                    kxy = kxy.to(device)
                    dxkxy = dxkxy.to(device)
                    v = (torch.matmul(kxy, s) + dxkxy) / sample_size
                else:
                    for i in range(f_iter):
                        f_loss = -self.gssm(
                            dup_x, g_fn=g_fn, p=p, precond_alpha=precond_alpha)
                        f_opt.zero_grad()
                        f_loss.backward()
                        f_opt.step()

                    v = self.f_net(x)

                # adagrad
                if adagrad:
                    if ep == 0:
                        historical_grad = historical_grad + v**2
                    else:
                        historical_grad = alpha * historical_grad + (
                            1 - alpha) * v**2
                    v = torch.divide(
                        v, fudge_factor + torch.sqrt(historical_grad))

                # update particles
                x = x + step_size * v

                # adaptive p
                if g_fn == 'p_norm' and adaptive:
                    # compute gradient of p
                    v = abs(v).detach().cpu().numpy()**p
                    grad_p = np.sum(v * (np.log(v + 1e-7) - 1),
                                    axis=1).mean() / p**2
                    grad_p = np.clip(grad_p, -0.1, 0.1)

                    p += p_step * grad_p
                    p = np.clip(p, lb, ub)

            if (ep % check_frq == 0) or (ep == n_epoch - 1):
                xs += [x]
                info += [evaluation(x)]
                print(info[-1])

        return xs, info


In [None]:
# adaptive

torch.manual_seed(12345)
np.random.seed(12345)

num_interval = 100
num_obs = 20
beta = 10
T = 1.0
sigma = 0.1
CD = Conditioned_Diffusion(num_interval, num_obs, beta, T)
xi_true, x_true, u_true, y = CD.generate_path(sigma, batch_size=1)

dim = num_interval
latent_dim = 200
sample_size = 1000


def cd_score(x):
    return CD.posterior_score(x, y, sigma)


parvi = ParVI(cd_score, dim, latent_dim)
f_opt = optim.Adam(parvi.f_net.parameters(), lr=1e-3)
xi_ada, info_ada = parvi.sample(f_opt,
                                'p_norm',
                                sample_size,
                                p=2.0,
                                adaptive=True,
                                p_step=3e-3,
                                step_size=3e-3,
                                n_epoch=500,
                                f_iter=15,
                                pre_train_epoch=100,
                                check_frq=50,
                                evaluation=CD.evaluation)


In [None]:
# precondition

torch.manual_seed(12345)
np.random.seed(12345)

num_interval = 100
num_obs = 20
beta = 10
T = 1.0
sigma = 0.1
CD = Conditioned_Diffusion(num_interval, num_obs, beta, T)
xi_true, x_true, u_true, y = CD.generate_path(sigma, batch_size=1)

dim = num_interval
latent_dim = 200
sample_size = 1000


def cd_score(x):
    return CD.posterior_score(x, y, sigma)


parvi = ParVI(cd_score, dim, latent_dim)
f_opt = optim.Adam(parvi.f_net.parameters(), lr=1e-3)
xi_H2, info_H2 = parvi.sample(f_opt,
                              'precondition',
                              sample_size,
                              p=2,
                              precond_alpha=1,
                              step_size=3e-3,
                              n_epoch=500,
                              f_iter=15,
                              pre_train_epoch=100,
                              check_frq=50,
                              evaluation=CD.evaluation)
