In [1]:
import torch
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torchvision.models import ResNet50_Weights
import numpy as np
import torchvision.transforms as transforms        
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from src.temperature_scaling import ModelWithTemperature
from src.saps_hyp_opt import lambda_optimization_saps

#  Reprocess
data_transform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

sorted_val_path = "D:\\Download\\ImageNet-1K\\Validation_Set\\sorted_ImageNet_val"
dataset = ImageFolder(root=sorted_val_path, transform=data_transform)

# load pre-trained model 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)

subset_size = len(dataset) // 10
indices = np.random.choice(len(dataset), subset_size, replace=False)
subset_dataset = Subset(dataset, indices)
train_loader = DataLoader(subset_dataset, batch_size=32, shuffle=False, num_workers=2)
model = ModelWithTemperature(model, temperature = 1.0).to(device)
model.set_temperature(train_loader)

model.eval()

Before temperature - NLL: 1.156, ECE: 0.020
Optimal temperature: 0.989
After temperature - NLL: 1.155, ECE: 0.020


ModelWithTemperature(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(

# $\alpha$=0.1

In [1]:
lambda_values = [0, 1e-3, 0.01, 0.05, 0.1, 0.15, 0.3, 0.5, 1.0]

# lambda optimization
print("Looking for optimal lambda...")
optimal_lambda = lambda_optimization_saps(model, dataset, lambda_values, device)
if optimal_lambda is None:
    print("No optimal lambda is found")
else:
    print(f"Optimal lambda is {optimal_lambda}\n")

Before temperature - NLL: 1.154, ECE: 0.026
Optimal temperature: 0.989
After temperature - NLL: 1.153, ECE: 0.026
Looking for optimal lambda...
Optimal lambda is 0.15



# $\alpha$=0.05

In [2]:
lambda_values = [0.01, 0.03, 0.05, 0.07, 0.1, 0.15, 0.3]

# lambda optimization
print("Looking for optimal lambda...")
optimal_lambda = lambda_optimization_saps(model, dataset, lambda_values, device, alpha=0.05)
if optimal_lambda is None:
    print("No optimal lambda is found")
else:
    print(f"Optimal lambda is {optimal_lambda}\n")

Looking for optimal lambda...
Optimal lambda is 0.05



# $\alpha$=0.2

In [3]:
lambda_values = [0.01, 0.05, 0.07, 0.1, 0.13, 0.15, 0.3]

# lambda optimization
print("Looking for optimal lambda...")
optimal_lambda = lambda_optimization_saps(model, dataset, lambda_values, device, alpha=0.2)
if optimal_lambda is None:
    print("No optimal lambda is found")
else:
    print(f"Optimal lambda is {optimal_lambda}\n")

Looking for optimal lambda...
Optimal lambda is 0.1

