## Import utilities

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
import math
import numpy as np

In [2]:
import sys
sys.path.insert(0, ".")

## SVGD gradient with deep kernel

Define function `svgd_gradient(particals, score, h=-1)` for $\phi(x)$ to update particals

In [None]:
# defining deep kernel
'''
def rbf_kernel(x, y, sigma=1.0):
    K = torch.exp(-torch.cdist(x, y) ** 2 / (2 * sigma ** 2))
    return K
'''

class KernelLayer(nn.Module):
    def __init__(self, input_dim, kernel):
        super(KernelLayer, self).__init__()
        self.input_dim = input_dim
        self.kernel = kernel
        
    def forward(self, x):
        k = self.kernel(x, x)
        return k

class KernelNN(nn.Module):
    def __init__(self, kernel, in_dim=1, out_dim=1, hidden_dim=300, depth=1):
        super().__init__()
        if depth == 1:
            self.layer = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                Swish(hidden_dim),
                nn.Linear(hidden_dim, out_dim),
                KernelLayer(out_dim),
            )
        else:
            self.layer = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                Swish(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim),
                Swish(hidden_dim),
                nn.Linear(hidden_dim, out_dim),
                KernelLayer(out_dim),
            )

    def forward(self, x):
        out = self.layer(x)
        return out


In [28]:
from scipy.spatial.distance import pdist, squareform

def svgd_kernel(x, h=-1):
    theta = 
    sq_dist = pdist(theta)
    pairwise_dists = squareform(sq_dist) ** 2
    pairwise_dists = torch.tensor(pairwise_dists)
    if h < 0: 
        # if h < 0, using median trick
        h = torch.median(pairwise_dists)
        h = torch.sqrt(0.5 * h / torch.log(torch.tensor(theta.shape[0]+1)))

    # compute the rbf kernel
    Kxy = torch.exp( - pairwise_dists / h**2 / 2)
    Kxy = Kxy.double()
    dxkxy = - torch.matmul(Kxy, theta)
    sumkxy = torch.sum(Kxy, axis=1)
    for i in range(theta.shape[1]):
        dxkxy[:, i] = dxkxy[:, i] + torch.mul(theta[:, i], sumkxy)
    dxkxy = dxkxy / (h**2)

    return Kxy, dxkxy

In [28]:
def svgd_gradient(particles, score, h=-1):
    """
    Args:
        particles: N x D
        score: gradient of log p(x):  N x D
        h: -1
    Returns:
        svgd update gradient, phi(*) in the paper
    """

    delta_x = particles.unsqueeze(0) - particles.unsqueeze(1)  # N x N x D
    pairwise_dists = delta_x.pow(2.0).sum(-1)  # N x N

    if h < 0:  # if h < 0, using median trick
        h = torch.median(pairwise_dists)
        h = torch.sqrt(0.5 * h / math.log(particles.shape[0] + 1))  # in fact is sqrt(1 / (2h))

    Kxy = torch.exp(-pairwise_dists / h ** 2 / 2)  # NxN , rbf kernel matrix

    dxkxy1 = -torch.matmul(Kxy, particles)
    sumkxy = torch.sum(Kxy, dim=1, keepdim=True)  # N x 1
    dxkxy2 = sumkxy * particles
    dxkxy = (dxkxy1 + dxkxy2) / (h ** 2)
    repulsive_grad = dxkxy
    attractive_grad = torch.mm(Kxy, score)
    return (attractive_grad + repulsive_grad) / particles.shape[0]



## Swish neural & MLP model

In [3]:
# base model
class Swish(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        if dim > 0:
            self.beta = nn.Parameter(torch.ones((dim,)))
        else:
            self.beta = torch.ones((1,))

    def forward(self, x):
        if len(x.size()) == 2:
            return x * torch.sigmoid(self.beta[None, :] * x)
        else:
            return x * torch.sigmoid(self.beta[None, :, None, None] * x)
        
class LargeFeatureExtractor(torch.nn.Sequential):
    def __init__(self,data_dim):
        super(LargeFeatureExtractor, self).__init__()
        self.add_module('linear1', torch.nn.Linear(data_dim, 1000))
        self.add_module('relu1', torch.nn.ReLU())
        self.add_module('linear2', torch.nn.Linear(1000, 500))
        self.add_module('relu2', torch.nn.ReLU())
        self.add_module('linear3', torch.nn.Linear(500, 50))
        self.add_module('relu3', torch.nn.ReLU())
        self.add_module('linear4', torch.nn.Linear(50, 2))

class KernelLayer(nn.Module):
    def __init__(self, input_dim, kernel):
        super(KernelLayer, self).__init__()
        self.input_dim = input_dim
        self.kernel = kernel
        
    def forward(self, x):
        k = self.kernel(x, x)
        return k

class KernelNN(nn.Module):
    def __init__(self, kernel, in_dim=1, out_dim=1, hidden_dim=300, depth=1):
        super().__init__()
        if depth == 1:
            self.layer = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                Swish(hidden_dim),
                nn.Linear(hidden_dim, out_dim),
                KernelLayer(out_dim),
            )
        else:
            self.layer = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                Swish(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim),
                Swish(hidden_dim),
                nn.Linear(hidden_dim, out_dim),
                KernelLayer(out_dim),
            )

    def forward(self, x):
        out = self.layer(x)
        return out


In [10]:
train_x = torch.randn(100,10).requires_grad_()
data_dim = train_x.size(-1)

feature_extractor = LargeFeatureExtractor(data_dim)

In [11]:
x_projected=feature_extractor(train_x)
x_projected.shape

torch.Size([100, 2])

In [12]:
def rbf_kernel(x, y, sigma=1.0):
    K = torch.exp(-torch.cdist(x, y) ** 2 / (2 * sigma ** 2))
    return K

In [13]:
def keep_grad(output, input, grad_outputs=None):
    return torch.autograd.grad(output, input, grad_outputs=grad_outputs, retain_graph=True, create_graph=True)[0]
 

In [14]:
x_projected_vice = x_projected
kernel_matrix = rbf_kernel(x_projected,x_projected_vice.detach())


In [15]:
kernel_matrix.shape

torch.Size([100, 100])

In [16]:
eps=torch.rand_like(kernel_matrix)
dkxy_dx = keep_grad(output=kernel_matrix,input=train_x,grad_outputs=eps)
dkxy_dx.shape

torch.Size([100, 10])

In [5]:
f = lambda x:2*x

## Base Neural SVGD

In [5]:
class BaseNeuralSVGD:

    def __init__(self, device):
        self.device = device

    def keep_grad(self, output, input, grad_outputs=None):
        return torch.autograd.grad(output, input, grad_outputs=grad_outputs, retain_graph=True, create_graph=True)[0]
    
    def rbf_kernel(x, sigma=-1.0):
        K = torch.exp(-torch.cdist(x, y) ** 2 / (2 * sigma ** 2))
        return K

    def approx_jacobian_trace(self, fx, x):
        eps = torch.randn_like(fx)
        eps_dfdx = self.keep_grad(fx, x, grad_outputs=eps)
        tr_dfdx = (eps_dfdx * eps).sum(-1)
        return tr_dfdx

    def exact_jacobian_trace(self, fx, x):
        vals = []
        for i in range(x.size(1)):
            fxi = fx[:, i]
            dfxi_dxi = self.keep_grad(fxi.sum(), x)[:, i][:, None]
            vals.append(dfxi_dxi)
        vals = torch.cat(vals, dim=1)
        return vals.sum(dim=1)

    def update_particles_base(self, log_p_target, data_dim, hidden_dim, net_lr, net_update_num, initial_particles,
                              n_iter, step_size, reg_coefficient, init_net_per_iter=1, jacobian_trace='approx'):
        cur_particles = initial_particles
        stein_discrepancy = []
        trans_net = LargeFeatureExtractor(data_dim=data_dim).to(self.device)
        kernel_mat = lambda particles: self.rbf_kernel(particles, particles)
        f_net = lambda 
        optimizer = optim.Adam(f_net.parameters(), lr=net_lr)
        for i in trange(n_iter):
            if (i + 1) % init_net_per_iter == 0:
                f_net = MLP(in_dim=data_dim, out_dim=data_dim, hidden_dim=hidden_dim, activation=activation).to(self.device)
                optimizer = optim.SGD(f_net.parameters(), lr=net_lr)
            f_net.train()
            for j in range(net_update_num):
                optimizer.zero_grad()
                cur_particles = cur_particles.requires_grad_()
                f_x = f_net(cur_particles)
                if jacobian_trace == 'exact':
                    tr_grad_f = self.exact_jacobian_trace(f_x, cur_particles)
                else:
                    tr_grad_f = self.approx_jacobian_trace(f_x, cur_particles)
                score_p = self.keep_grad(log_p_target(cur_particles).sum(), cur_particles)
                scorep_fx = (score_p * f_x).sum(-1)  # compute (dlogp(x)/dx)^T * f(x)
                stein_loss = (scorep_fx + tr_grad_f).mean()  # estimate of S(p, q)
                l2_penalty = (f_x * f_x).sum(1).mean() * reg_coefficient
                loss = -1.0 * stein_loss + l2_penalty
                loss.backward()
                optimizer.step()

                if j == net_update_num - 1:
                    stein_discrepancy.append(stein_loss.item())

            with torch.no_grad():
                f_net.eval()
                cur_particles = cur_particles + step_size * f_net(cur_particles)
        final_particles = cur_particles

        return final_particles, stein_discrepancy


## BayesianLR model with svgd

In [6]:
from tqdm.auto import tqdm

class BayesianLR(BaseNeuralSVGD):
    def __init__(self, X_train, y_train, X_test, y_test, batch_size=100, alpha=1.0):
        super().__init__(device=device)
        self.X_train, self.y_train = X_train, y_train
        self.X_test, self.y_test = X_test, y_test
        self.batch_size = min(batch_size, self.X_train.shape[0])
        self.alpha = 1.0 / alpha
        self.N = X_train.shape[0]
        self.permutation = torch.randperm(self.N)
        self.iter = 0

    def log_posterior(self, theta, iter):
        '''
        Returns:
            log density: 1-dim: sample_num
        '''
        W = theta  # sample_num x D
        D = theta.shape[1]
        batch = [i % self.N for i in range(iter * self.batch_size, (iter + 1) * self.batch_size)]
        ridx = self.permutation[batch]
        Xs = self.X_train[ridx, :]  # batch x D
        ys = self.y_train[ridx]  # batch x 1
        z = torch.matmul(Xs, W.t())  # batch x sample_num
        coff = -z * ys  # batch x sample_num
        coff = torch.clamp(coff, min=-20, max=20)

        log_p_D_given_w = -torch.log(1 + torch.exp(coff)).sum(0)  # sample_num
        log_p_w_given_alpha = -0.5 * self.alpha * torch.sum(W * W, dim=1) + (D / 2) * math.log(self.alpha)   # sample_num
        log_posterior = log_p_D_given_w * self.N / Xs.shape[0] + log_p_w_given_alpha
        return log_posterior

    def evaluation(self, theta, X_test, y_test):
        W = theta.cpu().detach() # N x D

        #print('BLR weight: ', torch.mean(W, dim=0))
        # print('BLR weight std: ', torch.std(W, dim=0))
        z = torch.matmul(X_test, W.t())  # batch x smale_num
        coff = -z * y_test  # batch x sample_num
        prob = torch.mean(1. / (1 + torch.exp(coff)), dim=1)
        acc = torch.mean((prob > .5).float())
        llh = torch.mean(torch.log(prob))
        return acc, llh

    def update_particles_and_eval_iters(self, target_dim, hidden_dim, net_lr, initial_particles,
                                        n_iter, step_size, reg_coefficient, jacobian_trace='approx',
                                        auto_corr=0.9, fudge_factor=1e-6):
        cur_particles = initial_particles
        historical_grad = 0
        kernel = MLP(in_dim=target_dim, out_dim=target_dim, hidden_dim=hidden_dim).to(self.device)
        f_net = MLP(in_dim=target_dim, out_dim=target_dim, hidden_dim=hidden_dim).to(self.device)
        optimizer = optim.Adam(f_net.parameters(), lr=net_lr, betas=(0.1, 0.1), amsgrad=True)
        for i in tqdm(range(n_iter)): # progress bar
            f_net.train()
            optimizer.zero_grad()
            cur_particles = cur_particles.detach().requires_grad_()
            f_x = f_net(cur_particles)
            if jacobian_trace == 'exact':
                tr_grad_f = self.exact_jacobian_trace(f_x, cur_particles)
            else:
                tr_grad_f = self.approx_jacobian_trace(f_x, cur_particles)
            score_p = self.keep_grad(self.log_posterior(cur_particles, i).sum(), cur_particles) # B x D
            scorep_fx = (score_p * f_x).sum(-1) # compute (dlogp(x)/dx)^T * f(x)
            stein_loss = (scorep_fx + tr_grad_f).mean()  # estimate of S(p, q)
            l2_penalty = (f_x * f_x).sum(-1).mean() * reg_coefficient
            loss = -1.0 * stein_loss + l2_penalty
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(f_net.parameters(), 1e8)
            optimizer.step()

            with torch.no_grad():
                f_net.eval()
                phi = f_net(cur_particles)
                if i == 0:
                    historical_grad = historical_grad + torch.multiply(phi, phi)
                else:
                    historical_grad = auto_corr * historical_grad + (1 - auto_corr) * torch.multiply(phi, phi)
                adj_grad = torch.divide(phi, fudge_factor + torch.sqrt(historical_grad))
                # adj_grad = phi
                cur_particles = cur_particles + step_size * adj_grad
                if (i + 1) % 100000 == 0:
                    acc, ll = self.evaluation(cur_particles.detach(), self.X_test, self.y_test)
                    print(f'Iter: {i + 1}, Acc:{acc:.4f}, LL:{ll:.4f}, phi:{f_x.mean():.4f}')
        final_particles = cur_particles

        return final_particles


## Load DateSet Utilities

In [15]:
from load_data import load_data_for_blr

## Train & Test

In [16]:
class Arguments():
    def __init__(self, num_particles, batch_size, num_trials):
        self.num_particles = num_particles
        self.batch_size = batch_size
        self.num_trials = num_trials

### preset

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--num_particles', type=int, default=200, help="particles number")
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--num_trials', type=int, default=20)
args = parser.parse_args()


In [17]:
import random

seed=123
random.seed(seed)
np.random.seed(seed)
# torch.manual_seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [10]:
# prepare training
import time

datasets = ['covertype', 'w8a', 'a9a', 'bioresponse']
step_size = {'covertype': 0.008, 'w8a': 0.05, 'a9a': 0.03,  'bioresponse' : 0.003}
hidden_dim = {'covertype': 100, 'w8a': 1000, 'a9a': 500, 'bioresponse' : 5000}
depth = {'covertype': 2, 'w8a': 1, 'a9a': 1, 'bioresponse': 1}
max_iters = {'covertype': 10000, 'w8a': 10000, 'a9a': 10000,  'bioresponse' : 10000}
reg_coefficient = {'covertype': 5.0, 'w8a': 1.0, 'a9a': 1.0, 'bioresponse': 1.0}



In [11]:
args = Arguments(10,100,2)
print(f"num_particles = {args.num_particles}||batch_size = {args.batch_size}||num_trials = {args.num_trials}")

num_particles = 10||batch_size = 100||num_trials = 2


## Kernel as a MLP