In [1]:
"""
@article{ding2023classconditional,
  title={Class-Conditional Conformal Prediction with Many Classes},
  author={Ding, Tiffany and Angelopoulos, Anastasios N and Bates, 
          Stephen and Jordan, Michael I and Tibshirani, Ryan J},
  journal={arXiv preprint arXiv:2306.09335},
  year={2023}
}
@article{huang2023conformal,
  title={Conformal Prediction for Deep Classifier via Label Ranking},
  author={Huang, Jianguo and Xi, Huajun and Zhang, Linjun and Yao, Huaxiu and Qiu, Yue and Wei, Hongxin},
  journal={arXiv preprint arXiv:2310.06430},
  year={2023}
}
"""

'\n@article{ding2023classconditional,\n  title={Class-Conditional Conformal Prediction with Many Classes},\n  author={Ding, Tiffany and Angelopoulos, Anastasios N and Bates, \n          Stephen and Jordan, Michael I and Tibshirani, Ryan J},\n  journal={arXiv preprint arXiv:2306.09335},\n  year={2023}\n}\n@article{huang2023conformal,\n  title={Conformal Prediction for Deep Classifier via Label Ranking},\n  author={Huang, Jianguo and Xi, Huajun and Zhang, Linjun and Yao, Huaxiu and Qiu, Yue and Wei, Hongxin},\n  journal={arXiv preprint arXiv:2310.06430},\n  year={2023}\n}\n'

In [2]:
#pip install torchcp

In [47]:
from torchcp.classification.scores import THR
from torchcp.classification.scores import APS
from torchcp.classification.scores import RAPS
from torchcp.classification.scores import SAPS

from torchcp.classification.predictors import SplitPredictor

import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Subset

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision
import torchvision.models as models

from transformers import CLIPProcessor, CLIPModel

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

import argparse
import os

import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as trn
from tqdm import tqdm

from torchcp.classification.predictors import ClusterPredictor, ClassWisePredictor, SplitPredictor
from torchcp.classification.scores import THR, APS, SAPS, RAPS
from torchcp.classification import Metrics
from torchcp.utils import fix_randomness
from examples.common.dataset import build_dataset

In [26]:
# cifar10 transform
BATCH_SIZE = 1024

transform_cifar10_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_cifar10_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform_cifar10_train)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform_cifar10_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified
Files already downloaded and verified


In [39]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 4, 3)  
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(4, 8, 3)  
        self.fc1 = nn.Linear(8 * 6 * 6, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 8 * 6 * 6)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [40]:
model =torch.load("./model.pth")

In [44]:
model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(model_device)
model.eval()

dataset = torchvision.datasets.CIFAR10(root='../data', train=True,download=True, transform=transform_cifar10_train)

cal_dataset, test_dataset = torch.utils.data.random_split(dataset, [25000, 25000])
cal_data_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=1024, shuffle=False, pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, pin_memory=True)

# Options of score function: THR, APS, SAPS, RAPS
# Define a conformal prediction algorithm. Optional: SplitPredictor, ClusterPredictor, ClassWisePredictor

def test(score_function):
    predictor = SplitPredictor(score_function=THR(), model=model)

    # Calibrating the predictor with significance level as 0.1
    predictor.calibrate(cal_data_loader, alpha=0.1)

    #########################################
    # Evaluating the coverage rate and average set size on a given dataset.
    ########################################
    result_dict = predictor.evaluate(test_data_loader)
    print(result_dict["Coverage_rate"], result_dict["Average_size"])


Files already downloaded and verified


In [50]:
test(THR())
test(APS())

0.89832 4.38196
0.90116 4.43996


In [53]:
s = {1,2,4,8,16,32}
for i in s:
    test(SAPS(i))              #weight must be positive

0.89804 4.43744
0.9012 4.45904
0.90004 4.4214
0.90268 4.45996
0.90168 4.43832
0.9016 4.43072


In [56]:
#penalty, kreg=0      penalty must be postive and kreg must be a natural
s = {1,2,4,8,16,32,64,100}
for i in s:
    test(RAPS(i))

0.90156 4.39476
0.9002 4.432
0.90116 4.42608
0.90132 4.4156
0.90248 4.46168
0.90136 4.4348
0.89844 4.41972
0.90024 4.42884


In [57]:

for i in range(11):
    test(RAPS(8,i))

0.902 4.43176
0.90068 4.44996
0.89972 4.448
0.90196 4.47336
0.8988 4.40336
0.90152 4.482
0.8992 4.40908
0.90224 4.4506
0.9016 4.4056
0.90152 4.45036
0.9008 4.41828
