In [1]:

import torchvision
import torch
import numpy as np

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
train_set = torchvision.datasets.MNIST(
   root="./data",
   download=True,
   train=True,
   transform=torchvision.transforms.ToTensor()
)
test_set = torchvision.datasets.MNIST(
   root="./data",
   download=True,
   train=False,
   transform=torchvision.transforms.ToTensor()
)

amount = 500
batch_size = 100
epochs = 1
workers = 8

shuffled_ind = np.random.permutation(np.arange(len(train_set)))

calibration_set = torch.utils.data.Subset(train_set, shuffled_ind[:amount])
train_set = torch.utils.data.Subset(train_set, shuffled_ind[amount:])

# create Dataloaders
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=8)
calibration_loader = torch.utils.data.DataLoader(calibration_set, shuffle=True, batch_size=batch_size, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=batch_size, num_workers=8)

In [3]:
#Bad Net
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional
import torch

class BadNet(pl.LightningModule):

    def __init__(self):
        super(BadNet, self).__init__()        
        self.conv1 = torch.nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
            ),                              
            nn.ReLU(),                      
            nn.AvgPool2d(kernel_size=(2, 2), stride=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(
                in_channels=16, 
                out_channels=32, 
                kernel_size=5, 
                stride=1),     
            nn.ReLU(),                      
            nn.AvgPool2d(kernel_size=(2, 2), stride=2),                
        )
        self.f1 = nn.Sequential(
            nn.Linear(in_features=(32 * 4 * 4), out_features=512),
            nn.ReLU(),
        )
        # fully connected layer, output 10 classes
        self.out = nn.Sequential(
            nn.Linear(in_features=512, out_features=10),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)        
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        x = self.f1(x)
        output = self.out(x)
        return output

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        out = self.forward(x)

        loss = nn.functional.cross_entropy(out, y)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        out = self.forward(x)
        test_loss = nn.functional.cross_entropy(out, y)
        self.log("test_loss", test_loss)

    def backward(self, loss, optimizer, optimizer_idx):
        loss.backward(retain_graph=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self.forward(x)
        validiation_loss = nn.functional.cross_entropy(out, y)
        self.log("validation_loss", validiation_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [4]:
bad_net = BadNet()
trainer = pl.Trainer(max_epochs=epochs)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
trainer = trainer.fit(bad_net, train_loader)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")

  | Name  | Type       | Params
-------------------------------------
0 | conv1 | Sequential | 416   
1 | conv2 | Sequential | 12.8 K
2 | f1    | Sequential | 262 K 
3 | out   | Sequential | 5.1 K 
-------------------------------------
281 K     Trainable params
0         Non-trainable params
281 K     Total params
1.124     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [7]:
from tqdm.notebook import tqdm

def get_logits(model, dataloader):
    logits = []
    labels = []
    model.eval()

    for x, y in tqdm(dataloader):
        with torch.no_grad():
            logits.append(model(x))
            labels.append(y)
    
    return torch.cat(logits), torch.cat(labels)

In [22]:
def calc_prediction_sets(model, calibration_loader, test_loader, alpha=0.3):
    calib_logits, calib_y = get_logits(model, calibration_loader)
    n = len(calibration_loader)

    probs = calib_logits.softmax(dim=1)

    scores = 1 - torch.gather(probs, 1, calib_y.unsqueeze(dim=1))
    qhat = torch.quantile(scores, np.ceil((n + 1) * (1 - alpha))/n)

    test_logits, test_y = get_logits(model, test_loader)
    smx = test_logits.softmax(dim=1)

    return (smx > (1 - qhat)).nonzero()

In [23]:
prediction_sets = calc_prediction_sets(bad_net, calibration_loader, test_loader)
print(prediction_sets)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

tensor([[   0,    7],
        [   1,    2],
        [   2,    1],
        ...,
        [9998,    5],
        [9998,    8],
        [9999,    6]])


In [34]:
nps = prediction_sets.detach().numpy()

x0 = nps[::, 0]
x1 = nps[::, 1]
print(np.unique(x0).shape)
print(np.unique(x1))
print(nps)

(10000,)
[0 1 2 3 4 5 6 7 8 9]
[[   0    7]
 [   1    2]
 [   2    1]
 ...
 [9998    5]
 [9998    8]
 [9999    6]]
