In [4]:
from transformers import T5ForConditionalGeneration, AutoTokenizer
from alphaarc.task import Task
import torch
import matplotlib.pyplot as plt
import numpy as np 
import json
import torch, tqdm, pickle
from pathlib import Path
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from transformers import T5ForConditionalGeneration, AutoTokenizer
from alphaarc.task import Task
import torch
import matplotlib.pyplot as plt
import numpy as np 
import json
import tqdm

In [5]:
with open('../data/split_keys.json') as fp:
    json_object = json.load(fp)

validation_tasks = json_object['val']


In [6]:
from alphaarc.policy.tokenize import tokenize_task

def encode_task(task, tokenizer, model, input_state_max=256, n_examples=10, max_length=256): 
    tokenized_task = np.array(tokenize_task(task, tokenizer, n_examples, input_state_max, max_length)['input_ids'])
    return tokenized_task

In [37]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, \
                            accuracy_score, confusion_matrix, classification_report
import torch

# --- metrics.py -------------------------------------------------------------
import torch

def token_score(logits: torch.Tensor, ids: torch.Tensor | None = None,
                kind: str = "entropy") -> torch.Tensor:
   
    log_p = torch.log_softmax(logits, -1)         # [1, L, V]
    p      = log_p.exp()

    if kind == "entropy":
        score = -(p * log_p).sum(-1)              # [1, L]
    elif kind == "nll":
        if ids is None:
            raise ValueError("ids required for kind='nll'")
        score = -log_p.gather(-1, ids.unsqueeze(0).unsqueeze(-1)).squeeze(-1)
    elif kind == "max_prob":
        score = 1.0 - p.max(-1).values            # 1 – confidence
    else:
        raise ValueError(f"Unknown kind: {kind}")

    return score.squeeze(0).detach().cpu()        # [L]


# --- eval.py ----------------------------------------------------------------
@torch.no_grad()
def gather_val_records(val_tasks, tok, model, score_kind="entropy",
                       tau=0.5, batch_size=4,
                       top_k=8,
                       prob_mass=0.9,
                       prob_thresh=0):
    """
    Returns three parallel lists
        entropy  : per-token entropy / score
        is_error : 1 if token is wrong, else 0
        is_confident_error : 1 if wrong *and* score < τ
    """
    model.eval()
    score_lst, err_lst, conf_err_lst = [], [], []
    error_free_tasks = []
    
    for task in tqdm.tqdm(val_tasks, desc=f"val ({score_kind})"):
        src_ids = torch.tensor(encode_task(task, tok, model)).unsqueeze(0)
        tgt_ids = tok(task.program_lines, return_tensors="pt").input_ids
        src_ids, tgt_ids = src_ids.to(model.device), tgt_ids.to(model.device)

        logits = model(src_ids, labels=tgt_ids).logits        # [1, L, V]
        score  = token_score(logits, tgt_ids, kind=score_kind)# [L]
        #pred   = logits.argmax(-1).squeeze(0)                 # [L]
        #wrong  = (pred != tgt_ids.squeeze(0)).float()         # [L]
        """k = max(1, top_k)                       # safety
        topk_idx = logits.topk(k, dim=-1).indices          # [1, L, k]
        tgt_exp  = tgt_ids.unsqueeze(-1)                   # [1, L, 1]
        is_correct = (topk_idx == tgt_exp).any(-1).squeeze(0).float()  # [L]
        wrong      = 1.0 - is_correct """  
        
        """prob_sorted, idx_sorted = logits.softmax(-1).sort(-1, descending=True)
        cum_prob    = prob_sorted.cumsum(-1)                    # [1, L, V]
        # first position where cumulative probability reaches the threshold
        cutoff_idx  = (cum_prob >= prob_mass).float().argmax(-1)  # [1, L]
        # rank (0-based) of the reference token in the sorted list
        tgt_rank    = (idx_sorted == tgt_ids.unsqueeze(-1)).float().argmax(-1)
        is_correct  = (tgt_rank <= cutoff_idx).squeeze(0).float()  # [L]
        wrong       = 1.0 - is_correct"""
        """probs      = logits.softmax(-1)                              # [1, L, V]
        tgt_prob   = probs.gather(-1, tgt_ids.unsqueeze(-1))         # [1, L, 1]
        is_correct = (tgt_prob.squeeze(-1) >= prob_thresh).float()   # [1, L]
        wrong      = 1.0 - is_correct.squeeze(0)   
        """
        
        k          = max(1, top_k)                                  # safety
        probs      = logits.softmax(-1)                             # [1, L, V]

        topk_prob, topk_idx = probs.topk(k, dim=-1)                 # [1, L, k]
        tgt_exp    = tgt_ids.unsqueeze(-1)                          # [1, L, 1]

        in_topk    = (topk_idx == tgt_exp).any(-1)                  # [1, L]
        tgt_prob   = probs.gather(-1, tgt_exp).squeeze(-1)          # [1, L]

        is_correct = (in_topk & (tgt_prob >= prob_thresh)).float()  # [1, L]
        wrong      = 1.0 - is_correct.squeeze(0) 

        score_lst.extend(score.tolist())
        err_lst.extend(wrong.tolist())
        conf_error = ((score < tau) & wrong.bool()).float().tolist()
        conf_err_lst.extend(conf_error)

        if sum(conf_error) == 0:
            error_free_tasks.append(task)

    print(len(error_free_tasks))
    return score_lst, err_lst, conf_err_lst


def entropy_reliability(records, tau=0.5):


    score, err, conf_err = map(torch.tensor, zip(*records))

    n_err          = err.sum().item()
    n_conf_err     = conf_err.sum().item()
    frac_conf_err  = n_conf_err / max(n_err, 1)

    print(f"Total tokens         : {len(score):,}")
    print(f"Errors               : {n_err:,}")
    print(f"Confident errors (<τ): {n_conf_err:,}  "
          f"({frac_conf_err*100:.1f}% of all errors with τ={tau})\n")

    # --- diagnostic curves treat LOW score = high confidence, so invert sign
    inv = -score.numpy()
    auc_roc = roc_auc_score(err, inv)
    prec, rec, thr = precision_recall_curve(err, inv)
    pr_auc = auc(rec, prec)

    print(f"ROC-AUC (conf vs err): {auc_roc:.3f}")
    print(f"PR-AUC               : {pr_auc:.3f}")

    # Simple two-way classification: “confident error” vs everything else
    pred_err = score < tau
    overall_acc = accuracy_score(err, pred_err)
    tn, fp, fn, tp = confusion_matrix(err, pred_err).ravel()

    print(f"Overall accuracy     : {overall_acc:.3f}")
    print(f"Specificity (correct): {tn / (tn+fp):.3f}")
    print(f"Recall (errors)      : {tp / (tp+fn):.3f}\n")
    print(classification_report(err, pred_err,
                                target_names=["correct", "error"], digits=3))


In [40]:
model = T5ForConditionalGeneration.from_pretrained('../finetune-checkpoint/dev-checkpoint')
tok = AutoTokenizer.from_pretrained('Salesforce/codet5p-220m')

val_tasks = []


for file_name in json_object['val']:
    task = Task.from_json(f'../data/training/{file_name}.json')
    val_tasks.append(task )
 


tau = 0.1
score_kind = "entropy"          # or "nll" / "max_prob"

score, err, conf_err = gather_val_records(
        val_tasks, tok, model, score_kind=score_kind, tau=tau)

entropy_reliability(list(zip(score, err, conf_err)), tau=tau)

val (entropy):   4%|▍         | 4/89 [00:00<00:07, 12.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (520 > 512). Running this sequence through the model will result in indexing errors
val (entropy): 100%|██████████| 89/89 [00:12<00:00,  7.23it/s]

62
Total tokens         : 13,092
Errors               : 655.0
Confident errors (<τ): 53.0  (8.1% of all errors with τ=0.1)

ROC-AUC (conf vs err): 0.057
PR-AUC               : 0.026
Overall accuracy     : 0.168
Specificity (correct): 0.173
Recall (errors)      : 0.081

              precision    recall  f1-score   support

     correct      0.781     0.173     0.283     12437
       error      0.005     0.081     0.010       655

    accuracy                          0.168     13092
   macro avg      0.393     0.127     0.146     13092
weighted avg      0.742     0.168     0.270     13092




