# SAR Confidence

In [1]:
import torch
from typing import List, Dict
from sentence_transformers import CrossEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pickle

with open('/Users/ivan/Documents/uni-projects/LLM-Calibration-Study/cache/mmlu_pro_shared_sampling_gpt-4.1-mini_2_0', 'rb') as file:
    data = pickle.load(file)
data[0]["top_logprobs"]

[{'``': -0.002475777640938759,
  'Explanation': -6.002475738525391,
  '```': -15.75247573852539,
  '\n': -20.87747573852539,
  'Explain': -21.25247573852539},
 {'`\n': 0.0, '`\n\n': -20.25, '>\n': -20.75, '  \n': -21.5, '}\n': -21.625},
 {'Explanation': 0.0,
  'Expl': -22.25,
  ' Explanation': -22.25,
  'Explain': -23.125,
  'Ex': -24.5},
 {':': 0.0, ':\n': -19.0, ':\n\n': -23.25, ':The': -25.625, ':**': -26.75},
 {' For': -0.6039973497390747,
  ' If': -0.8539973497390747,
  ' In': -4.353997230529785,
  ' Both': -4.603997230529785,
  ' Since': -5.478997230529785},
 {' an': -0.0006354739889502525,
  ' autos': -7.375635623931885,
  ' a': -11.750635147094727,
  ' two': -14.125635147094727,
  ' the': -16.500635147094727},
 {' autos': 0.0,
  ' individual': -20.5,
  ' aut': -21.0,
  ' Autos': -22.875,
  ' auto': -24.0},
 {'omal': 0.0, 'ome': -22.0, 'om': -24.0, 'omial': -24.875, 'omes': -25.0},
 {' recess': 0.0,
  ' recession': -26.625,
  ' recessed': -28.375,
  '-re': -29.0,
  ' Re': -29.12

In [3]:
with open('/Users/ivan/Documents/uni-projects/LLM-Calibration-Study/cache/mmlu_pro_shared_sampling_Llama-3.3-70B-Instruct-Turbo-Free_3_0', 'rb') as file:
    display(pickle.load(file)[0]["top_logprobs"])

[{'Explanation': -9.059906e-06,
  'Expl': -12.125,
  ' Explanation': -13,
  'Ex': -13.625,
  'Expression': -16.375},
 {':': 0, ' :': -31.5, ':\n': -32, ':.': -35, ';': -36},
 {' Both': -0.54296875,
  ' Since': -1.0390625,
  ' The': -3.546875,
  ' For': -3.921875,
  ' To': -5.03125},
 {' parents': 0,
  'parents': -18.25,
  ' Parents': -19.625,
  'Parents': -22.25,
  '_parents': -23.375},
 {' are': -0.13085938,
  ' have': -2.125,
  ' having': -5.625,
  ' being': -9.5,
  ' must': -14.625},
 {' affected': -3.5762787e-07,
  'affected': -14.875,
  ' carriers': -19,
  ' afflicted': -19.125,
  ' affect': -19.5},
 {' with': -0.0019989014, ' by': -6.25, ',': -9.5, ' and': -22, 'with': -23},
 {' the': -0.07910156,
  ' an': -2.578125,
  ' a': -19.625,
  ' autos': -21.25,
  ' Autos': -25.375},
 {' same': 0,
  'same': -18.25,
  ' Same': -21.625,
  ' autos': -21.625,
  '_same': -24.375},
 {' autos': 0,
  'autos': -16.75,
  ' Autos': -18.25,
  '-aut': -26.25,
  ' recess': -26.625},
 {'omal': -1.192092

In [7]:
def token_sar_confidence(top_logprobs: List[Dict[str, float]]) -> float:
    """
    Compute SAR-based confidence from only top_logprobs (one per token).
    Assumes the first token in each dict is the generated token.
    
    Args:
        top_logprobs (List[Dict[str, float]]): List of token: logprob dictionaries.
        relevance_model (CrossEncoder): Semantic similarity model.
    
    Returns:
        float: SAR-style confidence score in [0, 1].
    """
    relevance_model = CrossEncoder("cross-encoder/nli-deberta-v3-small")
    # Step 1: Extract selected tokens
    tokens = [list(logprob_dict.keys())[0] for logprob_dict in top_logprobs]

    # Step 2: Compute entropy for each token
    entropies = []
    for logp_dict in top_logprobs:
        logprobs = torch.tensor(list(logp_dict.values()))
        probs = torch.exp(logprobs)
        entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
        entropies.append(entropy)
    
    # Step 3: Compute semantic relevance by measuring meaning drop
    sentence = " ".join(tokens)
    pairs = []
    for i in range(len(tokens)):
        reduced = " ".join(tokens[:i] + tokens[i+1:])
        pairs.append((sentence, reduced))

    # Predict all at once
    sims = relevance_model.predict(pairs, show_progress_bar=True)

    # Convert to relevances
    relevances = []
    for sim in sims:
        try:
            rel = max(0.0, 1.0 - sim)
        except:
            print(sim)
            rel = 0.0
        relevances.append(rel)
    
    # Normalise relevance
    total_rel = sum(relevances) + 1e-8
    weights = [r / total_rel for r in relevances]

    # Step 4: Compute token-level SAR
    token_SAR = sum(e * w for e, w in zip(entropies, weights))

    # Step 5: Convert to confidence score
    confidence = 1 / (1 + token_SAR)
    return confidence


In [8]:
q = 0
print(token_sar_confidence(data[q]["top_logprobs"]), data[q]["logit_perplexity"])

Batches: 100%|██████████| 3/3 [00:01<00:00,  2.33it/s]

[-2.65173    3.5465643 -0.9929013]
[-2.687227   3.6040251 -1.0089436]
[-2.7685914  3.638867  -0.961026 ]
[-2.6997988  3.6106145 -1.0019344]
[-2.9453888  3.873794  -1.0651044]
[-2.7122345  3.6113677 -0.9949131]
[-2.6555138  3.59995   -1.05159  ]
[-2.641945   3.5421128 -1.004776 ]
[-2.6737678  3.5908263 -1.0103232]
[-2.678328   3.5674772 -0.9878888]
[-2.7303782  3.5925689 -0.9551976]
[-2.816792    3.722238   -0.99260557]
[-2.6872663  3.6044455 -1.0131857]
[-2.7338817   3.5838804  -0.93588185]
[-2.6755815  3.5744944 -0.9899627]
[-2.6490116  3.4902349 -0.9284319]
[-2.668934   3.5838776 -1.0034283]
[-2.6407058  3.517412  -0.9682974]
[-2.6851997  3.575359  -0.9856201]
[-2.6912513  3.5966635 -1.0041847]
[-2.6800718  3.6080747 -1.0252361]
[-2.6837335   3.5580716  -0.96664107]
[-2.6833086   3.5071826  -0.89204586]
[-2.6955867   3.4277985  -0.77136445]
[-2.5602767   3.3831453  -0.92502403]
[-2.2848728  3.2087624 -1.0485114]
[-2.7087588  3.577721  -0.9561218]
[-2.4970608  3.406516  -0.9898862]
[-




In [6]:
print(data[q]["response"])

```
Explanation: For an autosomal recessive disorder, affected individuals have two recessive alleles (aa). If both parents are affected (aa x aa), all their children will inherit the recessive allele from both parents, making all children affected (aa).

Answer: F

Confidence (0-100): 100%
```
