In [40]:
from pathlib import Path
from tqdm import tqdm
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle

data_dir = Path('./data')

datasets = {
    'cifar100': {
        'scores': np.load(data_dir / 'clip_probs_matrix_cifar100_2k.npy'),
        'targets': np.load(data_dir / 'clip_targets_cifar100_2k.npy'),
    },
}

In [41]:
import math
import numpy as np
from numba import njit, set_num_threads
from collections import defaultdict
from tqdm import tqdm


# ────────────────────────────────────────────────────────────────────────────────
#  1) In-place greedy majority rule, using caller-provided scratch arrays
# ────────────────────────────────────────────────────────────────────────────────
@njit(fastmath=True, cache=True)
def _greedy_majority_inplace(sorted_idx, n_val, alpha,
                             counts, chosen, r_value, order_out, keep_mask):
    n_rows, n_cols = sorted_idx.shape

    # reset the buffers
    for i in range(n_cols):
        counts[i] = 0
        chosen[i] = False

    # greedy picks
    for step in range(n_cols):
        col_vec = sorted_idx[:, step]
        for r in range(n_rows):
            counts[col_vec[r]] += 1

        # find best
        best = -1
        best_cnt = -1
        for c in range(n_cols):
            if (not chosen[c]) and counts[c] > best_cnt:
                best_cnt = counts[c]
                best = c

        chosen[best]    = True
        r_value[best]   = (step + 1) / n_cols
        order_out[step] = best

    # compute boundary on the validation columns
    k = math.ceil((n_val + 1) * (1 - alpha))
    seen = 0
    boundary = 1.0
    for step in range(n_cols):
        idx = order_out[step]
        if idx < n_val:
            seen += 1
            if seen == k:
                boundary = r_value[idx]
                break

    # build the mask for test columns
    for c in range(n_val, n_cols):
        keep_mask[c - n_val] = r_value[c] <= boundary


# ────────────────────────────────────────────────────────────────────────────────
#  2) calculate_r_value_fast with full bookkeeping
# ────────────────────────────────────────────────────────────────────────────────
def calculate_r_value_fast(datasets, alpha, num_trials, num_threads=None):
    """
    Same signature as the original, but much faster inside the loop.
    Returns:
      all_r_value_sets, all_test_targets,
      all_sizes_rvalue, all_average_sizes,
      all_correct_predictions
    """
    if num_threads is not None:
        set_num_threads(num_threads)

    all_r_value_sets     = defaultdict(list)
    all_sizes_rvalue     = defaultdict(list)
    all_average_sizes    = defaultdict(list)
    all_correct_predictions = defaultdict(list)

    for _ in tqdm(range(num_trials)):
        for name, ds in datasets.items():
            probs   = ds["scores"]
            targets = ds["targets"]

            perm       = np.random.permutation(len(probs))
            probs_shuf = probs[perm]
            targs_shuf = targets[perm]

            mid        = len(probs) // 2 
            # mid        = int(0.8 * len(probs))
            val_p      = probs_shuf[:mid]
            tst_p      = probs_shuf[mid:]
            val_t      = targs_shuf[:mid]
            tst_t      = targs_shuf[mid:]

            val_true = val_p[np.arange(mid), :, val_t].T
            n_val    = val_true.shape[1]
            n_reph   = val_true.shape[0]
            n_cls    = probs.shape[2]

            full      = np.empty((n_reph, n_val + n_cls), np.float32)
            counts    = np.empty(n_val + n_cls,     np.int32)
            chosen    = np.empty(n_val + n_cls,     np.bool_)
            r_value   = np.empty(n_val + n_cls,   np.float32)
            order_out = np.empty(n_val + n_cls,     np.int32)
            keep_mask = np.empty(n_cls,            np.bool_)

            r_value_sets = []
            sizes        = []
            hits         = 0

            for i in range(len(tst_p)):
                full[:, :n_val]  = val_true
                full[:, n_val:]  = tst_p[i]

                sidx = np.argsort(-full, axis=1).astype(np.int32)

                _greedy_majority_inplace(
                    sidx, n_val, alpha,
                    counts, chosen, r_value, order_out, keep_mask
                )

                test_rv = r_value[n_val:]          
                idxs    = list(np.nonzero(keep_mask)[0]) 
                rvs     = [float(test_rv[j]) for j in idxs]

                r_value_sets.append({
                    "alpha": alpha,
                    "data": {"index": idxs, "r_value": rvs}
                })
                sizes.append(len(idxs))
                if tst_t[i] in idxs:
                    hits += 1

            all_r_value_sets[name].append(r_value_sets)
            all_sizes_rvalue[name].append(sizes)
            all_average_sizes[name].append(np.mean(sizes))
            all_correct_predictions[name].append(hits / len(tst_p))

    print(f'R-VALUE COVERAGE at alpha: {alpha}')
    print()
    for name, results in all_correct_predictions.items():
        print(name.center(50, '-'))
        print(f'Coverage: {np.mean(results):.2%} +/- {np.std(results):.2%}')
        print()

    print('********************')
    print(f'SET SIZES at alpha: {alpha}')
    print()
    for name, results in all_average_sizes.items():
        print(name.center(50, '-'))
        print(f'Set Size: {np.mean(results):.2f} +/- {np.std(results):.1f}')
        print()

    return all_average_sizes, all_correct_predictions


In [42]:
from functools import partial
import numpy as np
import torch


def calibrate_lac(scores, targets, alpha=0.1, return_dist=False):
    """
    Estimates the 1-alpha quantile on held-out calibration data.
    The score function is `1 - max(softmax_score)`.
    
    Arguments:
        scores: softmax scores of the calibration set
        targets: corresponding labels of the calibration set
        alpha: parameter for the desired coverage level (1-alpha)

    Returns:
       qhat: the estimated quantile
       score_dist: the score distribution
    """
    scores = torch.tensor(scores, dtype=torch.float)
    targets = torch.tensor(targets)
    assert scores.size(0) == targets.size(0)
    assert targets.size(0)
    n = torch.tensor(targets.size(0))
    assert n

    score_dist = torch.take_along_dim(1 - scores, targets.unsqueeze(1), 1).flatten()
    assert (
        0 <= torch.ceil((n + 1) * (1 - alpha)) / n <= 1
    ), f"{alpha=} {n=} {torch.ceil((n+1)*(1-alpha))/n=}"
    qhat = torch.quantile(
        
        score_dist, torch.ceil((n + 1) * (1 - alpha)) / n, interpolation="higher"
    )
    return (qhat, score_dist) if return_dist else qhat


def inference_lac(scores, qhat, allow_empty_sets=False):
    """
    Makes prediction sets on new test data
    
    Arguments:
        scores: softmax scores of the test set
        qhat: estimated quantile of the calibration set from the `calirbate_lac` function
        allow_empty_sets: if True allow a prediction set to contain no predictions (will then satisfy upper bound of marginal coverage)

    Returns:
       prediction_sets: boolean mask of prediction sets (True if class is included in the prediction set; otherwise False)
    """
    scores = torch.tensor(scores, dtype=torch.float)
    n = scores.size(0)

    elements_mask = scores >= (1 - qhat)

    if not allow_empty_sets:
        elements_mask[torch.arange(n), scores.argmax(1)] = True
        
    prediction_sets = elements_mask

    return prediction_sets



def get_coverage(psets, targets, precision=None):
    """
    Calculates empirical coverage of prediction sets
    
    Arguments:
        psets: prediction sets of test set
        targets: ground true labels of test set
        precision: rounding precision

    Returns:
       coverage: how many times the answer is in the prediction set
    """
    psets = torch.tensor(psets)
    targets = torch.tensor(targets)
    psets = psets.clone()
    targets = targets.clone()
    n = psets.shape[0]
    coverage = psets[torch.arange(n), targets].float().mean().item()
    # if precision is not None:
    #     coverage = round(coverage, precision)
    return coverage


def get_size(psets, precision=1):
    """
    Calculates empirical set sizes of prediction sets (can consider as the average uncertainty of the model)
    
    Arguments:
        psets: prediction sets of test set
        precision: rounding precision

    Returns:
       size: how many prediction does each set contain on average
    """
    psets = psets.clone()
    size = psets.sum(1).float().mean().item()
    # if precision is not None:
    #     size = round(size, precision)
    return size


from collections import defaultdict

def get_lac(datasets, alpha, num_trials):
    all_q_hats = defaultdict(list)
    all_scores = defaultdict(list)
    all_psets = defaultdict(list)
    all_targets = defaultdict(list)
    all_coverage = defaultdict(list)
    all_size = defaultdict(list)

    for i in range(num_trials):
        for name, results in datasets.items():
            scores = results['scores'][:, 0, :]  
            targets = results['targets']
            
            # shuffle data
            index = np.arange(len(scores))
            np.random.shuffle(index)

            # split data into calibration and test sets
            
            n = len(scores) // 2
            # n = int(0.05 * len(scores))
            cal_scores = scores[index][:n]
            cal_targets = targets[index][:n]
            val_scores = scores[index][n:]
            val_targets = targets[index][n:]

            # estimate 1-alpha quantile on calibration set
            q = calibrate_lac(cal_scores, cal_targets, alpha=alpha)

            # make prediction sets on test set
            psets = inference_lac(val_scores, q)
            
            all_psets[name].append(psets)
            all_scores[name].append(val_scores)
            all_targets[name].append(val_targets)
            
            # Calculate coverage and set size for this trial
            coverage = get_coverage(psets, val_targets)
            size = get_size(psets)
            
            all_coverage[name].append(coverage)
            all_size[name].append(size)

    print(f'COVERAGE at alpha: {alpha}')
    print()

    for name in datasets.keys():
        print(name.center(50, '-'))
        mean_coverage = np.mean(all_coverage[name])
        std_coverage = np.std(all_coverage[name])
        print(f'{mean_coverage:.2%} +/- {std_coverage:.2%}')
        print()

    print('********************')
    print(f'SET SIZES at alpha: {alpha}')
    print()

    for name in datasets.keys():
        print(name.center(50, '-'))
        mean_size = np.mean(all_size[name])
        std_size = np.std(all_size[name])
        print(f'{mean_size:.2f} +/- {std_size:.1f}')
        print()
            
    return all_coverage, all_size


def get_lac_using_mean(datasets, alpha, num_trials):
    all_q_hats = defaultdict(list)
    all_scores = defaultdict(list)
    all_psets = defaultdict(list)
    all_targets = defaultdict(list)
    all_coverage = defaultdict(list)
    all_size = defaultdict(list)

    for i in range(num_trials):
        for name, results in datasets.items():
            scores = results['scores'].mean(axis=1)
            targets = results['targets']
            
            # shuffle data
            index = np.arange(len(scores))
            np.random.shuffle(index)

            # split data into calibration and test sets
            
            n = len(scores) // 2
            # n = int(0.05 * len(scores))
            cal_scores = scores[index][:n]
            cal_targets = targets[index][:n]
            val_scores = scores[index][n:]
            val_targets = targets[index][n:]

            # estimate 1-alpha quantile on calibration set
            q = calibrate_lac(cal_scores, cal_targets, alpha=alpha)

            # make prediction sets on test set
            psets = inference_lac(val_scores, q)
            
            all_psets[name].append(psets)
            all_scores[name].append(val_scores)
            all_targets[name].append(val_targets)
            
            # Calculate coverage and set size for this trial
            coverage = get_coverage(psets, val_targets)
            size = get_size(psets)
            
            all_coverage[name].append(coverage)
            all_size[name].append(size)

    print(f'MEAN COVERAGE at alpha: {alpha}')
    print()

    for name in datasets.keys():
        print(name.center(50, '-'))
        mean_coverage = np.mean(all_coverage[name])
        std_coverage = np.std(all_coverage[name])
        print(f'{mean_coverage:.2%} +/- {std_coverage:.2%}')
        print()

    print('********************')
    print(f'MEAN SET SIZES at alpha: {alpha}')
    print()

    for name in datasets.keys():
        print(name.center(50, '-'))
        mean_size = np.mean(all_size[name])
        std_size = np.std(all_size[name])
        print(f'{mean_size:.2f} +/- {std_size:.1f}')
        print()
            
    return all_coverage, all_size



In [None]:
all_results = []


    
alpha = 0.05  # Desired error rate
num_trials = 100  # Number of random trials to perform   

    
all_sizes_rvalue, all_correct_predictions = calculate_r_value_fast(datasets, alpha, num_trials)
all_coverage, all_size = get_lac(datasets, alpha, num_trials)
all_coverage_using_mean, all_size_using_mean = get_lac_using_mean(datasets, alpha, num_trials)

lac_coverage = defaultdict(list)
for name in datasets.keys():
    lac_coverage[name] = np.mean(all_coverage[name])
r_value_coverage = defaultdict(list)
for name in datasets.keys():
    r_value_coverage[name] = np.mean(all_correct_predictions[name])
lac_set_sizes = defaultdict(list)
for name in datasets.keys():
    lac_set_sizes[name] = np.mean(all_size[name])
r_value_set_sizes = defaultdict(list)
for name in datasets.keys():
    r_value_set_sizes[name] = np.mean(all_sizes_rvalue[name])
lac_using_mean_coverage = defaultdict(list)
for name in datasets.keys():
    lac_using_mean_coverage[name] = np.mean(all_coverage_using_mean[name])
lac_using_mean_set_sizes = defaultdict(list)
for name in datasets.keys():
    lac_using_mean_set_sizes[name] = np.mean(all_size_using_mean[name])


all_results={
    'lac_coverage': lac_coverage,
    'r_value_coverage': r_value_coverage,
    'lac_set_sizes': lac_set_sizes,
    'r_value_set_sizes': r_value_set_sizes,
    'lac_using_mean_coverage': lac_using_mean_coverage,
    'lac_using_mean_set_sizes': lac_using_mean_set_sizes,
}
    

In [None]:
all_results