In [2]:
import torch, tqdm, pickle
from pathlib import Path
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from transformers import T5ForConditionalGeneration, AutoTokenizer
from alphaarc.task import Task
import torch
import matplotlib.pyplot as plt
import numpy as np 
import json

In [3]:
with open('../data/split_keys.json') as fp:
    json_object = json.load(fp)

json_object['val']

['6150a2bd',
 'c8f0f002',
 '68b16354',
 'c9e6f938',
 '6fa7a44f',
 'ac0a08a4',
 '4258a5f9',
 '9ecd008a',
 '25ff71a9',
 '3af2c5a8',
 '67e8384a',
 '56ff96f3',
 'dc1df850',
 'ae4f1146',
 'aabf363d',
 'aedd82e4',
 '445eab21',
 '41e4d17e',
 'f5b8619d',
 '00d62c1b',
 '3906de3d',
 '8d5021e8',
 '0ca9ddb6',
 'd0f5fe59',
 'ba97ae07',
 'd9fac9be',
 'ded97339',
 'f8ff0b80',
 'dbc1a6ce',
 '496994bd',
 'fcb5c309',
 'ff805c23',
 '5c0a986e',
 '7f4411dc',
 '6d75e8bb',
 'e76a88a6',
 '8f2ea7aa',
 '137eaa0f',
 '6e82a1ae',
 'a5f85a15',
 '4be741c5',
 '681b3aeb',
 '3ac3eb23',
 '253bf280',
 '3428a4f5',
 'a9f96cdd',
 '1caeab9d',
 '3de23699',
 'bdad9b1f',
 '88a62173',
 'f8b3ba0a',
 '88a10436',
 '7e0986d6',
 'd4a91cb9',
 '913fb3ed',
 '3631a71a',
 '95990924',
 '93b581b8',
 '9edfc990',
 'a65b410d',
 'ef135b50',
 '6e19193c',
 'cbded52d',
 'ae3edfdc',
 '4612dd53',
 'e48d4e1a',
 '846bdb03',
 '2204b7a8',
 'e5062a87',
 '539a4f51',
 '91413438',
 'e8dc4411',
 'e40b9e2f',
 'd6ad076f',
 'b190f7f5',
 '6cdd2623',
 'f9012d9b',

In [4]:
from alphaarc.policy.tokenize import tokenize_task
def encode_task(task, tokenizer, model, input_state_max=256, n_examples=10, max_length=256): 
    tokenized_task = np.array(tokenize_task(task, tokenizer, n_examples, input_state_max, max_length)['input_ids'])
    return tokenized_task

In [5]:

def token_stats(logits, ids):
    """
    Given logits   [1, L, V]  and ground-truth ids  [1, L],
    return per-token entropy, is_correct flag, and p_true.
    """
    log_p  = torch.log_softmax(logits, -1)           # [1, L, V]
    probs  = log_p.exp()

    # entropy H_t
    entropy = -(probs * log_p).sum(-1).squeeze(0)    # [L]

    # probability assigned to the ground-truth token
    p_true  = probs.gather(2, ids.unsqueeze(-1)).squeeze(0).squeeze(-1)  # [L]

    # was the ground-truth the top-1 token?
    p_max, _ = probs.max(-1)
    is_correct = (p_true == p_max.squeeze(0)).float()                    # [L]

    return entropy.cpu(), is_correct.cpu(), p_true.cpu()


In [6]:
def gather_val_records(val_tasks, tok, model, cache_file="entropy_val.pkl"):
    """
    Returns a list of (entropy, is_error) per token across the whole split.
    Uses a cache on disk so you can rerun analyses instantly.
    """
    """if Path(cache_file).exists():
        return pickle.load(open(cache_file, "rb"))"""

    model.eval()
    records = []      # (entropy, is_error)

    for task in tqdm.tqdm(val_tasks, desc="val"):
        # --- encode once, as you already do ------------------------------
        input_ = torch.tensor(encode_task(task, tok, model)).unsqueeze(0).to(model.device)
        ids    = tok(task.program_lines, return_tensors="pt").input_ids.to(model.device)

        with torch.no_grad():
            logits = model(input_, labels=ids).logits   # [1, L, V]

        H, correct, _ = token_stats(logits, ids)
        for h, c in zip(H, correct):
            records.append((h.item(), int(1 - c.item())))   # is_error = 1 - correct

    pickle.dump(records, open(cache_file, "wb"))
    return records


In [7]:
def recall_at_tau(entropy, is_error, tau):
    ent = torch.tensor(entropy)
    err = torch.tensor(is_error).bool()
    tp  = ((ent > tau) & err).sum()
    fn  = ((ent <= tau) & err).sum()
    return (tp / (tp + fn + 1e-9)).item()          # avoid 0-div



In [8]:
import torch

def entropy_reliability(records, recall_target=0.95, tau=0.5):
      ent, err = zip(*records)               # lists of floats / ints
      ent = torch.tensor(ent)
      err = torch.tensor(err)                # 1 = error, 0 = correct

      print(f"Recall @ τ={tau} bits : {recall_at_tau(ent, err, tau):.3f}")
      # ROC-AUC
      auc_roc = roc_auc_score(err, ent)
      print(f"ROC-AUC (entropy vs error): {auc_roc:.3f}")

      # Precision-Recall
      prec, rec, thr = precision_recall_curve(err, ent)
      pr_auc = auc(rec, prec)
      print(f"PR-AUC                         : {pr_auc:.3f}")

      """# choose τ to capture recall_target of errors
      idx = (rec >= recall_target).nonzero()[0][0]
      tau = thr[idx].item()
      print(f"τ at {recall_target:.0%} error recall: {tau:.2f} "
            f"(precision={prec[idx]:.2f})")
      """
      from sklearn.metrics import accuracy_score, confusion_matrix, classification_report


      # ------------------------------------------------------------
      # inputs you already have
      # ------------------------------------------------------------
      ent = torch.tensor(ent)        # [N]  entropy per token (or per sequence)
      err = torch.tensor(err).bool() # [N]  ground-truth: 1 = error, 0 = correct

      # ------------------------------------------------------------
      # 1. predicted class from the entropy rule
      # ------------------------------------------------------------
      pred_err = (ent > tau)         # True ⇒ "I think it's an error"

      # ------------------------------------------------------------
      # 2. overall accuracy
      # ------------------------------------------------------------
      overall_acc = accuracy_score(err, pred_err)
      print(f"Overall accuracy : {overall_acc:.3f}")

      # ------------------------------------------------------------
      # 3. confusion matrix  →  per-class accuracies
      # ------------------------------------------------------------
      tn, fp, fn, tp = confusion_matrix(err, pred_err).ravel()
      print(f"Correct-class accuracy (specificity) : {tn / (tn+fp):.3f}")
      print(f"Error-class  accuracy (recall)       : {tp / (tp+fn):.3f}")

      # same numbers—plus precision & F1—in one call:
      print(classification_report(err, pred_err,
                              target_names=["correct", "error"], digits=3))

      return tau


In [11]:

model = T5ForConditionalGeneration.from_pretrained('../finetune-checkpoint/dev-checkpoint')
tok = AutoTokenizer.from_pretrained('Salesforce/codet5p-220m')

val_tasks = []


for file_name in json_object['val']:
    task = Task.from_json(f'../data/training/{file_name}.json')
    val_tasks.append(task )

# 1) gather statistics once
records = gather_val_records(val_tasks, tok, model)

# 2) measure reliability & pick τ
tau = entropy_reliability(records, tau=1 ) 

val:   4%|▍         | 4/89 [00:00<00:06, 13.11it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (520 > 512). Running this sequence through the model will result in indexing errors
val: 100%|██████████| 89/89 [00:11<00:00,  7.61it/s]

Recall @ τ=1 bits : 0.484
ROC-AUC (entropy vs error): 0.951
PR-AUC                         : 0.733
Overall accuracy : 0.915
Correct-class accuracy (specificity) : 0.980
Error-class  accuracy (recall)       : 0.484
              precision    recall  f1-score   support

     correct      0.927     0.980     0.953     11383
       error      0.783     0.484     0.598      1709

    accuracy                          0.915     13092
   macro avg      0.855     0.732     0.776     13092
weighted avg      0.908     0.915     0.906     13092




  ent = torch.tensor(entropy)
  err = torch.tensor(is_error).bool()
  ent = torch.tensor(ent)        # [N]  entropy per token (or per sequence)
  err = torch.tensor(err).bool() # [N]  ground-truth: 1 = error, 0 = correct
