In [38]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

from Sparse.modules.variational import LinearCD
import torch

In [137]:
from torch import nn
import torch.nn.functional as F
import numpy as np

class LinearCD(nn.Linear):
    r'''
        Linear layer with Concrete Dropout regularization.

        Code strongly inspired by: 
            https://github.com/danielkelshaw/ConcreteDropout/blob/master/condrop/concrete_dropout.py

        Note the relationship between the weight regularizer (w_reg) and dropout regularization (drop_reg):
        
            w_reg/drop_reg = (l^2)/2 
        
        with prior lengthscale l (number of in_features). 
        
        Note also that the factor of two should be ignored for cross-entropy loss, and used only for the
        Euclidean loss.
    '''
    def __init__(self, in_features, out_features, bias=True, threshold=.95, init_min=0.05, init_max=0.1):
        super(LinearCD, self).__init__(in_features, out_features, bias)        
        logit_init_min = np.log(init_min) - np.log(1. - init_min)
        logit_init_max = np.log(init_max) - np.log(1. - init_max)
        
        # The probability of deactive a neuron.
        self.logit_p = nn.Parameter(torch.rand(in_features) * (logit_init_max - logit_init_min) + logit_init_min)
        self.logit_threshold = np.log(threshold) - np.log(1. - threshold)

    def forward(self, x):
        if self.training:
            return F.linear(self.concrete_bernoulli(x), self.weight, self.bias)

        return F.linear(x, self.weight * (self.logit_p < self.logit_threshold).float(), self.bias) 

    def concrete_bernoulli(self, x):
        eps = 1e-8
        unif_noise = torch.cuda.FloatTensor(*x.size()).uniform_() if self.logit_p.is_cuda else torch.FloatTensor(*x.size()).uniform_()

        p = torch.sigmoid(self.logit_p)
        tmp = .1

        drop_prob = (torch.log(p + eps) - torch.log((1-p) + eps) + torch.log(unif_noise + eps)
        - torch.log((1. - unif_noise) + eps))
        drop_prob = torch.sigmoid(drop_prob / tmp)

        random_tensor = 1 - drop_prob
        retain_prob = 1 - p # rescale factor typical for dropout

        if self.training:
            self.activation_reg = random_tensor.sum(dim=1).mean() # Penalizing the number of features activated

        return torch.mul(x, random_tensor) / retain_prob

    def reg(self):
        return self.activation_reg

# Breast Cancer Wisconsin Dataset

In [152]:
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader

class BreastCancer(Dataset):
    r'''
        Breast Cancer Wisconsin Dataset
    '''
    def __init__(self, normalize=False):
        dataset = datasets.load_breast_cancer()
        self.data = torch.tensor(dataset.data).float()
        self.targets = torch.tensor(dataset.target)
    
        if normalize:
            self.data /= torch.max(self.data, dim=0)[0]

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    def __len__(self):
        return len(self.data)

In [157]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, nb_features):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            LinearCD(30, nb_features, bias=False),
            nn.ReLU(),
            nn.Linear(nb_features, nb_features//2),
            nn.ReLU(),
            nn.Linear(nb_features//2, 2)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [154]:
dataset = BreastCancer(normalize=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [209]:
from tqdm import tqdm

def train(model, dataset, batch_size = 64, n_epochs=10):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    kl_reg = 1e-6

    epoch_iterator = tqdm(
            range(n_epochs),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % 1},
        )

    for _ in epoch_iterator:
        kl_reg = min(kl_reg + .1e-2, 1e-1)
        for idx, (inputs, targets) in enumerate(loader):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)
            pred = model(inputs)

            loss = criterion(pred, targets) + 1e-1*model.model[0].reg()
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                epoch_iterator.set_postfix(tls="%.4f" % loss.item())

    return model

In [210]:
model = Model(512)
model = train(model, dataset, n_epochs=600)

100%|██████████| 600/600 [00:18<00:00, 32.40epoch/s, tls=0.8433]


In [211]:
torch.sigmoid(model.model[0].logit_p) < .5

tensor([False, False, False, False, False, False, False,  True, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False,  True, False, False],
       device='cuda:0')

In [212]:
model.model[0].reg()

tensor(7.7976, device='cuda:0', grad_fn=<MeanBackward0>)

In [214]:
torch.sigmoid(model.model[0].logit_p) 

tensor([0.8297, 0.7290, 0.8158, 0.8406, 0.7978, 0.8826, 0.7381, 0.1367, 0.7901,
        0.1992, 0.7448, 0.8966, 0.8553, 0.8460, 0.8912, 0.8829, 0.8925, 0.8951,
        0.8851, 0.8261, 0.8206, 0.4524, 0.8093, 0.7189, 0.7741, 0.8110, 0.7671,
        0.1566, 0.7976, 0.8744], device='cuda:0', grad_fn=<SigmoidBackward>)