In [1]:
import torch
import torchvision.transforms as transforms              
from torchvision.datasets import CIFAR10      
from torch.utils.data import DataLoader
from src.temperature_scaling import ModelWithTemperature
from src.raps_hyp_opt import lambda_optimization_raps, k_reg_optimization
from src.inception import inception_v3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dict_path = "C:\\Users\\jiayang\\ipynb\\trainedModel\\Inception_CIFAR10.pth"
model = inception_v3(pretrained=True, dict_path=dict_path).to(device)

# preprocess the images from CIFAR10
data_transform = transforms.Compose([
    transforms.ToTensor(),       
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  
])

# load images from CIFAR10
dataset = CIFAR10(root="../../../data", train=False, download=True, transform=data_transform)

# Temperature Scaling
temp_scal_loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = ModelWithTemperature(model, temperature=5.2).to(device)
model.set_temperature(temp_scal_loader)
model.eval()

Loading weights from: C:\Users\jiayang\ipynb\trainedModel\Inception_CIFAR10.pth
Files already downloaded and verified
Before temperature - NLL: 0.366, ECE: 0.029
Optimal temperature: 5.125
After temperature - NLL: 0.894, ECE: 0.410


ModelWithTemperature(
  (model): Inception3(
    (Conv2d_1a_3x3): BasicConv2d(
      (conv): Conv2d(3, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Mixed_5b): InceptionA(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (branch5x5_1): BasicConv2d(
        (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (branch5x5_2): BasicConv2d(
        (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (branch3x3dbl_1): BasicConv2d(
       

# $\alpha$=0.1

In [1]:

lambda_values = [0, 1e-3, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 1.0]
k_reg_values = [1, 2, 3, 4, 5, 7, 9]

# lambda optimization
print("Looking for optimal lambda...")
optimal_lambda = lambda_optimization_raps(model, dataset, lambda_values, k_reg = 4, device=device)
if optimal_lambda is None:
    print("No optimal lambda is found")
else:
    print(f"Optimal lambda is {optimal_lambda}\n")
    
    # k_reg optimization
    print("Looking for optimal k_reg...")
    optimal_k = k_reg_optimization(model, dataset, optimal_lambda, k_reg_values, device=device)
    
    if optimal_k is None:
        print("No optimal k_reg is found")
    else:
        print(f"Optimal k_reg is {optimal_k}")

Loading weights from: C:\Users\jiayang\ipynb\trainedModel\Inception_CIFAR10.pth
Files already downloaded and verified
Before temperature - NLL: 0.427, ECE: 0.055
Optimal temperature: 5.128
After temperature - NLL: 0.769, ECE: 0.341
Looking for optimal lambda...
Optimal lambda is 0.2

Looking for optimal k_reg...
Optimal k_reg is 2


In [2]:
lambda_values = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
k_reg_values = [1, 2, 3, 4, 5, 6, 7]

# lambda optimization
print("Looking for optimal lambda...")
optimal_lambda = lambda_optimization_raps(model, dataset, lambda_values, k_reg = 4, device=device, alpha=0.05)
if optimal_lambda is None:
    print("No optimal lambda is found")
else:
    print(f"Optimal lambda is {optimal_lambda}\n")
    
    # k_reg optimization
    print("Looking for optimal k_reg...")
    optimal_k = k_reg_optimization(model, dataset, optimal_lambda, k_reg_values, device=device, alpha=0.05)
    
    if optimal_k is None:
        print("No optimal k_reg is found")
    else:
        print(f"Optimal k_reg is {optimal_k}")

Looking for optimal lambda...
Optimal lambda is 0.2

Looking for optimal k_reg...
Optimal k_reg is 2
