In [25]:
from torchcp.classification.scores import THR,APS,SAPS,RAPS
from torchcp.classification.predictors import SplitPredictor,ClusterPredictor,ClassWisePredictor

import torch
from torch.utils.data import DataLoader

from model_prepare import load_resnet18,load_densenet,training_model
from data_prepare import load_data

import warnings
warnings.filterwarnings('ignore')

from tqdm import tqdm

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Prepare data

In [27]:
## fashion mnist, n_class = 10
cal_data_fashion_mnist, train_data_fashion_mnist, test_data_fashion_mnist = load_data("fashionmnist",seed = 42)
train_dataloader_fashion = DataLoader(train_data_fashion_mnist, batch_size=64, shuffle=True)
cal_dataloader_fashion = DataLoader(cal_data_fashion_mnist, batch_size=64, shuffle=False)
test_dataloader_fashion = DataLoader(test_data_fashion_mnist, batch_size=64, shuffle=False)


## cifar100, n_class = 100
cal_data_cifar100, train_data_cifar100, test_data_cifar100 = load_data("cifar100",seed = 42)
train_dataloader_cifar100 = DataLoader(train_data_cifar100, batch_size=64, shuffle=True)
cal_dataloader_cifar100 = DataLoader(cal_data_cifar100, batch_size=64, shuffle=False)
test_dataloader_cifar100 = DataLoader(test_data_cifar100, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## model prepare

In [28]:
# resnet18 = load_resnet18(10)
# densenet = load_densenet(10)

# # training_model(train_dataloader=train_dataloader_fashion,test_dataloader=test_dataloader_fashion,model=densenet,saving_name="densenet121_fashionmnist",num_epochs=10)

# resnet18 = load_resnet18(100)
# densenet = load_densenet(100)

# training_model(train_dataloader=train_dataloader_cifar100,test_dataloader=test_dataloader_cifar100,model=resnet18,saving_name="resnet18_cifar100",num_epochs=30)
# # training_model(train_dataloader=train_dataloader_cifar100,test_dataloader=test_dataloader_cifar100,model=densenet,saving_name="densenet121_cifar100",num_epochs=30)

In [29]:
densenet121_cifar100 = torch.load("models/densenet121_cifar100.pth")
densenet121_fashionmnist = torch.load("models/densenet121_fashionmnist.pth")

resnet18_cifar100 = torch.load("models/resnet18_cifar100.pth")
resnet18_fashionminist = torch.load("models/resnet18_fashionmnist.pth")

## Evaluation_cp

In [30]:
# predictor = SplitPredictor(score_function = THR(), model = resnet18_fashionminist)

# # Calibrating the predictor with significance level as alpha
# predictor.calibrate(cal_dataloader_fashion, 0.1)
# print("calibration finished")

In [31]:
def eval_cp(model_name, model, cal_dataloader, test_dataloader):
    model.eval()

    score_list = [THR(),APS(),SAPS(weight=0.2),RAPS(1,0)]
    predictor_list = [SplitPredictor,ClusterPredictor,ClassWisePredictor]


    alpha=0.1   
    for score in tqdm(score_list):
        for class_predictor in predictor_list:
            predictor = class_predictor(score_function = score, model = model)
            
            # Calibrating the predictor with significance level as alpha
            predictor.calibrate(cal_dataloader, alpha)
            print("calibration finished")
            # Evaluating the coverage rate and average set size on a given dataset.
            result_dict = predictor.evaluate(test_dataloader)
            
            print(f"----------------{model_name},{score.__class__.__name__},{predictor.__class__.__name__},alpha = {alpha}------------------")
            print(f"coverage_rate:{result_dict['Coverage_rate']}, Average_size:{result_dict['Average_size']}")


In [32]:
eval_cp("resnet18_fashionminist",resnet18_fashionminist,cal_dataloader_fashion,test_dataloader_fashion)

  0%|          | 0/4 [00:00<?, ?it/s]

calibration finished
----------------resnet18_fashionminist,THR,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.8925, Average_size:1.05825
calibration finished
----------------resnet18_fashionminist,THR,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.8925, Average_size:1.0585
calibration finished


 25%|██▌       | 1/4 [04:16<12:48, 256.05s/it]

----------------resnet18_fashionminist,THR,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.900625, Average_size:1.182
calibration finished
----------------resnet18_fashionminist,APS,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.897125, Average_size:1.2195
calibration finished
----------------resnet18_fashionminist,APS,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.8935, Average_size:1.213375
calibration finished


 50%|█████     | 2/4 [07:50<07:43, 231.78s/it]

----------------resnet18_fashionminist,APS,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.90175, Average_size:1.349
calibration finished
----------------resnet18_fashionminist,SAPS,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.892875, Average_size:1.187375
calibration finished
----------------resnet18_fashionminist,SAPS,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.891625, Average_size:1.184625
calibration finished


 75%|███████▌  | 3/4 [11:21<03:42, 222.06s/it]

----------------resnet18_fashionminist,SAPS,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.892125, Average_size:1.38425
calibration finished
----------------resnet18_fashionminist,RAPS,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.8955, Average_size:1.073
calibration finished
----------------resnet18_fashionminist,RAPS,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.8955, Average_size:1.0745
calibration finished


100%|██████████| 4/4 [14:50<00:00, 222.52s/it]

----------------resnet18_fashionminist,RAPS,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.893375, Average_size:1.329625





In [33]:
eval_cp("densenet121_fashionminist",densenet121_fashionmnist,cal_dataloader_fashion,test_dataloader_fashion)

  0%|          | 0/4 [00:00<?, ?it/s]

calibration finished
----------------densenet121_fashionminist,THR,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.9005, Average_size:1.181375
calibration finished
----------------densenet121_fashionminist,THR,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.900125, Average_size:1.179125
calibration finished


 25%|██▌       | 1/4 [51:56<2:35:49, 3116.45s/it]

----------------densenet121_fashionminist,THR,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.90325, Average_size:1.205125
calibration finished
----------------densenet121_fashionminist,APS,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.900375, Average_size:1.357875
calibration finished
----------------densenet121_fashionminist,APS,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.902, Average_size:1.352125
calibration finished


 50%|█████     | 2/4 [1:46:56<1:47:28, 3224.38s/it]

----------------densenet121_fashionminist,APS,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.903125, Average_size:1.40425
calibration finished
----------------densenet121_fashionminist,SAPS,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.90075, Average_size:1.338
calibration finished
----------------densenet121_fashionminist,SAPS,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.9005, Average_size:1.33825
calibration finished


 75%|███████▌  | 3/4 [2:43:02<54:48, 3288.88s/it]  

----------------densenet121_fashionminist,SAPS,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.904375, Average_size:1.501125
calibration finished
----------------densenet121_fashionminist,RAPS,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.900375, Average_size:1.205625
calibration finished
----------------densenet121_fashionminist,RAPS,ClusterPredictor,alpha = 0.1------------------
coverage_rate:0.8975, Average_size:1.20325
calibration finished


100%|██████████| 4/4 [3:39:13<00:00, 3288.32s/it]

----------------densenet121_fashionminist,RAPS,ClassWisePredictor,alpha = 0.1------------------
coverage_rate:0.89875, Average_size:1.458





In [34]:
eval_cp("resnet18_cifar100",resnet18_cifar100,cal_dataloader_cifar100,test_dataloader_cifar100)

  0%|          | 0/4 [00:00<?, ?it/s]

calibration finished
----------------resnet18_cifar100,THR,SplitPredictor,alpha = 0.1------------------
coverage_rate:0.89925, Average_size:16.645


  0%|          | 0/4 [03:58<?, ?it/s]


KeyboardInterrupt: 

In [None]:
eval_cp("densenet121_cifar100",densenet121_cifar100,cal_dataloader_cifar100,test_dataloader_cifar100)