In [3]:
import torch
import pickle
from torch.nn import Softmax
import numpy as np

# Load cached logits
with open('/home/coder/early-exit/models/tiny-imagenet/resnet18/blocks/cached_logits.pkl', 'rb') as f:
    cached_data = pickle.load(f)

# Initialize counters for correct predictions and confidence tracking
correct_predictions = [0, 0, 0, 0]  # One for each exit
confidence_sums = [0, 0, 0, 0]  # Track sum of confidence for each exit
total_samples = len(cached_data)

# Initialize bins for reliability diagram (10 bins from 0 to 1)
num_bins = 10
bins = np.linspace(0, 1, num_bins + 1)
bin_confidences = [[] for _ in range(4)]  # Store confidences for each exit
bin_accuracies = [[] for _ in range(4)]  # Store accuracies for each exit

# Softmax layer for converting logits to probabilities
softmax = Softmax(dim=0)

# Evaluate each sample
for sample in cached_data:
    label = sample['label']
    logits = sample['logits']
    
    # Check predictions from each exit
    for i, exit_logits in enumerate(logits):
        # Apply softmax and get prediction
        probs = softmax(exit_logits)
        pred = torch.argmax(probs)
        confidence = torch.max(probs).item()
        
        # Update confidence sum
        confidence_sums[i] += confidence
        
        # Track prediction correctness
        is_correct = (pred == label)
        if is_correct:
            correct_predictions[i] += 1
            
        # Add to appropriate bin
        bin_idx = np.digitize(confidence, bins) - 1
        if bin_idx < num_bins:  # Ensure we don't exceed array bounds
            bin_confidences[i].append(confidence)
            bin_accuracies[i].append(float(is_correct))

# Calculate accuracies and average confidences
accuracies = [correct / total_samples for correct in correct_predictions]
avg_confidences = [conf_sum / total_samples for conf_sum in confidence_sums]

print("Model calibration metrics for each exit:")
for i in range(4):
    print(f"\nExit {i+1}:")
    print(f"Accuracy: {accuracies[i]:.4f}")
    print(f"Average confidence: {avg_confidences[i]:.4f}")
    print(f"Calibration error: {abs(accuracies[i] - avg_confidences[i]):.4f}")
    
    # Calculate reliability per bin
    for bin_idx in range(num_bins):
        bin_mask = np.digitize(bin_confidences[i], bins) == bin_idx + 1
        if np.any(bin_mask):
            bin_acc = np.mean(np.array(bin_accuracies[i])[bin_mask])
            bin_conf = np.mean(np.array(bin_confidences[i])[bin_mask])
            print(f"Bin {bin_idx+1}: Confidence={bin_conf:.3f}, Accuracy={bin_acc:.3f}")


Model calibration metrics for each exit:

Exit 1:
Accuracy: 0.2625
Average confidence: 0.2391
Calibration error: 0.0234
Bin 1: Confidence=0.076, Accuracy=0.082
Bin 2: Confidence=0.146, Accuracy=0.163
Bin 3: Confidence=0.245, Accuracy=0.261
Bin 4: Confidence=0.345, Accuracy=0.397
Bin 5: Confidence=0.446, Accuracy=0.491
Bin 6: Confidence=0.544, Accuracy=0.607
Bin 7: Confidence=0.646, Accuracy=0.749
Bin 8: Confidence=0.747, Accuracy=0.772
Bin 9: Confidence=0.846, Accuracy=0.843
Bin 10: Confidence=0.943, Accuracy=0.967

Exit 2:
Accuracy: 0.3091
Average confidence: 0.2973
Calibration error: 0.0118
Bin 1: Confidence=0.080, Accuracy=0.084
Bin 2: Confidence=0.149, Accuracy=0.150
Bin 3: Confidence=0.246, Accuracy=0.242
Bin 4: Confidence=0.346, Accuracy=0.375
Bin 5: Confidence=0.445, Accuracy=0.471
Bin 6: Confidence=0.546, Accuracy=0.586
Bin 7: Confidence=0.649, Accuracy=0.683
Bin 8: Confidence=0.746, Accuracy=0.806
Bin 9: Confidence=0.846, Accuracy=0.898
Bin 10: Confidence=0.951, Accuracy=0.951