In [21]:
from sklearn.metrics import multilabel_confusion_matrix
from OTE import divide
import numpy as np

In [22]:
MODEL_NAME = "deepset/gbert-base"
MAX_TOKENS = 256
RANDOM_SEED = 43
BATCH_SIZE = 16
N_EPOCHS = 1
LEARNING_RATE = 5e-06

label2id = {
    'O': 0,
    'B': 1,
    'I': 2,
}
id2label = {
    0: 'O',
    1: 'B',
    2: 'I',
}

n_labels = len(id2label)

In [23]:
loaded_data = np.load("metrics_data.npy", allow_pickle=True)
aspect_categories = loaded_data.item()["aspect_categories"]
predictions = loaded_data.item()["predictions"]

labels = loaded_data.item()["true_labels"]

In [24]:
def one_hot_to_label(one_hot):
    return next(id2label[idx] for idx in range(len(one_hot)) if one_hot[idx] == 1)   

In [25]:
def find_bio_phrases(bio_list):
    phrases = []
    phrase_start = None

    for i in range(len(bio_list)):
        if bio_list[i] == 'B':
            if phrase_start is not None:
                phrase_end = i - 1
                phrases.append({"start": phrase_start, "end": phrase_end})
            phrase_start = i
        elif bio_list[i] == 'O':
            if phrase_start is not None:
                phrase_end = i - 1
                phrases.append({"start": phrase_start, "end": phrase_end})
                phrase_start = None

    if phrase_start is not None:
        phrases.append({"start": phrase_start, "end": len(bio_list) - 1})

    return phrases

In [26]:
import numpy as np

def calculate_tp_tn_fp_fn_spans(pred, label):
    """
    Calculate true positives (TP), true negatives (TN), false positives (FP), and false negatives (FN) based on the provided
    lists of predicted and actual label ranges.

    Args:
        pred (list of dict): A list containing dictionaries representing predicted ranges with 'start' and 'end' values.
        label (list of dict): A list containing dictionaries representing actual label ranges with 'start' and 'end' values.

    Returns:
        tuple: A tuple containing four values - TP (true positives), FP (false positives), and FN (false negatives).
    """
    # Convert ranges to string representations and create sets.
    pred_set = set(f"{range['start']}_{range['end']}" for range in pred)
    label_set = set(f"{range['start']}_{range['end']}" for range in label)

    # Calculate true positives by finding the intersection of the sets.
    tp_set = pred_set & label_set
    tp = len(tp_set)

    # Calculate false positives by subtracting the intersection from the predicted set.
    fp_set = pred_set - tp_set
    fp = len(fp_set)

    # Calculate false negatives by subtracting the intersection from the label set.
    fn_set = label_set - tp_set
    fn = len(fn_set)

    # Calculate true negatives by considering all possible pairs and subtracting TP, FP, and FN.
    total_possible_pairs = len(pred) * len(label)

    return tp, 0, fp, fn


In [27]:
def calculate_tp_tn_fp_fn_labels_OTE(predictions, labels):
    predictions = (predictions == predictions.max(axis=2)[:,:,np.newaxis]).astype(int)
    tp_total = 0
    tn_total = 0
    fp_total = 0
    fn_total = 0
    for i in range(len(labels)):
        label = find_bio_phrases([one_hot_to_label(p) for p in labels[i]])
        pred = find_bio_phrases([one_hot_to_label(p) for p in predictions[i]])
        tp, tn, fp, fn = calculate_tp_tn_fp_fn_spans(pred, label)
        tp_total += tp
        tn_total += tn
        fp_total += fp
        fn_total += fn
    return tp_total, tn_total, fp_total, fn_total

calculate_tp_tn_fp_fn_labels_OTE(predictions, labels)

(264, 0, 12322, 269)

In [32]:
np.set_printoptions(precision=2, suppress=True)

In [33]:
predictions[1][:20]

array([[0.03, 0.35, 0.61],
       [0.01, 0.51, 0.49],
       [0.01, 0.24, 0.75],
       [0.03, 0.53, 0.45],
       [0.03, 0.43, 0.54],
       [0.04, 0.34, 0.62],
       [0.05, 0.69, 0.26],
       [0.03, 0.7 , 0.27],
       [0.04, 0.55, 0.42],
       [0.03, 0.56, 0.4 ],
       [0.01, 0.87, 0.12],
       [0.05, 0.48, 0.47],
       [0.04, 0.26, 0.7 ],
       [0.02, 0.63, 0.36],
       [0.02, 0.62, 0.36],
       [0.02, 0.32, 0.66],
       [0.03, 0.55, 0.43],
       [0.04, 0.25, 0.71],
       [0.02, 0.62, 0.36],
       [0.05, 0.22, 0.73]], dtype=float32)

In [34]:
labels[1][:20]

array([[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0]])