In [12]:
import os

import pandas as pd
import seaborn as sn

import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.nn import functional as F

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

import sys
sys.path.append("../../src")
from concretedropout.pytorch import ConcreteDropout, ConcreteDropout2D, get_weight_regularizer, get_dropout_regularizer


Define the model and utility functions:

In [13]:
class MNISTModel(LightningModule):
    def __init__(self, weight_regularizer, dropout_regularizer):
        super().__init__()

        self.weight_regularizer = weight_regularizer
        self.dropout_regularizer = dropout_regularizer

        self.data_dir = PATH_DATASETS

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )


        self.conv1 = nn.Conv2d(1, 32, 3, padding="same")
        self.cd1 = ConcreteDropout2D(weight_regularizer=self.weight_regularizer, dropout_regularizer=self.dropout_regularizer)
        self.maxpool1 = nn.MaxPool2d(2)
        self.activation1 = nn.ReLU()

        self.conv2 = nn.Conv2d(32, 64, 3, padding="same")
        self.cd2 = ConcreteDropout2D(weight_regularizer=self.weight_regularizer, dropout_regularizer=self.dropout_regularizer)
        self.maxpool2 = nn.MaxPool2d(2)
        self.activation2 = nn.ReLU()

        self.flatten = nn.Flatten()

        self.l1 = torch.nn.Linear(7 * 7*64, 10)

        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

    def forward(self, x):
        x = self.cd1(x, self.conv1)
        x = self.activation1(x)
        x = self.maxpool1(x)

        x = self.cd2(x, self.conv2)
        x = self.activation2(x)
        x = self.maxpool2(x)

        x = self.flatten(x)

        x = self.l1(x)
        return F.log_softmax(x, dim=1)
        

    def training_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)
        
        reg = torch.zeros(1)
        for module in filter(lambda x: isinstance(x, ConcreteDropout), self.modules()):
            reg += module.regularization

        loss = F.nll_loss(logits, y) + reg
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE, num_workers=4)

In [14]:
wr = get_weight_regularizer(55000, l=1e-2, tau=0.1) # tau and l may need tuning, depending on the dataset
dr = get_dropout_regularizer(55000, tau=0.1, cross_entropy_loss=True)

mnist_model = MNISTModel(wr, dr)

In [15]:
trainer = Trainer(
    accelerator="cpu",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [16]:
trainer.fit(mnist_model)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



Missing logger folder: /mnt/d/Aure/Documenti/Github/ConcreteDropout-TF2/examples/PyTorch/lightning_logs

   | Name          | Type              | Params
-----------------------------------------------------
0  | conv1         | Conv2d            | 320   
1  | cd1           | ConcreteDropout2D | 1     
2  | maxpool1      | MaxPool2d         | 0     
3  | activation1   | ReLU              | 0     
4  | conv2         | Conv2d            | 18.5 K
5  | cd2           | ConcreteDropout2D | 1     
6  | maxpool2      | MaxPool2d         | 0     
7  | activation2   | ReLU              | 0     
8  | flatten       | Flatten           | 0     
9  | l1            | Linear            | 31.4 K
10 | val_accuracy  | Accuracy          | 0     
11 | test_accuracy | Accuracy          | 0     
-----------------------------------------------------
50.2 K    Trainable params
0         Non-trainable params
50.2 K    Total params
0.201     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [17]:
trainer.test(ckpt_path='best')

Restoring states from the checkpoint path at /mnt/d/Aure/Documenti/Github/ConcreteDropout-TF2/examples/PyTorch/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt
Loaded model weights from checkpoint at /mnt/d/Aure/Documenti/Github/ConcreteDropout-TF2/examples/PyTorch/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9722999930381775
        test_loss           0.0905378982424736
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.0905378982424736, 'test_acc': 0.9722999930381775}]

learned dropout probabilities:

In [18]:
for module in mnist_model.modules():
    if isinstance(module, ConcreteDropout):
        print(module.p.detach().numpy())

[0.00322718]
[0.00941805]
