Skip to content

Commit

Permalink
Updated README and code.
Browse files Browse the repository at this point in the history
  • Loading branch information
clouizos committed Nov 2, 2017
1 parent bf42062 commit 7bb140f
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 92 deletions.
128 changes: 71 additions & 57 deletions BayesianLayer.py → BayesianLayers.py
Expand Up @@ -5,7 +5,7 @@
Variational Dropout version of linear and convolutional layers
Karen Ullrich, Oct 2017
Karen Ullrich, Christos Louizos, Oct 2017
"""

import math
Expand All @@ -19,7 +19,7 @@
from torch.nn.modules import utils


def reparameterise(mu, logvar, cuda=True, sampling=True):
def reparametrize(mu, logvar, cuda=False, sampling=True):
if sampling:
std = logvar.mul(0.5).exp_()
if cuda:
Expand All @@ -36,23 +36,23 @@ def reparameterise(mu, logvar, cuda=True, sampling=True):
# LINEAR LAYER
# -------------------------------------------------------

class LinearGroupVD(Module):
"""Linear Group Variational Dropout Layer
class LinearGroupNJ(Module):
"""Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).
References:
[1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
[2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
[3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
"""

def __init__(self, in_features, out_features, cuda=True, init_weight=None, init_bias=None, clip_var=None):
def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

super(LinearGroupVD, self).__init__()
super(LinearGroupNJ, self).__init__()
self.cuda = cuda
self.in_features = in_features
self.out_features = out_features
self.clip_var = clip_var
self.deterministic = False # flag is used for compressed inference
self.deterministic = False # flag is used for compressed inference
# trainable params according to Eq.(6)
# dropout params
self.z_mu = Parameter(torch.Tensor(in_features))
Expand Down Expand Up @@ -95,14 +95,17 @@ def reset_parameters(self, init_weight, init_bias):
self.weight_logvar.data.normal_(-9, 1e-2)
self.bias_logvar.data.normal_(-9, 1e-2)

def clip_variances(self):
if self.clip_var:
self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

def get_log_dropout_rates(self):
log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
return log_alpha

def compute_posterior_params(self):
weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
if self.clip_var:
weight_var = torch.clamp(weight_var, 0., self.clip_var)
self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
self.post_weight_mu = self.weight_mu * self.z_mu
return self.post_weight_mu, self.post_weight_var
Expand All @@ -115,20 +118,16 @@ def forward(self, x):
batch_size = x.size()[0]
# compute z
# note that we reparametrise according to [2] Eq. (11) (not [1])
z = reparameterise(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
cuda=self.cuda)
z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
cuda=self.cuda)

# apply local reparametrisation trick see [1] Eq. (6)
# to the parametrisation given in [3] Eq. (6)
xz = x * z
mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

var_weights = self.weight_logvar.exp()
if self.clip_var:
var_weights = torch.clamp(var_weights, 0., self.clip_var)
var_activations = F.linear(xz.pow(2), var_weights, self.bias_logvar.exp())

return reparameterise(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)
return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

def kl_divergence(self):
# KL(q(z)||p(z))
Expand Down Expand Up @@ -158,10 +157,17 @@ def __repr__(self):
# CONVOLUTIONAL LAYER
# -------------------------------------------------------

class _ConvNdGroupVD(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias, init_weight, init_bias, cuda=True):
super(_ConvNdGroupVD, self).__init__()
class _ConvNdGroupNJ(Module):
"""Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).
References:
[1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
[2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
[3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
super(_ConvNdGroupNJ, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
Expand All @@ -177,6 +183,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
self.groups = groups

self.cuda = cuda
self.clip_var = clip_var
self.deterministic = False # flag is used for compressed inference

if transposed:
self.weight_mu = Parameter(torch.Tensor(
Expand All @@ -192,8 +200,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
self.bias_mu = Parameter(torch.Tensor(out_channels))
self.bias_logvar = Parameter(torch.Tensor(out_channels))

self.z_mu = Parameter(torch.Tensor(self.out_channels, 1, 1))
self.z_logvar = Parameter(torch.Tensor(self.out_channels, 1, 1))
self.z_mu = Parameter(torch.Tensor(self.out_channels))
self.z_logvar = Parameter(torch.Tensor(self.out_channels))

self.reset_parameters(init_weight, init_bias)

Expand Down Expand Up @@ -229,6 +237,11 @@ def reset_parameters(self, init_weight, init_bias):
self.weight_logvar.data.normal_(-9, 1e-2)
self.bias_logvar.data.normal_(-9, 1e-2)

def clip_variances(self):
if self.clip_var:
self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

def get_log_dropout_rates(self):
log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
return log_alpha
Expand Down Expand Up @@ -274,20 +287,20 @@ def __repr__(self):
return s.format(name=self.__class__.__name__, **self.__dict__)


class Conv1dGroupVD(_ConvNdGroupVD):
class Conv1dGroupNJ(_ConvNdGroupNJ):
r"""
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, cuda=True, init_weight=None, init_bias=None):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
cuda=False, init_weight=None, init_bias=None, clip_var=None):
kernel_size = utils._single(kernel_size)
stride = utils._single(stride)
padding = utils._single(padding)
dilation = utils._single(dilation)

super(Conv1dGroupVD, self).__init__(
super(Conv1dGroupNJ, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, utils._pair(0), groups, bias, init_weight, init_bias, cuda)
False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var)

def forward(self, x):
if self.deterministic:
Expand All @@ -299,37 +312,37 @@ def forward(self, x):
mu_activations = F.conv1d(x, self.weight_mu, self.bias_mu, self.stride,
self.padding, self.dilation, self.groups)

var_weights = self.weight_logvar.exp()
var_activations = F.conv1d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride,
var_activations = F.conv1d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride,
self.padding, self.dilation, self.groups)
# compute z
# note that we reparametrise according to [2] Eq. (11) (not [1])
z = reparameterise(self.z_mu.repeat(batch_size, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1),
sampling=self.training,
cuda=self.cuda)
return reparameterise(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
cuda=self.cuda)
z = reparametrize(self.z_mu.repeat(batch_size, 1, 1), self.z_logvar.repeat(batch_size, 1, 1),
sampling=self.training, cuda=self.cuda)
z = z[:, :, None]

return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
cuda=self.cuda)

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'


class Conv2dGroupVD(_ConvNdGroupVD):
class Conv2dGroupNJ(_ConvNdGroupNJ):
r"""
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, cuda=True, init_weight=None, init_bias=None):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
cuda=False, init_weight=None, init_bias=None, clip_var=None):
kernel_size = utils._pair(kernel_size)
stride = utils._pair(stride)
padding = utils._pair(padding)
dilation = utils._pair(dilation)

super(Conv2dGroupVD, self).__init__(
super(Conv2dGroupNJ, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, utils._pair(0), groups, bias, init_weight, init_bias, cuda)
False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var)

def forward(self, x):
if self.deterministic:
Expand All @@ -341,37 +354,37 @@ def forward(self, x):
mu_activations = F.conv2d(x, self.weight_mu, self.bias_mu, self.stride,
self.padding, self.dilation, self.groups)

var_weights = self.weight_logvar.exp()
var_activations = F.conv2d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride,
var_activations = F.conv2d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride,
self.padding, self.dilation, self.groups)
# compute z
# note that we reparametrise according to [2] Eq. (11) (not [1])
z = reparameterise(self.z_mu.repeat(batch_size, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1),
sampling=self.training,
cuda=self.cuda)
return reparameterise(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
cuda=self.cuda)
z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1),
sampling=self.training, cuda=self.cuda)
z = z[:, :, None, None]

return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
cuda=self.cuda)

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'


class Conv3dGroupVD(_ConvNdGroupVD):
class Conv3dGroupNJ(_ConvNdGroupNJ):
r"""
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, cuda=True, init_weight=None, init_bias=None):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
cuda=False, init_weight=None, init_bias=None, clip_var=None):
kernel_size = utils._triple(kernel_size)
stride = utils._triple(stride)
padding = utils._triple(padding)
dilation = utils.triple(dilation)

super(Conv3dGroupVD, self).__init__(
super(Conv3dGroupNJ, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, utils._pair(0), groups, bias, init_weight, init_bias, cuda)
False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var)

def forward(self, x):
if self.deterministic:
Expand All @@ -388,11 +401,12 @@ def forward(self, x):
self.padding, self.dilation, self.groups)
# compute z
# note that we reparametrise according to [2] Eq. (11) (not [1])
z = reparameterise(self.z_mu.repeat(batch_size, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1),
sampling=self.training,
cuda=self.cuda)
return reparameterise(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
cuda=self.cuda)
z = reparametrize(self.z_mu.repeat(batch_size, 1, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1, 1),
sampling=self.training, cuda=self.cuda)
z = z[:, :, None, None, None]

return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
cuda=self.cuda)

def __repr__(self):
return self.__class__.__name__ + ' (' \
Expand Down
33 changes: 18 additions & 15 deletions README.md
@@ -1,17 +1,18 @@
# Code release for "Bayesian Compression for Deep Learning"


In "Bayesian Compression for Deep Learning" we take an information theoretic take on the compression of neural networks. We explicitly revisit the connection between the minimum description length principle and variational inference.
In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of neural networks.
By revisiting the connection between the minimum description length principle and variational inference we are
able to achieve up to 700x compression and up to 50x speed up (CPU to sparse GPU) for neural networks.

Compression of neural networks with up to 700x and speed up up to 50x (CPU to sparse GPU).

We achive these results by learning additive noise to the weights. We visualize the learning process in the following figures examplarily for a dense network with 300 and 100 connections. Whitness represents redundacy, red and blue positive and negative weights respectively.
We visualize the learning process in the following figures for a dense network with 300 and 100 connections.
White color represents redundancy whereas red and blue represent positive and negative weights respectively.

|First layer weights |Second Layer weights|
| :------ |:------: |
|![alt text](./figures/weight0_e.gif "First layer weights")|![alt text](./figures/weight1_e.gif "Second Layer weights")|

For dense networks it is also simple to reconstruct input feature impartance. We show this for a mask and 5 randomly chosen digits.
For dense networks it is also simple to reconstruct input feature importance. We show this for a mask and 5 randomly chosen digits.
![alt text](./figures/pixel.gif "Pixel importance")


Expand All @@ -30,21 +31,21 @@ For dense networks it is also simple to reconstruct input feature impartance. We
| |BC-GHS | 9.0 | 18* | 59|

## Usage
We provide an implementation in pyTorch for linear and convolutional layers for the group normal-Jeffreys prior (aka Group Variational Dropout) via
We provide an implementation in PyTorch for fully connected and convolutional layers for the group normal-Jeffreys prior (aka Group Variational Dropout) via:
```python
import BayesianLayer
import BayesianLayers
```
Layers can be inclued equivalently to their frequentist counter parts.
The layers can be then straightforwardly included eas follows:
```python
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# activation
self.relu = nn.ReLU()
# layers
self.fc1 = BayesianLayer.LinearGroupVD(28 * 28, 300, clip_var=0.04)
self.fc2 = BayesianLayer.LinearGroupVD(300, 100)
self.fc3 = BayesianLayer.LinearGroupVD(100, 10)
self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04)
self.fc2 = BayesianLayers.LinearGroupNJ(300, 100)
self.fc3 = BayesianLayers.LinearGroupNJ(100, 10)
# layers including kl_divergence
self.kl_list = [self.fc1, self.fc2, self.fc3]

Expand All @@ -60,7 +61,8 @@ Layers can be inclued equivalently to their frequentist counter parts.
KLD += layer.kl_divergence()
return KLD
```
The only addional effort is to define a KL-divergence. Which is of need for the optimisation of the variational lower bound
The only additional effort is to include the KL-divergence in the objective.
This is necessary if we want to the optimize the variational lower bound that leads to sparse solutions:
```python
N = 60000.
discrimination_loss = nn.functional.cross_entropy
Expand All @@ -70,15 +72,16 @@ def objective(output, target, kl_divergence):
return discrimination_error + kl_divergence / N
```
## Run an example
We provide a simple example, the LeNet-300-100 trained with normal-Jeffreys prior.
We provide a simple example, the LeNet-300-100 trained with the group normal-Jeffreys prior:
```sh
python example.py
```

## Retraining a regular neural network
Often times we wish to compress an already existing network. To retrain a pretrained network just inialize the weights when creating an equivalent Bayesian network
Instead of training a network from scratch we often need to compress an already existing network.
In this case we can simply initialize the weights with those of the pretrained network:
```python
BayesianLayer.LinearGroupVD(28*28, 300,init_weight=pretrained_weight, init_bias=pretrained_bias)
BayesianLayers.LinearGroupNJ(28*28, 300, init_weight=pretrained_weight, init_bias=pretrained_bias)
```
## *Reference*
The paper "Bayesian Compression for Deep Learning" has been accepted to NIPS 2017. Please cite us:
Expand Down

0 comments on commit 7bb140f

Please sign in to comment.