In [1]:
import numpy as np, torch, math, random
from pathlib import Path
from numba import njit, set_num_threads
from collections import defaultdict

# ----------------------------------------------------------------------
#  R‑value helper copied from your code (only the inner kernel is needed)
# ----------------------------------------------------------------------
@njit(fastmath=True, cache=True)
def _greedy_majority_inplace(sorted_idx, n_val, alpha,
                             counts, chosen, r_value, order_out, keep_mask):
    n_rows, n_cols = sorted_idx.shape
    for i in range(n_cols):
        counts[i] = 0
        chosen[i] = False
    for step in range(n_cols):
        col_vec = sorted_idx[:, step]
        for r in range(n_rows):
            counts[col_vec[r]] += 1
        best = -1
        best_cnt = -1
        for c in range(n_cols):
            if (not chosen[c]) and counts[c] > best_cnt:
                best_cnt = counts[c]
                best = c
        chosen[best]    = True
        r_value[best]   = (step + 1) / n_cols
        order_out[step] = best
    k = math.ceil((n_val + 1) * (1 - alpha))
    seen = 0
    boundary = 1.0
    for step in range(n_cols):
        idx = order_out[step]
        if idx < n_val:
            seen += 1
            if seen == k:
                boundary = r_value[idx]
                break
    for c in range(n_val, n_cols):
        keep_mask[c - n_val] = r_value[c] <= boundary


# ----------------------------------------------------------------------
#  LAC helpers copied from your code (unchanged)
# ----------------------------------------------------------------------

def calibrate_lac(scores, targets, alpha=0.1, return_dist=False):
    """
    Estimates the 1-alpha quantile on held-out calibration data.
    The score function is `1 - max(softmax_score)`.
    
    Arguments:
        scores: softmax scores of the calibration set
        targets: corresponding labels of the calibration set
        alpha: parameter for the desired coverage level (1-alpha)

    Returns:
       qhat: the estimated quantile
       score_dist: the score distribution
    """
    scores = torch.tensor(scores, dtype=torch.float)
    targets = torch.tensor(targets)
    assert scores.size(0) == targets.size(0)
    assert targets.size(0)
    n = torch.tensor(targets.size(0))
    assert n

    score_dist = torch.take_along_dim(1 - scores, targets.unsqueeze(1), 1).flatten()
    assert (
        0 <= torch.ceil((n + 1) * (1 - alpha)) / n <= 1
    ), f"{alpha=} {n=} {torch.ceil((n+1)*(1-alpha))/n=}"
    qhat = torch.quantile(
        
        score_dist, torch.ceil((n + 1) * (1 - alpha)) / n, interpolation="higher"
    )
    return (qhat, score_dist) if return_dist else qhat


def inference_lac(scores, qhat, allow_empty_sets=False):
    """
    Makes prediction sets on new test data
    
    Arguments:
        scores: softmax scores of the test set
        qhat: estimated quantile of the calibration set from the `calirbate_lac` function
        allow_empty_sets: if True allow a prediction set to contain no predictions (will then satisfy upper bound of marginal coverage)

    Returns:
       prediction_sets: boolean mask of prediction sets (True if class is included in the prediction set; otherwise False)
    """
    scores = torch.tensor(scores, dtype=torch.float)
    n = scores.size(0)

    elements_mask = scores >= (1 - qhat)

    if not allow_empty_sets:
        elements_mask[torch.arange(n), scores.argmax(1)] = True
        
    prediction_sets = elements_mask

    return prediction_sets


# ----------------------------------------------------------------------
#  0.  Load data and helper label map
# ----------------------------------------------------------------------
data_dir = Path("./data")
scores = np.load(data_dir / "clip_probs_matrix_cifar100_2k.npy")   # (N, 31, 100)
targets = np.load(data_dir / "clip_targets_cifar100_2k.npy")       # (N,)
N, N_rep, N_cls = scores.shape
with open(data_dir / "label_rephrased_dict_cifar100.json") as f: pass  # just to show it exists

label_names = {
    0: "apple", 1: "aquarium_fish", 2: "baby", 3: "bear", 4: "beaver", 5: "bed", 
    6: "bee", 7: "beetle", 8: "bicycle", 9: "bottle", 10: "bowl", 11: "boy", 
    12: "bridge", 13: "bus", 14: "butterfly", 15: "camel", 16: "can", 17: "castle", 
    18: "caterpillar", 19: "cattle", 20: "chair", 21: "chimpanzee", 22: "clock", 
    23: "cloud", 24: "cockroach", 25: "couch", 26: "crab", 27: "crocodile", 
    28: "cup", 29: "dinosaur", 30: "dolphin", 31: "elephant", 32: "flatfish", 
    33: "forest", 34: "fox", 35: "girl", 36: "hamster", 37: "house", 38: "kangaroo", 
    39: "keyboard", 40: "lamp", 41: "lawn_mower", 42: "leopard", 43: "lion", 
    44: "lizard", 45: "lobster", 46: "man", 47: "maple_tree", 48: "motorcycle", 
    49: "mountain", 50: "mouse", 51: "mushroom", 52: "oak_tree", 53: "orange", 
    54: "orchid", 55: "otter", 56: "palm_tree", 57: "pear", 58: "pickup_truck", 
    59: "pine_tree", 60: "plain", 61: "plate", 62: "poppy", 63: "porcupine", 
    64: "possum", 65: "rabbit", 66: "raccoon", 67: "ray", 68: "road", 69: "rocket", 
    70: "rose", 71: "sea", 72: "seal", 73: "shark", 74: "shrew", 75: "skunk", 
    76: "skyscraper", 77: "snail", 78: "snake", 79: "spider", 80: "squirrel", 
    81: "streetcar", 82: "sunflower", 83: "sweet_pepper", 84: "table", 85: "tank", 
    86: "telephone", 87: "television", 88: "tiger", 89: "tractor", 90: "train", 
    91: "trout", 92: "tulip", 93: "turtle", 94: "wardrobe", 95: "whale", 
    96: "willow_tree", 97: "wolf", 98: "woman", 99: "worm"
}


alpha = 0.05
rng   = np.random.default_rng(seed=2025)

# ----------------------------------------------------------------------
#  1.  Simple single split (50 % calibration / 50 % test)
# ----------------------------------------------------------------------
perm      = rng.permutation(N)
cal_idx   = perm[: N//2]
test_idx  = perm[N//2 :]

cal_scores_full = scores[cal_idx]          # (n_cal, 31, 100)
cal_targets     = targets[cal_idx]
test_scores_full= scores[test_idx]
test_targets    = targets[test_idx]

# pick **one** random test example
test_pick_pos   = random.randrange(len(test_idx))     # position within test split
test_pick_idx   = test_idx[test_pick_pos]
true_label      = test_targets[test_pick_pos]

# ----------------------------------------------------------------------
#  2‑A.  R‑value conformal set for that example
#       (uses ALL prompt variants + greedy majority rule)
# ----------------------------------------------------------------------
# Step‑by‑step, exactly like your loop – but for a single image
val_true         = cal_scores_full[np.arange(len(cal_idx)), :, cal_targets].T  # (31, n_val)
n_val, n_rep     = val_true.shape[1], val_true.shape[0]

# buffers
full      = np.empty((n_rep, n_val + N_cls), np.float32)
counts    = np.empty(n_val + N_cls,  np.int32)
chosen    = np.empty(n_val + N_cls,  np.bool_)
r_value   = np.empty(n_val + N_cls,  np.float32)
order_out = np.empty(n_val + N_cls,  np.int32)
keep_mask = np.empty(N_cls,          np.bool_)

# insert chosen test image
full[:, :n_val]  = val_true
full[:, n_val:]  = test_scores_full[test_pick_pos]

sidx = np.argsort(-full, axis=1).astype(np.int32)
_greedy_majority_inplace(
    sidx, n_val, alpha,
    counts, chosen, r_value, order_out, keep_mask
)

test_rv         = r_value[n_val:]                # r‑values for the 100 test classes
rvalue_indices  = np.nonzero(keep_mask)[0]
rvalue_scores   = test_rv[rvalue_indices]

# sort by r‑value (lower = higher rank)
rv_sorted_pairs = sorted(zip(rvalue_indices, rvalue_scores), key=lambda t: t[1])

# ----------------------------------------------------------------------
#  2‑B.  Standard (LAC) conformal set – use **mean** logits per image
# ----------------------------------------------------------------------
cal_scores_mean  = cal_scores_full.mean(axis=1)       # (n_cal, 100)
test_scores_mean = test_scores_full.mean(axis=1)      # (n_test, 100)

qhat = calibrate_lac(cal_scores_mean, cal_targets, alpha=alpha)
lac_mask = inference_lac(test_scores_mean, qhat)[test_pick_pos].numpy()  # boolean (100,)

# ----------------------------------------------------------------------
#  3.  Pretty print results
# ----------------------------------------------------------------------
print(f"\n=== Demo on CIFAR‑100 image index {test_pick_idx} ===")
print(f"True label : {label_names[true_label]} (id={true_label})")
print(f"alpha      : {alpha}")

# ---------- R‑value set (unchanged) -----------------------------------
print("\nR‑value prediction set "
      f"(size={len(rv_sorted_pairs)}) – sorted by r‑value:\n"
      "  id  r‑value  name")
for cls, rv in rv_sorted_pairs:
    print(f"  {cls:3d}  {rv:7.3f}  {label_names[cls]}")

# ---------- LAC set (ordered by probability) --------------------------
lac_indices = np.where(lac_mask)[0]
probs_vec   = test_scores_mean[test_pick_pos]  # probabilities for this image
# sort selected classes by probability, highest first
lac_sorted_pairs = sorted(
    ((cls, probs_vec[cls]) for cls in lac_indices),
    key=lambda t: t[1],
    reverse=True
)

print("\nLAC prediction set "
      f"(size={len(lac_sorted_pairs)}) – sorted by probability:\n"
      "  id   prob    name")
for cls, p in lac_sorted_pairs:
    print(f"  {cls:3d}  {p:6.3f}  {label_names[cls]}")


ModuleNotFoundError: No module named 'torch'