In [48]:
import torch

def dkl(mu_1, sigma_1, mu_2, sigma_2):
    sigma_1 = torch.pow(sigma_1, 2)
    sigma_2 = torch.pow(sigma_2, 2)
    term_1 = torch.pow(mu_2 - mu_1, 2) / sigma_2 
    term_2 = sigma_1 / sigma_2 
    term_3 = torch.log(term_2)
    print(term_1)
    print(term_2)
    print(term_3)
    out = (.5 * (term_1 + term_2 - term_3 - 1)).sum()
    out = torch.nan_to_num(out)
    print(out)

In [49]:
mu_1 = torch.tensor([0])
mu_2 = torch.tensor([10])
sigma_1 = torch.tensor([2])
sigma_2 = torch.tensor([30])

dkl(mu_1, sigma_1, mu_2, sigma_2)

tensor([0.1111])
tensor([0.0044])
tensor([-5.4161])
tensor(2.2658)


In [50]:
mu_1 = torch.tensor([1])
mu_2 = torch.tensor([11])
sigma_1 = torch.tensor([2])
sigma_2 = torch.tensor([30])

dkl(mu_1, sigma_1, mu_2, sigma_2)

tensor([0.1111])
tensor([0.0044])
tensor([-5.4161])
tensor(2.2658)


In [51]:
mu_1 = torch.tensor([0])
mu_2 = torch.tensor([1])
sigma_1 = torch.tensor([5])
sigma_2 = torch.tensor([5])

dkl(mu_1, sigma_1, mu_2, sigma_2)

tensor([0.0400])
tensor([1.])
tensor([0.])
tensor(0.0200)


In [52]:
mu_1 = torch.tensor([0, 1, 0])
mu_2 = torch.tensor([10, 11, 1])
sigma_1 = torch.tensor([2, 2, 5])
sigma_2 = torch.tensor([30, 30, 5])

dkl(mu_1, sigma_1, mu_2, sigma_2)

tensor([0.1111, 0.1111, 0.0400])
tensor([0.0044, 0.0044, 1.0000])
tensor([-5.4161, -5.4161,  0.0000])
tensor(4.5517)


In [53]:
import argparse

from utils import default_args

def get_title(arg_dict):
    parser = argparse.ArgumentParser()
    for arg in vars(default_args):
        if(arg in arg_dict.keys()): parser.add_argument('--{}'.format(arg), default = arg_dict[arg])
        else:                       parser.add_argument('--{}'.format(arg), default = getattr(default_args, arg))
    args, _ = parser.parse_known_args()
    title = ""
    first = True
    for arg in vars(args):
        if(getattr(args, arg) != getattr(default_args, arg)):
            if(not first): title += "_"
            title += "{}_{}".format(arg, getattr(args, arg)) ; first = False
    if(len(title) == 0): title = "default"
    print(arg_dict, title)
    return(args, title)

with_bias  = default_args
no_bias, _ = get_title({"bias" : False})

{'bias': False} default


In [54]:
from torch import nn
from torchinfo import summary as torch_summary
from blitz.modules import BayesianLinear

from utils import device, default_args, init_weights, weights

class Bayes_Forward(nn.Module):
    
    def __init__(self, args):
        super(Bayes_Forward, self).__init__()
        
        self.pos_out = nn.Sequential(
            BayesianLinear(8, args.hidden, bias = args.bias),
            BayesianLinear(args.hidden, 6, bias = args.bias))
        
        self.pos_out.apply(init_weights)
        self.to(device)
        
    def forward(self, pos, action):
        x = torch.cat([pos, action], -1)
        x = self.pos_out(x).to("cpu")
        return(x) 
    
bayes_forward = Bayes_Forward(with_bias)
weights_mu, weights_sigma, bias_mu, bias_sigma = weights(bayes_forward)
print(weights_mu.shape, weights_sigma.shape, bias_mu.shape, bias_sigma.shape)

bayes_forward = Bayes_Forward(no_bias)
weights_mu, weights_sigma, bias_mu, bias_sigma = weights(bayes_forward)
print(weights_mu.shape, weights_sigma.shape, bias_mu.shape, bias_sigma.shape)

(default_args.hidden * 8 + default_args.hidden * 6, default_args.hidden + 6)



bayes_forward = Bayes_Forward(default_args)

print("\n\n")
print(bayes_forward)
print()
print(torch_summary(bayes_forward, ((1,6), (1,2))))

AttributeError: 'Namespace' object has no attribute 'bias'

In [55]:
class DKL_Guesser(nn.Module):
    
    def __init__(self, args):
        super(DKL_Guesser, self).__init__()
        
        self.error_in = nn.Linear(1, args.dkl_hidden)
        self.w_mu     = nn.Linear(args.hidden * 8 + args.hidden * 6, args.dkl_hidden)
        self.w_sigma  = nn.Linear(args.hidden * 8 + args.hidden * 6, args.dkl_hidden)
        self.b_mu     = nn.Linear(args.hidden + 6, args.dkl_hidden)
        self.b_sigma  = nn.Linear(args.hidden + 6, args.dkl_hidden)
        self.DKL_out  = nn.Linear(args.dkl_hidden * 5, 1)
        
    def forward(self, errors, weights_mu, weights_sigma, bias_mu, bias_sigma):
        errors_shape = errors.shape
        errors  = self.error_in(errors) 
        w_mu    = self.w_mu(weights_mu)
        w_sigma = self.w_sigma(weights_sigma)
        b_mu    = self.b_mu(bias_mu)
        b_sigma = self.b_sigma(bias_sigma) 
        print(errors.shape, w_mu.shape, w_sigma.shape, b_mu.shape, b_sigma.shape)
        w_mu    = torch.tile(w_mu, (1, errors.shape[1], errors.shape[2], 1))
        w_sigma = torch.tile(w_sigma, (1, errors.shape[1], errors.shape[2], 1))
        b_mu    = torch.tile(b_mu, (1, errors.shape[1], errors.shape[2], 1))
        b_sigma = torch.tile(b_sigma, (1, errors.shape[1], errors.shape[2], 1))
        print(errors.shape, w_mu.shape, w_sigma.shape, b_mu.shape, b_sigma.shape)
        x = torch.cat([errors, w_mu, w_sigma, b_mu, b_sigma], -1)
        x = self.DKL_out(x)
        return(x)
    
errors        = torch.ones((2, 8, 10, 1))
weights_mu    = torch.ones((2, 1, 1, 448))
weights_sigma = torch.ones((2, 1, 1, 448))
bias_mu       = torch.ones((2, 1, 1, 38))
bias_sigma    = torch.ones((2, 1, 1, 38))

dkl_guesser = DKL_Guesser(default_args)

print("\n\n")
print(dkl_guesser)
print()
print(torch_summary(dkl_guesser, (errors.shape, weights_mu.shape, weights_sigma.shape, bias_mu.shape, bias_sigma.shape)))

AttributeError: 'Namespace' object has no attribute 'dkl_hidden'