## Import utilities

In [211]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
from tqdm.auto import tqdm
import math
import numpy as np

In [212]:
import sys
sys.path.insert(0, ".")

## SVGD gradient with deep kernel

Define function `svgd_gradient(particals, score, h=-1)` for $\phi(x)$ to update particals

In [213]:
# defining deep kernel
'''
def rbf_kernel(x, y, sigma=1.0):
    K = torch.exp(-torch.cdist(x, y) ** 2 / (2 * sigma ** 2))
    return K
'''

class KernelNN(nn.Module):
    def __init__(self, in_dim=1, out_dim=1, hidden_dim=300, depth=1):
        super().__init__()
        if depth == 1:
            self.layer = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.SiLU(hidden_dim),
                nn.Linear(hidden_dim, out_dim),
            )
        else:
            self.layer = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.SiLU(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(hidden_dim),
                nn.Linear(hidden_dim, out_dim),
            )

    def forward(self, x):
        out = self.layer(x)
        return out


In [214]:

def svgd_grad_with_deepkernel(theta, score, model, h=-1):
    particles = model(theta)
    delta_x = particles.unsqueeze(0) - particles.unsqueeze(1)  # N x N x D
    pairwise_dists = delta_x.pow(2.0).sum(-1)  # N x N
    if h < 0: 
        # if h < 0, using median trick
        h = torch.median(pairwise_dists)
        h = torch.sqrt(0.5 * h / torch.log(torch.tensor(particles.shape[0]+1)))

    # compute the rbf kernel
    Kxy = torch.exp( - pairwise_dists / h**2 / 2)
    dxkxy = torch.autograd.grad(Kxy, theta, grad_outputs=torch.rand_like(Kxy), retain_graph=True, create_graph=True)[0]

    repulsive_grad = dxkxy
    attractive_grad = torch.mm(Kxy, score)
    return (attractive_grad + repulsive_grad) / particles.shape[0]


In [215]:
def log_posterior(theta, X_train, y_train):
    '''
    Returns:
        log density: 1-dim: sample_num
    '''
    W = theta  # sample_num x D
    D = theta.shape[1]
    Xs = X_train
    ys = y_train  # batch x 1
    z = torch.matmul(Xs, W.t())  # batch x sample_num
    coff = -z * ys  # batch x sample_num
    coff = torch.clamp(coff, min=-20, max=20)
    log_p_D_given_w = -torch.log(1 + torch.exp(coff)).sum(0)  # sample_num
    log_p_w_given_alpha = -0.5 * torch.sum(W * W, dim=1) + (D / 2) * math.log(1.)   # sample_num
    log_posterior = log_p_D_given_w  + log_p_w_given_alpha
    return log_posterior


In [216]:
X_train = torch.randn(100,10)
y_train = torch.randint(1,(100,)) 
y_train[y_train == 0] = -1
particles = torch.randn(100,10).requires_grad_()

In [217]:
log_posterior = log_posterior(particles, X_train= X_train, y_train= y_train)
score = torch.autograd.grad(log_posterior.sum(), particles, grad_outputs=None, retain_graph=True, create_graph=True)[0]

In [218]:
target_dim = X_train.shape[1]
kernel_model = KernelNN(in_dim=target_dim, out_dim=target_dim)
phi = svgd_grad_with_deepkernel(theta=particles, score=score, model=kernel_model)

In [219]:
phi.sum(-1).mean()

tensor(0.1123, grad_fn=<MeanBackward0>)

In [220]:
phi.mean()

tensor(0.0112, grad_fn=<MeanBackward0>)

## SVGD with deep kernel

In [221]:
class DKSVGD:
    def __init__(self, X_train, y_train, X_test, y_test, batch_size=100, alpha=1.0):
        self.device = device
        self.X_train, self.y_train = X_train, y_train
        self.X_test, self.y_test = X_test, y_test
        self.batch_size = min(batch_size, self.X_train.shape[0])
        self.alpha = 1.0 / alpha
        self.N = X_train.shape[0]
        self.permutation = torch.randperm(self.N)
        self.iter = 0

    def keep_grad(self, output, input, grad_outputs=None):
        return torch.autograd.grad(output, input, grad_outputs=grad_outputs, retain_graph=True, create_graph=True)[0]
    
    def svgd_grad_with_deepkernel(self, theta, score, model, h=-1):
        particles = model(theta)
        delta_x = particles.unsqueeze(0) - particles.unsqueeze(1)  # N x N x D
        pairwise_dists = delta_x.pow(2.0).sum(-1)  # N x N
        if h < 0: 
            # if h < 0, using median trick
            h = torch.median(pairwise_dists)
            h = torch.sqrt(0.5 * h / torch.log(torch.tensor(particles.shape[0]+1)))
        
        # compute the rbf kernel
        Kxy = torch.exp( - pairwise_dists / h**2 / 2)
        dxkxy = torch.autograd.grad(Kxy, theta, grad_outputs=torch.ones_like(Kxy), retain_graph=True, create_graph=True)[0]
        
        repulsive_grad = dxkxy
        attractive_grad = torch.mm(Kxy, score)
        return (attractive_grad + repulsive_grad) / particles.shape[0]
    
    def log_posterior_f(self, theta, iter):
        '''
        Returns:
            log density: 1-dim: sample_num
        '''
        W = theta  # sample_num x D
        D = theta.shape[1]
        batch = [i % self.N for i in range(iter * self.batch_size, (iter + 1) * self.batch_size)]
        ridx = self.permutation[batch]
        Xs = self.X_train[ridx, :]  # batch x D
        ys = self.y_train[ridx]  # batch x 1
        z = torch.matmul(Xs, W.t())  # batch x sample_num
        coff = -z * ys  # batch x sample_num
        coff = torch.clamp(coff, min=-20, max=20)
        log_p_D_given_w = -torch.log(1 + torch.exp(coff)).sum(0)  # sample_num
        log_p_w_given_alpha = -0.5 * self.alpha * torch.sum(W * W, dim=1) + (D / 2) * math.log(1.)   # sample_num
        log_posterior = log_p_D_given_w * self.N / Xs.shape[0] + log_p_w_given_alpha
        return log_posterior
    
    def evaluation(self, theta, X_test, y_test):
        W = theta.cpu().detach() # N x D

        #print('BLR weight: ', torch.mean(W, dim=0))
        # print('BLR weight std: ', torch.std(W, dim=0))
        z = torch.matmul(X_test, W.t())  # batch x smale_num
        coff = -z * y_test  # batch x sample_num
        prob = torch.mean(1. / (1 + torch.exp(coff)), dim=1)
        acc = torch.mean((prob > .5).float())
        llh = torch.mean(torch.log(prob))
        return acc, llh

    def update_particles_and_eval_iters(self, target_dim, hidden_dim, net_lr, initial_particles,
                                        n_iter, step_size, reg_coefficient,auto_corr=0.9, 
                                        fudge_factor=1e-6):
        cur_particles = initial_particles
        historical_grad = 0
        kernel_model = KernelNN(in_dim=target_dim, out_dim=target_dim, hidden_dim=hidden_dim)
        optimizer = optim.Adam(kernel_model.parameters(), lr=net_lr, betas=(0.1, 0.1), amsgrad=True)
        for i in tqdm(range(n_iter)): # progress bar
            kernel_model.train()
            optimizer.zero_grad()
            cur_particles = cur_particles.requires_grad_()
            log_p_target = self.log_posterior_f(cur_particles, i)
            score_p = self.keep_grad(log_p_target.sum(), cur_particles)

            f_x = svgd_grad_with_deepkernel(theta=particles, score=score_p, model=kernel_model)
            stein_loss = f_x.sum(-1).mean()  # estimate of S(p, q)
            loss = -1.0 * stein_loss
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                kernel_model.eval()
                phi = f_x
                if i == 0:
                    historical_grad = historical_grad + torch.multiply(phi, phi)
                else:
                    historical_grad = auto_corr * historical_grad + (1 - auto_corr) * torch.multiply(phi, phi)
                adj_grad = torch.divide(phi, fudge_factor + torch.sqrt(historical_grad))
                # adj_grad = phi
                cur_particles = cur_particles + step_size * adj_grad
                if (i + 1) % 100000 == 0:
                    acc, ll = self.evaluation(cur_particles.detach(), self.X_test, self.y_test)
                    print(f'Iter: {i + 1}, Acc:{acc:.4f}, LL:{ll:.4f}, phi:{f_x.mean():.4f}')
        final_particles = cur_particles

        return final_particles

## Load DateSet Utilities

In [222]:
from load_data import load_data_for_blr

## Train & Test

In [223]:
class Arguments():
    def __init__(self, num_particles, batch_size, num_trials):
        self.num_particles = num_particles
        self.batch_size = batch_size
        self.num_trials = num_trials

### preset

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--num_particles', type=int, default=200, help="particles number")
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--num_trials', type=int, default=20)
args = parser.parse_args()


In [224]:
import random

seed=123
random.seed(seed)
np.random.seed(seed)
# torch.manual_seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [225]:
# prepare training
import time

datasets = ['covertype', 'w8a', 'a9a', 'bioresponse']
step_size = {'covertype': 0.008, 'w8a': 0.05, 'a9a': 0.03,  'bioresponse' : 0.003}
hidden_dim = {'covertype': 100, 'w8a': 1000, 'a9a': 500, 'bioresponse' : 5000}
depth = {'covertype': 2, 'w8a': 1, 'a9a': 1, 'bioresponse': 1}
max_iters = {'covertype': 10000, 'w8a': 10000, 'a9a': 10000,  'bioresponse' : 10000}
reg_coefficient = {'covertype': 5.0, 'w8a': 1.0, 'a9a': 1.0, 'bioresponse': 1.0}



In [226]:
args = Arguments(10,100,2)
print(f"num_particles = {args.num_particles}||batch_size = {args.batch_size}||num_trials = {args.num_trials}")

num_particles = 10||batch_size = 100||num_trials = 2


In [227]:
# train
start = time.time()

for dataset in datasets[2:3]:
    acc = torch.zeros(args.num_trials)
    ll = torch.zeros(args.num_trials)
    X_train, y_train, X_test, y_test, X_val, y_val = load_data_for_blr(dataset)

    X_train, y_train = X_train.to(device), y_train.to(device)

    D = X_train.shape[1]
    for trial in range(args.num_trials):
        cur = torch.rand(args.num_particles, D).to(device).requires_grad_()
        model = DKSVGD(X_train, y_train, X_test, y_test, batch_size=args.batch_size)
        final_particles = model.update_particles_and_eval_iters(
            target_dim=D, hidden_dim=hidden_dim[dataset], net_lr=1e-4,
            initial_particles=cur, n_iter=max_iters[dataset],
            step_size=step_size[dataset], reg_coefficient=reg_coefficient[dataset])
        acc[trial], ll[trial] = model.evaluation(final_particles, X_test, y_test)

    print(f'Dataset: {dataset}, '
            f'Acc mean: {acc.mean().item():.6f}, Acc std: {acc.std().item():.6f}, '
            f'Likelihood mean: {ll.mean().item():.6f}, Likelihood std: {ll.std().item():.6f}')
end = time.time()
print(f'Run time: {end - start:.4f}s')

  0%|          | 0/10000 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (100x10 and 124x500)