In [1]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

from Sparse.modules.variational import LinearCD
import torch

In [70]:
from torch import nn
import torch.nn.functional as F
import numpy as np

class LinearCD(nn.Linear):
    r'''
        Linear layer with Concrete Dropout regularization.

        Code strongly inspired by: 
            https://github.com/danielkelshaw/ConcreteDropout/blob/master/condrop/concrete_dropout.py

        Note the relationship between the weight regularizer (w_reg) and dropout regularization (drop_reg):
        
            w_reg/drop_reg = (l^2)/2 
        
        with prior lengthscale l (number of in_features). 
        
        Note also that the factor of two should be ignored for cross-entropy loss, and used only for the
        Euclidean loss.
    '''
    def __init__(self, in_features, out_features, bias=True, threshold=.1, init_min=.5, init_max=.5):
        super(LinearCD, self).__init__(in_features, out_features, bias)        
        logit_init_min = np.log(init_min) - np.log(1. - init_min)
        logit_init_max = np.log(init_max) - np.log(1. - init_max)
        
        # The probability of deactive a neuron.
        self.logit_p = nn.Parameter(torch.rand(in_features) * (logit_init_max - logit_init_min) + logit_init_min)
        self.logit_threshold = np.log(threshold) - np.log(1. - threshold)

    def forward(self, x):
        if self.training:
            return F.linear(self.concrete_bernoulli(x), self.weight, self.bias)

        return F.linear(x, self.weight * (self.logit_p < self.logit_threshold).float(), self.bias) 

    def concrete_bernoulli(self, x):
        eps = 1e-8
        unif_noise = torch.cuda.FloatTensor(*x.size()).uniform_() if self.logit_p.is_cuda else torch.FloatTensor(*x.size()).uniform_()

        p = torch.sigmoid(self.logit_p)
        tmp = .1

        drop_prob = (torch.log(p + eps) - torch.log((1-p) + eps) + torch.log(unif_noise + eps)
        - torch.log((1. - unif_noise) + eps))
        drop_prob = torch.sigmoid(drop_prob / tmp)

        random_tensor = 1 - drop_prob
        retain_prob = 1 - p # rescale factor typical for dropout

        return torch.mul(x, random_tensor)

    def reg(self):
        tmp = .1
        eps = 1e-6
        p = torch.sigmoid(self.logit_p)
        bernoulli = (torch.log(p + eps) - torch.log((1-p) + eps))
        reg = 1 - torch.sigmoid(bernoulli / tmp)
        return torch.sum(reg)

# Breast Cancer Wisconsin Dataset

In [71]:
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader

class BreastCancer(Dataset):
    r'''
        Breast Cancer Wisconsin Dataset
    '''
    def __init__(self, normalize=False):
        dataset = datasets.load_breast_cancer()
        self.data = torch.tensor(dataset.data).float()
        self.targets = torch.tensor(dataset.target)
    
        if normalize:
            self.data /= torch.max(self.data, dim=0)[0]

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    def __len__(self):
        return len(self.data)

In [72]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, nb_features, threshold = .75):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            LinearCD(30, nb_features, bias=False, threshold=threshold),
            nn.ReLU(),
            nn.Linear(nb_features, nb_features//2),
            nn.ReLU(),
            nn.Linear(nb_features//2, 2)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [73]:
dataset = BreastCancer(normalize=True)

eval_len = len(dataset) // 5 # 20% of the dataset
train_set, eval_set = torch.utils.data.random_split(dataset, [len(dataset) - eval_len, eval_len])

loader = DataLoader(eval_set, batch_size=128, shuffle=True)

In [74]:
from tqdm import tqdm

def train(model, dataset, batch_size = 128, n_epochs=10):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    criterion = nn.CrossEntropyLoss()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    reg = 1e-6

    epoch_iterator = tqdm(
            range(n_epochs),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % 1},
        )

    modules = []
    for i in model.modules():
        if isinstance(i, LinearCD):
            modules.append(i)

    for _ in epoch_iterator:
        reg = min(reg + .5e-5, 1e-3)
        for idx, (inputs, targets) in enumerate(loader):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)
            pred = model(inputs)

            reg_value = 0
            for module in modules:
                reg_value += module.reg()

            loss = criterion(pred, targets) + reg*reg_value
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                epoch_iterator.set_postfix(tls="%.4f" % loss.item())

    print(reg)
    return model

In [96]:
model = Model(512, threshold=.1)
model = train(model, train_set, n_epochs=500)

100%|██████████| 500/500 [00:09<00:00, 50.27epoch/s, tls=0.1075]

0.001





In [97]:
torch.sigmoid(model.model[0].logit_p)

tensor([0.7144, 0.5733, 0.7956, 0.1486, 0.7911, 0.6613, 0.1491, 0.0516, 0.8569,
        0.6134, 0.1162, 0.6057, 0.1518, 0.0117, 0.6140, 0.1717, 0.6347, 0.5819,
        0.6631, 0.0370, 0.4127, 0.0559, 0.1536, 0.0157, 0.6988, 0.2164, 0.0713,
        0.0347, 0.6146, 0.8080], device='cuda:0', grad_fn=<SigmoidBackward>)

In [98]:
torch.sigmoid(model.model[0].logit_p).data.cpu().numpy() < .1

array([False, False, False, False, False, False, False,  True, False,
       False, False, False, False,  True, False, False, False, False,
       False,  True, False,  True, False,  True, False, False,  True,
        True, False, False])

In [99]:
x = torch.rand(10, 30).cuda()
model.model[0].concrete_bernoulli(x)[0]

tensor([4.3115e-04, 9.5253e-08, 0.0000e+00, 9.3902e-01, 7.7253e-04, 0.0000e+00,
        7.3897e-01, 6.7087e-01, 0.0000e+00, 0.0000e+00, 2.6603e-01, 3.2602e-01,
        4.3707e-01, 9.6831e-01, 3.3489e-01, 5.3411e-01, 3.0525e-02, 3.0429e-01,
        3.2627e-01, 6.2648e-01, 4.1079e-01, 6.3978e-03, 2.3534e-01, 4.4083e-01,
        9.6102e-02, 1.7012e-01, 5.9307e-01, 2.9681e-01, 1.4406e-03, 0.0000e+00],
       device='cuda:0', grad_fn=<SelectBackward>)

In [100]:
model.eval()
x[0] == model.model[0].concrete_bernoulli(x)[0] 

tensor([False, False, False,  True, False, False, False, False, False, False,
        False, False, False,  True, False, False,  True, False, False, False,
        False,  True, False,  True, False,  True,  True,  True, False, False],
       device='cuda:0')

In [101]:
features_score, index = torch.sigmoid(model.model[0].logit_p).sort()

features_names = datasets.load_breast_cancer(as_frame=True).data.columns[index.cpu()]

print('Features:{}'.format(features_names))
print('Features Score:{}'.format(features_score))

Features:['area error' 'worst area' 'worst concave points'
 'fractal dimension error' 'mean concave points' 'worst texture'
 'worst concavity' 'radius error' 'mean area' 'mean concavity'
 'perimeter error' 'worst perimeter' 'compactness error'
 'worst compactness' 'worst radius' 'mean texture' 'concave points error'
 'texture error' 'mean fractal dimension' 'smoothness error'
 'worst symmetry' 'concavity error' 'mean compactness' 'symmetry error'
 'worst smoothness' 'mean radius' 'mean smoothness' 'mean perimeter'
 'worst fractal dimension' 'mean symmetry']
Features Score:tensor([0.0117, 0.0157, 0.0347, 0.0370, 0.0516, 0.0559, 0.0713, 0.1162, 0.1486,
        0.1491, 0.1518, 0.1536, 0.1717, 0.2164, 0.4127, 0.5733, 0.5819, 0.6057,
        0.6134, 0.6140, 0.6146, 0.6347, 0.6613, 0.6631, 0.6988, 0.7144, 0.7911,
        0.7956, 0.8080, 0.8569], device='cuda:0', grad_fn=<SortBackward0>)


In [104]:
x, y = next(iter(loader))

In [108]:
threshold = .1
model.model[0].logit_threshold = torch.tensor(np.log(threshold) - np.log(1. - threshold))
model.eval()
torch.argmax(torch.softmax(model(x.cuda()), dim=1), dim=1)

tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1,
        1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0,
        1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [109]:
y

tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1,
        1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0,
        1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1])

In [110]:
(torch.argmax(torch.softmax(model(x.cuda()), dim=1), dim=1) == y.cuda()).sum() / len(y)

tensor(0.9823, device='cuda:0')