In [1]:
import torch
from torch import nn
from torch.distributions.normal import Normal
from torch.distributions.kl import kl_divergence

ModuleNotFoundError: No module named 'torch'

In [None]:
def sample_new_weight(module, input):
    m = module

    if ('bias' in module._parameters.keys()) and (module.bias is not None):
        has_bias = True
    else:
        has_bias = False

    # Weight
    # std = module.weight_logvar.mul(1/2).exp()
    # eps = torch.randn_like(std)
    # setattr(module, "weight", module.weight_mean + eps*std)
    setattr(module, "weight", Normal(m.weight_mean, m.weight_logvar.mul(1/2).exp()).rsample())

    # Bias
    # std = torch.exp(0.5 * module.bias_logvar)
    # eps = torch.radn_like(std)
    # setattr(module, "bias", module.bias_mean + eps*std)
    if has_bias:
        setattr(module, "bias", Normal(m.bias_mean, m.bias_logvar.mul(1/2).exp()).rsample())
    else:
        setattr(module, "bias", None)

    return None


def make_variational_linear(module, name=None):
    m, has_bias = module, module.bias is not None
    setattr(module, "name", name)

    del m._parameters['weight']

    if has_bias:
        del m._parameters['bias']

    weight_mean_param   = nn.Parameter(torch.Tensor(m.out_features, m.in_features))
    weight_logvar_param = nn.Parameter(torch.Tensor(m.out_features, m.in_features))

    m.register_parameter('weight_mean', weight_mean_param)
    m.register_parameter('weight_logvar', weight_logvar_param)

    variance = 2 / (m.out_features + m.in_features)
    nn.init.normal_(weight_mean_param, 0.0, std = variance**(1/2))
    nn.init.constant_(weight_logvar_param, -5) # WAS -10

    if has_bias:
        bias_mean_param    = nn.Parameter(torch.Tensor(m.out_features))
        bias_logvar_param  = nn.Parameter(torch.Tensor(m.out_features))

        m.register_parameter('bias_mean', bias_mean_param)
        m.register_parameter('bias_logvar', bias_logvar_param)

        nn.init.constant_(bias_mean_param, 0.1)
        nn.init.constant_(bias_logvar_param, -5) # WAS -10


    m.register_forward_pre_hook(sample_new_weight)
    sample_new_weight(m, None)

    return m


In [None]:
variational_layers = []

fc3 = nn.Linear(2, 32)
fc3 = make_variational_linear(fc3)
variational_layers.append(fc3)

fc3h = nn.Linear(32, 64)
fc3h = make_variational_linear(fc3h)
variational_layers.append(fc3h)

In [None]:
type(variational_layers[0])

torch.nn.modules.linear.Linear

In [None]:
variational_layers

[Linear(in_features=2, out_features=32, bias=False),
 Linear(in_features=32, out_features=64, bias=False)]

In [None]:
variational_layers[0]._parameters.keys()

odict_keys(['weight_mean', 'weight_logvar', 'bias_mean', 'bias_logvar'])

In [None]:
KLP = 0
for l in variational_layers:
    bias_mu, bias_std = l.bias_mean, l.bias_logvar.mul(1/2).exp()
    q_bias = Normal(bias_mu, bias_std)
    p_bias = Normal(torch.zeros_like(bias_mu), torch.ones_like(bias_std))
    KLP += kl_divergence(q_bias, p_bias).sum()

In [None]:
KLP

tensor(192.8034, grad_fn=<AddBackward0>)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=35838f82-2ce6-4453-9bd2-2d87a43af151' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>