In [30]:
class Prior:
    def __init__(self):
        pass

    def kl(self):
        """KL(q,p), where q is Gaussian and p is from the given prior family"""
        pass

    def init_params_for(self, model, init_mu=0.0, init_scale=0.1, init_const=-3.0):
        pass

In [31]:
class GaussianPrior(Prior):
    def __init__(self):
        super().__init__()

    def kl(self, q, p):
        q_mu, q_sigma = q['mu'], q['sigma']
        p_mu, p_sigma = p['mu'], p['sigma']
        
        ratio = (q_sigma / p_sigma) ** 2
        log_std = torch.log(1./ratio)
        mean_term = ratio + ((q_mu - p_mu) / p_sigma) ** 2
        return 0.5 * (log_std + mean_term - 1)

    def init_params_for(self, model, init_mu=0.0, init_scale=0.1, init_const=-3.0):
        nn.init.normal_(model.W_mu, mean=init_mu, std=init_scale)
        nn.init.normal_(model.b_mu, mean=init_mu, std=init_scale)
        model.W_rho.data.fill_(init_const)
        model.b_rho.data.fill_(init_const)

In [32]:
class ExponentialPrior(Prior):
    def __init__(self):
        super().__init__()

    def kl(self, q, p):
        p_rate = 1. / p['mu']  # mu = sigma = 1/rate
        q_sigma = q['sigma']
        # E[log p(X)] = log(lambda) - lambda E[X] = log(lambda) - 1
        logp = torch.log(p_rate) - 1
        # E[log q(X)] = -0.5 * [(1/sigma^2)Var(X) + log(2pi sigma^2)] = -0.5 * [1 + log(2pi sigma^2)]
        logq = -0.5 * (1. + torch.log(2*np.pi * (q_sigma ** 2)))
        return logq - logp

    def init_params_for(self, model, init_mu=0.0, init_scale=1., init_const=-3.0):
        exp_dist = Exponential(1./init_scale)  # Exp(rate)
        model.W_mu.data = exp_dist.sample(model.W_mu.shape) * torch.sign(torch.randn_like(model.W_mu))
        model.b_mu.data = exp_dist.sample(model.b_mu.shape) * torch.sign(torch.randn_like(model.b_mu))
        model.W_rho.data.fill_(init_const)
        model.b_rho.data.fill_(init_const)