In [None]:
import pandas as pd
import os
import sys
import json
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np

# ACC
from sklearn.metrics import accuracy_score

def bootstrap_confidence_interval(preds_list, labels_list, mode, calc_ci=False):
    """
    params:
        preds: [[pred_per_img] per experiment]
        labels: [[label_per_img] per experiment]
    """
    assert mode in ["auroc", "auprc", "acc"], f"mode {mode} not supported!"
    assert isinstance(preds_list, list) and isinstance(labels_list, list)
    if not isinstance(preds_list[0], list):
        preds_list = [preds_list]
        labels_list = [labels_list]

    scores = []
    recall_with_fixed_precision90 = []
    recall_with_fixed_precision80= []
    precision_with_fixed_recall90 = []
    precision_with_fixed_recall80 = []
    for preds, labels in zip(preds_list, labels_list):
        if mode == "auroc":
            score = roc_auc_score(np.array(labels), np.array(preds))
        elif mode == "auprc":
            score = average_precision_score(np.array(labels), np.array(preds))
            precision, recall, _ = precision_recall_curve(np.array(labels), np.array(preds))
            recall_with_fixed_precision90.append(float(np.max(recall[precision >= 0.9])))
            recall_with_fixed_precision80.append(float(np.max(recall[precision >= 0.8])))
            precision_with_fixed_recall90.append(float(np.max(precision[recall >= 0.9])))
            precision_with_fixed_recall80.append(float(np.max(precision[recall >= 0.8])))
        else:
            score = accuracy_score(np.array(labels), (np.array(preds) > 0.5).astype(int))
        scores.append(score)
    score = np.mean(scores)
    if mode == "auprc":
        rec_prec90 = np.mean(recall_with_fixed_precision90)
        rec_prec80 = np.mean(recall_with_fixed_precision80)
        prec_rec90 = np.mean(precision_with_fixed_recall90)
        prec_rec80 = np.mean(precision_with_fixed_recall80)
    if not calc_ci:
        return f"{score:.4f}"

    bootstrap_repeat = 1000
    bootstrap_score_list = []
    for idx in range(bootstrap_repeat):
        while True:
            try:
                sample_idx = list(
                    np.random.choice(range(len(labels)), size=len(labels), replace=True)
                )
                bootstrap_scores = []
                for preds, labels in zip(preds_list, labels_list):
                    if mode == "auroc":
                        bootstrap_scores.append(roc_auc_score(
                            np.array(labels)[sample_idx],
                            np.array(preds)[sample_idx]
                        ))
                    elif mode == "auprc":
                        bootstrap_scores.append(average_precision_score(
                            np.array(labels)[sample_idx],
                            np.array(preds)[sample_idx]
                        ))
                    else:
                        bootstrap_scores.append(accuracy_score(
                            np.array(labels)[sample_idx],
                            (np.array(preds) > 0.5).astype(int)[sample_idx]
                        ))
                bootstrap_score_list.append(np.mean(bootstrap_scores))
                break
            except:
                continue
    sorted_score_list = sorted(bootstrap_score_list)
    score_25, score_975 = sorted_score_list[25-1], sorted_score_list[975-1]
    score_string = f"{score:.4f} ({2*score-score_975:.4f}, {2*score-score_25:.4f})"
    if mode == "auprc":
        score_string += f"; R@P90 {rec_prec90:.2f}; R@P80 {rec_prec80:.2f}; "
        score_string += f"P@R90 {prec_rec90:.2f}; P@R80 {prec_rec80:.2f}"
    return score_string

def get_prob(outputs):
    import math
    import numpy as np
    logprob_list = []
    is_answer = False
    for idx, x in enumerate(outputs["logprobs"]["content"]):
        if x["token"] == "<|im_end|>":
            break
        logprob_list.append(x["logprob"])
    prob = math.exp(np.mean(logprob_list))
    if "Benign" in outputs["response"] or "0" in outputs["response"]:
        prob = 1 - prob
    label = 0
    if "Malignant" in outputs["labels"] or "1" in outputs["labels"]:
        label = 1
    return prob, label

def evaluate(json_path):
    pred = pd.read_json(
        json_path,
        lines=True
    ).to_dict(orient="records")
    probs = []
    labels = []
    for p in pred:
        image = p['images'][0]['path']
        prob, label = get_prob(p)
        probs.append(prob)
        labels.append(label)
    print("ACC", bootstrap_confidence_interval(probs, labels, "acc", calc_ci=True))