diff --git a/optillm/cot_decoding.py b/optillm/cot_decoding.py index 2f5b173c..8e337917 100644 --- a/optillm/cot_decoding.py +++ b/optillm/cot_decoding.py @@ -27,10 +27,10 @@ def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) - break token_logits = logits[t] probs = torch.softmax(token_logits, dim=-1) - if probs.size(0) > 1: - top_2_probs, _ = torch.topk(probs, min(2, probs.size(0))) - if top_2_probs.size(0) > 1: - confidence_sum += (top_2_probs[0] - top_2_probs[1]).item() + if probs.size(-1) > 1: + top_2_probs, _ = torch.topk(probs, min(2, probs.size(-1))) + if top_2_probs.size(-1) > 1: + confidence_sum += (top_2_probs[-1][0] - top_2_probs[-1][1]).item() else: confidence_sum += 1.0 # Max confidence if there's only one token else: @@ -142,4 +142,4 @@ def cot_decode( return aggregate_paths_based_on_scores(paths) else: return max(paths, key=lambda x: x[1]) - \ No newline at end of file +