In [None]:
import torch
from torch import nn

https://github.com/bayesgroup/variational-dropout-sparsifies-dnn

In [None]:
def compute_log_alpha(log_sigma, theta):
  r''' 
      Compute the log \alpha values from \theta and log \sigma^2.

      The relationship between \sigma^2, \theta, and \alpha as defined in the
      paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2.

      This method calculates the log \alpha values based on this relation:
        \log(\alpha) = 2*\log(\sigma) - 2*\log(\theta)
  ''' 
  log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(theta))
  log_alpha = torch.clamp(log_alpha, -10, 10) # clipping for a numerical stability
  return log_alpha



In [None]:
import numpy as np
from torch.nn.parameter import Parameter
import torch.nn.functional as F
# Linear Sparse Variational Dropout
# See https://arxiv.org/pdf/1701.05369.pdf for details
class LinearSVD(nn.Linear):
    def __init__(self, in_features, out_features, p_threshold = 0.952572, bias=True) -> None:
        r'''
            Parameters
            ----------
                in_features: int,
                    Number of input features.

                out_features: int,
                    Number of output features.
                
                p_threshold: float,
                    It consists in the \rho (binary dropout rate) threshold used in order to discard the weight.
                    In this approach, an Gaussian Dropout is being used which std is \alpha = \rho/(1-\rho) so, 
                    Infinitely large \sigma_{ij} corresponds to infinitely large multiplicative noise in w_{ij}. By 
                    default, the threshold is set to 0.952572 (\log(\sigma) ~ 3).

                bias: bool,
                    If True, adds a bias term to the output.
        '''
        super(LinearSVD, self).__init__(in_features, out_features, bias)
    
        self.log_alpha_threshold = np.log(p_threshold / (1-p_threshold))
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))

        self.log_sigma.data.fill_(-5) # Initialization based on the paper, Figure 1
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.log_alpha = compute_log_alpha(self.log_sigma, torch.abs(self.weight))
        
        if self.training:
            # LRT = local reparametrization trick (For details, see https://arxiv.org/pdf/1506.02557.pdf)
            lrt_mean =  F.linear(x, self.weight, self.bias)
            lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
            eps = torch.normal(0, torch.ones_like(lrt_std))
            return lrt_mean + lrt_std * eps
        
        return F.linear(x, self.weight * (self.log_alpha < self.log_alpha_threshold).float(), self.bias)

    def kl_reg(self):
        k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
        k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
        kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
        return -(torch.sum(kl))

In [None]:
# Loss function 
from torch.nn.functional import cross_entropy
class SGVBL(nn.Module):
    ''' 
        Stocastich Gradient Variational Bayes (SGVB) Loss function.
        More details in https://arxiv.org/pdf/1506.02557.pdf and https://arxiv.org/pdf/1312.6114.pdf
    '''

    def __init__(self, model, train_size, loss=cross_entropy):
        super(SGVBL, self).__init__()
        self.train_size = train_size
        self.net = model
        self.loss = loss

        self.variational_layers = []
        for module in model.modules():
            if isinstance(module, (LinearSVD)):
                self.variational_layers.append(module)

    def forward(self, input, target, kl_weight=1.0):
        assert not target.requires_grad
        kl = 0.0
        for layer in self.variational_layers:
            kl += layer.kl_reg()
        # for module in self.net.children():
        #     if hasattr(module, 'kl_reg'):
        #         kl = kl + module.kl_reg()

        # return self.loss(input, target) * self.train_size + kl_weight * kl    
        return self.loss(input, target) + (kl_weight/self.train_size) * kl # Lo vi en concrete dropout que el kl_weight es 1/train_size

In [None]:
# Define a simple 2 layer Network
class Net(nn.Module):
    def __init__(self, threshold):
        super(Net, self).__init__()
        self.fc1 = LinearSVD(28*28, 300, threshold)
        self.fc2 = LinearSVD(300,  100, threshold)
        self.fc3 = LinearSVD(100,  10, threshold)
        self.threshold=threshold

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # x = F.log_softmax(self.fc3(x), dim=1)
        x = self.fc3(x)
        return x

In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
from torch.optim import Adam
model = Net(threshold=.95).cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

sgvlb = SGVBL(model, len(train_loader.dataset)).cuda()

In [None]:
from torch.utils.tensorboard import SummaryWriter
import time
kl_weight = 0.02
epochs = 100

logger = SummaryWriter('log/sparse_vd')

for epoch in range(1, epochs + 1):
    time_start = time.perf_counter()
    model.train()
    train_loss, train_acc = 0, 0 
    kl_weight = min(kl_weight+0.02, 1)
    logger.add_scalar('kl', kl_weight, epoch)
    logger.add_scalar('lr', scheduler.get_lr()[0], epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        
        data = data.view(-1, 28*28)
        optimizer.zero_grad()
        
        output = model(data)
        pred = output.data.max(1)[1] 
        loss = sgvlb(output, target, kl_weight)
        loss.backward()
        optimizer.step()
        
        train_loss += float(loss) 
        train_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())

    scheduler.step()

    logger.add_scalar('tr_loss', train_loss / len(train_loader.dataset), epoch)
    logger.add_scalar('tr_acc', train_acc / len(train_loader.dataset) * 100, epoch)

    model.eval()
    test_loss, test_acc = 0, 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data = data.cuda()
        target = target.cuda()
        data = data.view(-1, 28*28)
        output = model(data)
        test_loss += float(sgvlb(output, target, kl_weight))
        pred = output.data.max(1)[1] 
        test_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())
        
    logger.add_scalar('te_loss', test_loss / len(test_loader.dataset), epoch)
    logger.add_scalar('te_acc', test_acc / len(test_loader.dataset) * 100, epoch)
    
    for i, c in enumerate(model.children()):
        if hasattr(c, 'kl_reg'):
            logger.add_scalar('sp_%s' % i, (c.log_alpha.cpu().data.numpy() > model.threshold).mean(), epoch)
    
    logger.add_scalar('time', time.perf_counter() - time_start, epoch)

In [None]:
from torch.nn import CrossEntropyLoss
ce_loss = CrossEntropyLoss()
output = model(data)

ce_loss(output, target)*len(train_dataset)

In [None]:
kl = 0.0
for module in model.modules():
    if hasattr(module, 'kl_reg'):
        kl = kl + module.kl_reg()

print(kl)

In [None]:
np.log(1e-6)

In [None]:
a = torch.rand((3,2))
b = torch.rand((3,2))

(a*b).sum(dim=1)

# Convolutional Variational Dropout Layerç

In [None]:
import torch
from torch import nn
import numpy as np
from torch.nn.parameter import Parameter

class Conv2dSVD(nn.Conv2d):
    '''
        Convolutional layer with SVD regularization.
    '''
    def __init__(self, in_channels, out_channels, kernel_size, p_threshold = 0.952572, **kargs):
        super(Conv2dSVD, self).__init__(in_channels, out_channels, kernel_size, **kargs)

        self.log_alpha_threshold = np.log(p_threshold / (1-p_threshold))
        self.log_sigma = Parameter(torch.Tensor(self.weight.shape))

        self.log_sigma.data.fill_(-5) # Initialization based on the paper, Figure 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.log_alpha = compute_log_alpha(self.log_sigma, torch.abs(self.weight))
        
        if self.training:
            # LRT = local reparametrization trick (For details, see https://arxiv.org/pdf/1506.02557.pdf)
            lrt_mean =  F.linear(x, self.weight, self.bias)
            lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
            eps = torch.normal(0, torch.ones_like(lrt_std))
            return lrt_mean + lrt_std * eps
        
        return F.linear(x, self.weight * (self.log_alpha < self.log_alpha_threshold).float(), self.bias)
        pass

In [None]:
model = Conv2dSVD(1,3, kernel_size=3, stride=1, padding=1, bias=True)

In [None]:
x = torch.rand(2,1,6,6)
mean = model._conv_forward(x, model.weight, model.bias).flatten(1)
std = model._conv_forward(x*x, torch.exp(model.log_sigma * 2.0), None).flatten(1)

In [None]:
std

In [None]:
a = torch.ones((1,1,3,3))-.5
a[0,0,0,0] = 1
a*a

In [None]:
sigma = torch.Tensor(a.weight.shape)
compute_log_alpha(sigma, a.weight)

In [None]:
import torch
torch.Tensor(a.weight.shape).shape

In [None]:
import torch
y = torch.sigmoid(torch.rand(35,4))
p = torch.sigmoid(torch.rand(35,8))

In [None]:
print(y@y.T)

print(p@p.T)

In [None]:
eps = torch.eye(4, dtype=y.dtype, device=y.device) * 1e-6
print(eps)
eps = eps.unsqueeze(dim=0).unsqueeze(dim=0)
eps

In [None]:
x = (y@y.T) - y@p.T
x = torch.linalg.cholesky(x)
# diag = torch.diagonal(x, dim1=-2, dim2=-1)
# 2 * torch.sum(torch.log(diag + 1e-8), dim=-1)

In [2]:
import torch
from torch import sigmoid
from torch import nn
import numpy as np

class LinearSCD(nn.Linear):
    r'''
        Linear layer with Sparse Concrete Dropout regularization.

        Code strongly inspired by: 
            https://github.com/danielkelshaw/ConcreteDropout/blob/master/condrop/concrete_dropout.py

        Note the relationship between the weight regularizer (w_reg) and dropout regularization (drop_reg):
        
            w_reg/drop_reg = (l^2)/2 
        
        with prior lengthscale l (number of in_features). 
        
        Note also that the factor of two should be ignored for cross-entropy loss, and used only for the
        Euclidean loss.
    '''
    def __init__(self, in_features, out_features, bias=True, p_threshold = 0.5, w_reg=1e-6, drop_reg=1e-3, init_min=0.2, init_max=0.5):
        super(LinearSCD, self).__init__(in_features, out_features, bias)
        self.logit_threshold = np.log(p_threshold) - np.log((1-p_threshold))
        
        logit_init_min = np.log(init_min) - np.log(1. - init_min)
        logit_init_max = np.log(init_max) - np.log(1. - init_max)
        
        # The probability of deactive a neuron.
        self.logit_p = nn.Parameter(torch.rand(1) * (logit_init_max - logit_init_min) + logit_init_min)
        
        # The weight and Dropout regularization term.
        self.w_reg = w_reg
        self.drop_reg = drop_reg

    def concrete_bernoulli(self, x):
        # Reparametrization trick
        eps = 1e-8
        unif_noise = torch.empty(x.size()).uniform_()
        p = torch.sigmoid(self.logit_p)
        tmp = .1

        drop_prob = (torch.log(p + eps) - torch.log((1-p) + eps) + torch.log(unif_noise + eps)
        - torch.log((1. - unif_noise) + eps))
        drop_prob = torch.sigmoid(drop_prob / tmp)

        random_tensor = 1 - drop_prob
        retain_prob = 1 - p # rescale factor typical for dropout
        return torch.mul(x, random_tensor) / retain_prob


    def kl_reg(self):
        # KL regularization term
        # For more deatils, see https://arxiv.org/pdf/1705.07832.pdf
        p = torch.sigmoid(self.logit_p)

        square_param = torch.sum(torch.pow(self.weight, 2), dim=0)
        # if self.bias is not None: # Tiene sentido el bias?!
        #     square_param += torch.pow(self.bias, 2)

        # Weights regularization divided by (1-p) because of the rescaling 
        # factor in the dropout distribution.
        weights_reg = self.w_reg * square_param / (1.0 - p) 

        # dropout regularization term (bernolli entropy) 
        l = self.weight.size(1)
        dropout_reg = (p * torch.log(p) + (1.0 - p) * torch.log(1.0 - p))
        dropout_reg = (self.drop_reg * l) * dropout_reg

        kl_reg = torch.sum(weights_reg + dropout_reg)
        return kl_reg


In [8]:
from torch.nn import functional as F
class Net(nn.Module):
    def __init__(self, threshold):
        super(Net, self).__init__()
        l = 1e-4
        wr = l**2.
        dr = 1e-6 # 2 for euclidean loss and 1 for cross-entropy loss
        self.fc1 = LinearSCD(10*10, 48, bias=False, p_threshold = threshold, w_reg=wr, drop_reg=dr)
        self.fc2 = LinearSCD(48,  24, bias=True, p_threshold = threshold, w_reg=wr, drop_reg=dr)
        self.fc3 = LinearSCD(24,  10, bias=True, p_threshold = threshold, w_reg=wr, drop_reg=dr)
        self.threshold=threshold

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [9]:
# Loss function 
from torch.nn.functional import cross_entropy
class SGVBL(nn.Module):
    ''' 
        Stocastich Gradient Variational Bayes (SGVB) Loss function.
        More details in https://arxiv.org/pdf/1506.02557.pdf and https://arxiv.org/pdf/1312.6114.pdf
    '''

    def __init__(self, model, train_size, loss=cross_entropy):
        super(SGVBL, self).__init__()
        self.train_size = train_size
        self.net = model
        self.loss = loss

        self.variational_layers = []
        for module in model.modules():
            if isinstance(module, (LinearSCD)):
                self.variational_layers.append(module)

    def forward(self, input, target, kl_weight=1.0):
        assert not target.requires_grad
        kl = 0.0
        for layer in self.variational_layers:
            kl += layer.kl_reg()
        # for module in self.net.children():
        #     if hasattr(module, 'kl_reg'):
        #         kl = kl + module.kl_reg()

        return self.loss(input, target) * self.train_size + kl_weight * kl    
        # return self.loss(input, target) + (kl_weight/self.train_size) * kl # Lo vi en "concrete dropout" que el kl_weight es 1/train_size

In [10]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((10, 10)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [11]:
from torch.optim import Adam
model = Net(threshold=.95).cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

sgvlb = SGVBL(model, len(train_loader.dataset)).cuda()

In [12]:
from torch.utils.tensorboard import SummaryWriter
import time
kl_weight = 0.02
epochs = 100

logger = SummaryWriter('log/sparse_scd')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

for epoch in range(1, epochs + 1):
    time_start = time.perf_counter()
    model.train()
    train_loss, train_acc = 0, 0 
    kl_weight = min(kl_weight+0.02, 1)
    logger.add_scalar('kl', kl_weight, epoch)
    logger.add_scalar('lr', scheduler.get_lr()[0], epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        
        data = data.view(-1, 10*10)
        optimizer.zero_grad()
        
        output = model(data)
        pred = output.data.max(1)[1] 
        loss = sgvlb(output, target, kl_weight)
        loss.backward()
        optimizer.step()
        
        train_loss += float(loss) 
        train_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())

    scheduler.step()

    logger.add_scalar('tr_loss', train_loss / len(train_loader.dataset), epoch)
    logger.add_scalar('tr_acc', train_acc / len(train_loader.dataset) * 100, epoch)

    model.eval()
    test_loss, test_acc = 0, 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data = data.cuda()
        target = target.cuda()
        data = data.view(-1, 10*10)
        output = model(data)
        test_loss += float(sgvlb(output, target, kl_weight))
        pred = output.data.max(1)[1] 
        test_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())
        
    logger.add_scalar('te_loss', test_loss / len(test_loader.dataset), epoch)
    logger.add_scalar('te_acc', test_acc / len(test_loader.dataset) * 100, epoch)
    
    # for i, c in enumerate(model.children()):
    #     if hasattr(c, 'kl_reg'):
    #         logger.add_scalar('sp_%s' % i, (c.log_alpha.cpu().data.numpy() > model.threshold).mean(), epoch)
    
    logger.add_scalar('time', time.perf_counter() - time_start, epoch)



KeyboardInterrupt: 

In [15]:
torch.sigmoid(model.fc2.logit_p).mean()

tensor(0.4997, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
test = model.fc1.weight.data.cpu().detach().numpy()
from matplotlib import pyplot as plt

plt.hist(test.flatten(), bins=50)
plt.show()

In [None]:
a = model.fc1.weight[model.fc1.weight < 0.1].shape
b = model.fc1.weight[model.fc1.weight > -0.1].sum()

In [None]:
b - a

In [None]:
model.fc1.weight[model.fc1.weight < 0.01].shape

In [None]:
model.fc1.weight[model.fc1.weight > -0.01].shape

In [None]:
28*28

In [None]:
fc1.weight