<a href="https://colab.research.google.com/github/andriygav/BayesianDistilation/blob/master/code/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [45]:
import torch

import numpy as np

from tqdm.notebook import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Модель

In [368]:
class BayesianNetworkFNN(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device
    
    def __init__(self, 
                 input_dim=10, 
                 output_dim=1, 
                 layers=[],
                 mu_prior=None,
                 sigma_prior=None):
        r'''
        Initialise Neural Network for Bayesian inference

        :param input_dim: size of feature space
        :type input_dim: int
        :param output_dim: size of target space
        :type output_dim: int
        :param layers: sizes of each layer for fully connected neural network
        :type layers: list[int]
        :param mu_prior: parameter of prior for :math:`\mathsf(E)w`
        :type mu_prior: array
        :param sigma_prior: parameter of prior for :math:`\mathsf(V)w`
        :type sigma_prior: array
        '''
        super(BayesianNetworkFNN, self).__init__()

        layers = [input_dim] + layers + [output_dim]

        self.network = torch.nn.Sequential()
        for i in range(1, len(layers)):
            self.network.add_module('layer{}'.format(i), 
                                    torch.nn.Linear(layers[i-1], 
                                                    layers[i]))
            if i != len(layers) - 1:
                self.network.add_module('relu{}'.format(i), 
                                        torch.nn.LeakyReLU())
        weight = self.get_weight(self.network)
        
        self.log_sigma_posterior = torch.nn.Parameter(
            torch.zeros_like(weight).float(), requires_grad=True)
        self.log_sigma_prior = torch.nn.Parameter(
            torch.zeros_like(weight).float(), requires_grad=True)
        self.mu_prior = torch.nn.Parameter(
            weight.data, requires_grad=True)
        
        if sigma_prior is not None:
            self.log_sigma_prior = torch.nn.Parameter(
                torch.log(torch.tensor(sigma_prior).float()), requires_grad=True)

        if mu_prior is not None:
            self.mu_prior = torch.nn.Parameter(
                torch.tensor(mu_prior).float(), requires_grad=True)
            
            self.set_weight(self.network, weight, self.mu_prior.data)

        assert len(self.log_sigma_prior) == len(weight), \
          'length of sigma_prior and weight must bu equels but {} != {}'.format(
              len(self.log_sigma_prior), len(weight)
          )
        assert len(self.mu_prior) == len(weight), \
          'length of mu_prior and weight must bu equels but {} != {}'.format(
              len(self.mu_prior), len(weight)
          )

    def posterior_parameters(self):
        yield self.log_sigma_posterior
        for param in self.network.parameters():
            yield param

    def prior_parameters(self):
        yield self.log_sigma_prior
        yield self.mu_prior

    @staticmethod  
    def set_weight(seq, weight, new_weight):
        assert len(weight) == len(new_weight)

        bias = 0
        for param in seq.parameters():
            param_size = torch.tensor(param.size()).prod()
            param.data = new_weight[bias:bias+param_size].view_as(param)
            
            bias += param_size

    @property
    def weight(self):
        return self.get_weight(self.network)

    @staticmethod
    def get_weight(seq, requires_grad=True):
        weight = None
        if requires_grad:
            parameters = []
            for param in seq.parameters():
                parameters.append(param.view(-1))

            weight = torch.cat(parameters)
        else:
            with torch.no_grad():
                parameters = []
                for param in seq.parameters():
                    parameters.append(param.view(-1))

                weight = torch.cat(parameters)
        return weight

    def forward(self, x_batch):
        r'''
        Model inference for one batch

        :param x_batch: 
            input tensor of shape `batch_size` :math:`\times` `input_dim`
        :type x_batch: Tensor
        :return: 
            output tensor of shape `batch_size` :math:`\times` `output_dim`
        :rtype: Tensor
        '''

        return self.network(x_batch)

    def _D_KL_loss(self):
        r'''
        The method return KL divergence between prior and posterior. 
        The distributions are assumed to be normal.
        '''

        mu_posterior = self.weight

        D_KL_1 = 0.5*torch.exp(self.log_sigma_posterior \
                               -self.log_sigma_prior).sum()
        D_KL_2 = (0.5*torch.exp(-self.log_sigma_prior) \
                     *(self.mu_prior-mu_posterior)**2).sum()
        D_KL_3 = 0.5*(self.log_sigma_prior.sum() \
                      - self.log_sigma_posterior.sum())
        D_KL_4 = -1*0.5*len(mu_posterior)

        return D_KL_1 + D_KL_2 + D_KL_3 + D_KL_4

    def loss(self, likelihood, alpha=1.):
        r'''
        Compute completed loss for bayesian model.
        :param likelihood: data likelihood
        :type likelihood: Tensor
        '''
        return alpha*self._D_KL_loss() - likelihood

# Данные

In [369]:
np.random.seed(0)

l = 1024
n = 10
X = np.random.randn(l, n)
w = np.random.randn(n)

y = X@w + 0.1*np.random.randn(l)

In [370]:
X_train_tr = torch.tensor(X[:900]).float()
X_test_tr = torch.tensor(X[900:]).float()

y_train_tr = torch.tensor(y[:900]).view(-1, 1).float()
y_test_tr = torch.tensor(y[900:]).view(-1, 1).float()


In [371]:
train_dataset = torch.utils.data.TensorDataset(X_train_tr, y_train_tr)
test_dataset = torch.utils.data.TensorDataset(X_test_tr, y_test_tr)

In [372]:
def test(model, dataset, batch_size = 128):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True)

    real = []
    pred = []
    for x_batch, y_batch in dataloader:
        x_batch = x_batch.to(model.device)
        y_batch = y_batch.to(model.device)
        output = model(x_batch)

        real += y_batch.view(-1).detach().cpu().numpy().tolist()
        pred += output.view(-1).detach().cpu().numpy().tolist()

    return ((np.array(real) - np.array(pred))**2).mean()

In [382]:
model = BayesianNetworkFNN()
_ = model.to(device)

In [383]:
lr=0.001
optimiser_posterior = torch.optim.Adam(model.posterior_parameters(), lr=lr)
optimiser_prior = torch.optim.Adam(model.prior_parameters(), lr=0.01*lr)
loss_function = torch.nn.MSELoss()

In [384]:
test(model, train_dataset)

17.00112491541768

In [390]:
dataloader = torch.utils.data.DataLoader(train_dataset, 
                                         batch_size=8, 
                                         pin_memory=True, 
                                         shuffle=True)

iterator = tqdm(range(1000))
for epoch in iterator:
    total_loss = 0
    for x_batch, y_batch in dataloader:
        x_batch = x_batch.to(model.device)
        y_batch = y_batch.to(model.device)

        optimiser_posterior.zero_grad()
        output = model(x_batch)

        likelihood = -1*loss_function(output, y_batch)/len(x_batch)
        loss = model.loss(likelihood)

        loss.backward()
        optimiser_posterior.step()


        optimiser_prior.zero_grad()
        output = model(x_batch)

        likelihood = -1*loss_function(output, y_batch)/len(x_batch)
        loss = model.loss(likelihood)

        loss.backward()
        optimiser_prior.step()


    iterator.set_postfix({'test': test(model, test_dataset)})

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [391]:
test(model, train_dataset), test(model, test_dataset)

(0.6924054174011397, 0.6591168450820817)

In [392]:
model._D_KL_loss()

tensor(0.0807, device='cuda:0', grad_fn=<AddBackward0>)

In [393]:
likelihood

tensor(-0.0162, device='cuda:0', grad_fn=<DivBackward0>)

In [395]:
model.log_sigma_prior

Parameter containing:
tensor([0.8753, 1.3614, 0.5844, 0.5627, 0.8744, 1.3111, 0.5383, 1.1906, 1.3039,
        1.2961, 0.5340], device='cuda:0', requires_grad=True)