In [1]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir)

import numpy as np
import torch
from torch import nn

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

In [2]:
from Sparse.functional import sparse_sigmoid

In [11]:
output = torch.rand((1,2,4,4))
sparse_sigmoid.apply(torch.tensor(0.05), output, 1e-6)

tensor(0.5125)

In [12]:
import torch
from torch import sigmoid

def kl_divergence(p: float, q: torch.Tensor, apply_sigmoid=True) -> torch.Tensor:
    '''
        Kullback-Leibler (KL) divergence between a Bernoulli random variable with mean
        p and a Bernoulli random variable with mean q.

        For convolutional output tensor (shape B,C,H,W) the kl divergence is estimated per
        channel.

        Params
        ------
            p: float
                Sparsity parameter, typically a small value close to zero (i.e 0.05).

            q: torch.Tensor
                The output of a layer.

            apply_sigmoid: Bolean
                Indicate if it is necessary to apply sigmoid function to q in order to
                obtain the probability distribution.
        Return
        ------
            kl divergence estimation: torch.Tensor
                In general return a unique value but in convolutional output the tensor
                shape is defined by the number of Channels, i.e shape [1, C].
    '''
    # check if tensor belong to a convolutional output or not
    dim = 2 if len(q.shape) == 4 else 1

    q = sigmoid(q) if apply_sigmoid else q # sigmoid because we need the probability distributions

    rho_hat = torch.mean(q.flatten(dim), dim) 
    rho = torch.ones(rho_hat.shape).to(q.device) * p
    return torch.sum(rho * torch.log(rho/rho_hat) + (1 - rho) * torch.log((1 - rho)/(1 - rho_hat)), axis=0)

In [14]:
kl_divergence(0.05, output.flatten(1), apply_sigmoid=False)

tensor(0.3692)