In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch import nn
import numpy as np
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from scipy.stats import entropy
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

transform = transforms.Compose([
    transforms.Resize(224), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

calibration_size = 1000
X_test, y_test = cifar10.data, np.array(cifar10.targets)

X_calib, X_final, y_calib, y_final = X_test[:calibration_size], X_test[calibration_size:], y_test[:calibration_size], y_test[calibration_size:]

X_calib = torch.tensor(X_calib).permute(0, 3, 1, 2) / 255.0  # Rescale to [0, 1]
X_final = torch.tensor(X_final).permute(0, 3, 1, 2) / 255.0  # Rescale to [0, 1]

model = resnet18(pretrained=True)
model.fc = nn.Identity()  # remove final classification layer
model.eval()

def extract_features(X):
    features = []
    with torch.no_grad():
        for x in X:
            feat = model(x.unsqueeze(0))  # Add batch dimension
            features.append(feat)
    return torch.cat(features, dim=0).numpy()

X_calib_features = extract_features(X_calib)
X_final_features = extract_features(X_final)

idx = np.random.randint(len(X_final_features))
x_test_sample = X_final_features[idx]
print(f"Random test sample index: {idx}, label: {y_final[idx]}")

pca = PCA(n_components=2)
X_calib_pca = pca.fit_transform(X_calib_features)
x_test_sample_pca = pca.transform(x_test_sample.reshape(1, -1))

k = 20
nn = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(X_calib_pca)

local_entropies_calib = []
for i, x in enumerate(X_calib_pca):
    distances, indices = nn.kneighbors([x])
    neighbor_indices = indices[0][1:]  # exclude the point itself
    neighbor_labels = y_calib[neighbor_indices]
    label_counts = np.bincount(neighbor_labels, minlength=10)
    p = label_counts / label_counts.sum()
    ent = entropy(p, base=2)
    local_entropies_calib.append(ent)
local_entropies_calib = np.array(local_entropies_calib)

distances, indices = nn.kneighbors(x_test_sample_pca)
neighbor_labels = y_calib[indices[0]]
label_counts = np.bincount(neighbor_labels, minlength=10)
p = label_counts / label_counts.sum()
local_entropy_test = entropy(p, base=2)
print(f"Local label entropy (k={k}) around test point: {local_entropy_test:.4f} bits")


In [None]:
max_entropy = np.max(local_entropies_calib)
min_entropy = np.min(local_entropies_calib)

def map_entropies_to_discrete_sizes(entropies, max_entropy, max_size=3, min_size=1):
    num_bins = max_size - min_size + 1
    scaled = np.linspace(0, 1, num_bins) ** (1/2)
    bins = min_entropy + (max_entropy - min_entropy) * scaled
    print(bins)
    return np.digitize(entropies, bins, right=True) + min_size

prediction_set_sizes = map_entropies_to_discrete_sizes(local_entropies_calib,max_entropy)
test_prediction_set_size = map_entropies_to_discrete_sizes(local_entropy_test,max_entropy)

In [None]:
from matplotlib.colors import ListedColormap, BoundaryNorm

# first figure: local entropy
plt.figure(figsize=(7, 4))
scatter1 = plt.scatter(
    X_calib_pca[:, 0], X_calib_pca[:, 1],
    c=local_entropies_calib, cmap='viridis', s=10
)
star_color_entropy = plt.cm.viridis(local_entropy_test / max(local_entropies_calib))
plt.scatter(
    x_test_sample_pca[0, 0], x_test_sample_pca[0, 1],
    c=[star_color_entropy], s=200, edgecolors='black', label="Test Sample", marker='*'
)
cbar1 = plt.colorbar(scatter1)
ticks = np.linspace(min(local_entropies_calib), max(local_entropies_calib), 10)
rounded_ticks = np.round(ticks, 2)
cbar1.set_ticks(rounded_ticks)
cbar1.set_label('Local Label Entropy (bits)',labelpad=20)
plt.xlim([-10, 10])
plt.ylim([-8, 12])
plt.xticks(np.arange(-10,11,4), fontsize=12, minor=False)
plt.xticks(np.arange(-10,11,2), fontsize=12, minor=True)
plt.yticks(np.arange(-8,13,4), fontsize=12, minor=False)
plt.yticks(np.arange(-8,13,2), fontsize=12, minor=True)

plt.tick_params(axis='x', which='major', length=10, labelsize=14)
plt.tick_params(axis='x', which='minor', length=5, labelsize=10)
plt.tick_params(axis='y', which='major', length=10, labelsize=14)
plt.tick_params(axis='y', which='minor', length=5, labelsize=10)
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.legend()
# plt.tight_layout()
plt.savefig("pca1.pdf", format="pdf", bbox_inches="tight")
plt.show()


# second figure: prediction set size
plt.figure(figsize=(7, 4))
colors = plt.cm.viridis([0.1,0.6,0.95])
cmap = ListedColormap(colors)
norm = BoundaryNorm([0.5, 1.5, 2.5, 3.5], cmap.N)
scatter2 = plt.scatter(
    X_calib_pca[:, 0], X_calib_pca[:, 1],
    c=prediction_set_sizes, cmap=cmap, norm=norm, s=10
)
star_color_size = cmap(norm(test_prediction_set_size))
plt.scatter(
    x_test_sample_pca[0, 0], x_test_sample_pca[0, 1],
    c=[star_color_size], s=200, edgecolors='black', label="Test Sample", marker='*'
)
cbar2 = plt.colorbar(scatter2, ticks=[1, 2, 3])
cbar2.set_label('Prediction Set Size',labelpad=20)
plt.xlim([-10, 10])
plt.ylim([-8, 12])
plt.xticks(np.arange(-10,11,4), fontsize=12, minor=False)
plt.xticks(np.arange(-10,11,2), fontsize=12, minor=True)
plt.yticks(np.arange(-8,13,4), fontsize=12, minor=False)
plt.yticks(np.arange(-8,13,2), fontsize=12, minor=True)
plt.tick_params(axis='x', which='major', length=10, labelsize=14)
plt.tick_params(axis='x', which='minor', length=5, labelsize=10)
plt.tick_params(axis='y', which='major', length=10, labelsize=14)
plt.tick_params(axis='y', which='minor', length=5, labelsize=10)
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.legend()
# plt.tight_layout()
plt.savefig("pca2.pdf", format="pdf", bbox_inches="tight")
plt.show()
