In [None]:
import torch
from torch import nn

In [None]:
def compute_log_alpha(log_sigma, theta):
  r''' 
      Compute the log \alpha values from \theta and log \sigma^2.

      The relationship between \sigma^2, \theta, and \alpha as defined in the
      paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2.

      This method calculates the log \alpha values based on this relation:
        \log(\alpha) = 2*\log(\sigma) - 2*\log(\theta)
  ''' 
  log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(theta))
  log_alpha = torch.clamp(log_alpha, -10, 10) # clipping for a numerical stability
  return log_alpha



In [None]:
import numpy as np
from torch.nn.parameter import Parameter
import torch.nn.functional as F
# Linear Sparse Variational Dropout
# See https://arxiv.org/pdf/1701.05369.pdf for details
class LinearSVD(nn.Linear):
    def __init__(self, in_features, out_features, p_threshold = 0.952572, bias=True) -> None:
        r'''
            Parameters
            ----------
                in_features: int,
                    Number of input features.

                out_features: int,
                    Number of output features.
                
                p_threshold: float,
                    It consists in the \rho (binary dropout rate) threshold used in order to discard the weight.
                    In this approach, an Gaussian Dropout is being used which std is \alpha = \rho/(1-\rho) so, 
                    Infinitely large \sigma_{ij} corresponds to infinitely large multiplicative noise in w_{ij}. By 
                    default, the threshold is set to 0.952572 (\log(\sigma) ~ 3).

                bias: bool,
                    If True, adds a bias term to the output.
        '''
        super(LinearSVD, self).__init__(in_features, out_features, bias)
    
        self.log_alpha_threshold = np.log(p_threshold / (1-p_threshold))
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))

        self.log_sigma.data.fill_(-5) # Initialization based on the paper, Figure 1
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.log_alpha = compute_log_alpha(self.log_sigma, torch.abs(self.weight))
        
        if self.training:
            # LRT = local reparametrization trick (For details, see https://arxiv.org/pdf/1506.02557.pdf)
            lrt_mean =  F.linear(x, self.weight, self.bias)
            lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
            eps = torch.normal(0, torch.ones_like(lrt_std))
            return lrt_mean + lrt_std * eps
        
        return F.linear(x, self.weight * (self.log_alpha < self.log_alpha_threshold).float(), self.bias)

    def kl_reg(self):
        k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
        k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
        kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
        return -(torch.sum(kl))

In [72]:
# Loss function 
from torch.nn.functional import cross_entropy
class SGVLB(nn.Module):
    ''' 
        Stocastich Gradient Variational Bayes (SGVB) Loss function.
        More details in https://arxiv.org/pdf/1506.02557.pdf and https://arxiv.org/pdf/1312.6114.pdf
    '''

    def __init__(self, net, train_size):
        super(SGVLB, self).__init__()
        self.train_size = train_size
        self.net = net

    def forward(self, input, target, kl_weight=1.0):
        assert not target.requires_grad
        kl = 0.0
        for module in self.net.children():
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
        return cross_entropy(input, target) * self.train_size + kl_weight * kl

In [73]:
# Define a simple 2 layer Network
class Net(nn.Module):
    def __init__(self, threshold):
        super(Net, self).__init__()
        self.fc1 = LinearSVD(28*28, 300, threshold)
        self.fc2 = LinearSVD(300,  100, threshold)
        self.fc3 = LinearSVD(100,  10, threshold)
        self.threshold=threshold

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # x = F.log_softmax(self.fc3(x), dim=1)
        x = self.fc3(x)
        return x

In [74]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [75]:
from torch.optim import Adam
model = Net(threshold=.95).cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

sgvlb = SGVLB(model, len(train_loader.dataset)).cuda()

In [76]:
from torch.utils.tensorboard import SummaryWriter
import time
kl_weight = 0.02
epochs = 100

logger = SummaryWriter('log/sparse_vd')

for epoch in range(1, epochs + 1):
    time_start = time.perf_counter()
    scheduler.step()
    model.train()
    train_loss, train_acc = 0, 0 
    kl_weight = min(kl_weight+0.02, 1)
    logger.add_scalar('kl', kl_weight, epoch)
    logger.add_scalar('lr', scheduler.get_lr()[0], epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        
        data = data.view(-1, 28*28)
        optimizer.zero_grad()
        
        output = model(data)
        pred = output.data.max(1)[1] 
        loss = sgvlb(output, target, kl_weight)
        loss.backward()
        optimizer.step()
        
        train_loss += float(loss) 
        train_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())

    logger.add_scalar('tr_loss', train_loss / len(train_loader.dataset), epoch)
    logger.add_scalar('tr_acc', train_acc / len(train_loader.dataset) * 100, epoch)
    
    model.eval()
    test_loss, test_acc = 0, 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data = data.cuda()
        target = target.cuda()
        data = data.view(-1, 28*28)
        output = model(data)
        test_loss += float(sgvlb(output, target, kl_weight))
        pred = output.data.max(1)[1] 
        test_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())
        
    logger.add_scalar('te_loss', test_loss / len(test_loader.dataset), epoch)
    logger.add_scalar('te_acc', test_acc / len(test_loader.dataset) * 100, epoch)
    
    for i, c in enumerate(model.children()):
        if hasattr(c, 'kl_reg'):
            logger.add_scalar('sp_%s' % i, (c.log_alpha.cpu().data.numpy() > model.threshold).mean(), epoch)
    
    logger.add_scalar('time', time.perf_counter() - time_start, epoch)



KeyboardInterrupt: 

In [56]:
from torch.nn import CrossEntropyLoss
ce_loss = CrossEntropyLoss()
output = model(data)

ce_loss(output, target)*len(train_dataset)

tensor(8007.0654, device='cuda:0', grad_fn=<MulBackward0>)

In [55]:
kl = 0.0
for module in model.modules():
    if hasattr(module, 'kl_reg'):
        kl = kl + module.kl_reg()

print(kl)

tensor(-115500.7109, device='cuda:0', grad_fn=<AddBackward0>)


In [59]:
np.log(1e-6)

-13.815510557964274

In [63]:
a = torch.rand((3,2))
b = torch.rand((3,2))

(a*b).sum(dim=1)

tensor([0.5528, 0.6258, 0.3259])

In [66]:
a.contiguous()

tensor([[0.6956, 0.4385],
        [0.8895, 0.0064],
        [0.4863, 0.1158]])

In [71]:
a = torch.rand((3,2,3,3))
a[:,0,:,:].unsqueeze(dim=1).shape

torch.Size([3, 1, 3, 3])

In [78]:
a = np.log(10)

print(np.e ** a)

10.000000000000002


In [102]:
(model.fc1.log_alpha.cpu().data.numpy() < -2).sum()

452

In [99]:
p=.1
np.log(p/(1-p))

-2.197224577336219