In [1]:
import csv
import json
import os
import numpy as np
import pandas as pd
import ast
from sklearn.metrics import multilabel_confusion_matrix
import tqdm
from utils import *
from type_vocabs import *
import tiktoken
enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0125")
from collections import defaultdict

In [2]:
def serialize(types):
    cnt = 0
    result = {}
    for type in type_vocab:
        result[type] = 0
    for item in types:
        if item in result:
            result[item] = 1
        else: 
            cnt += 1

    return result, cnt

def f1_score_multilabel(true_list, pred_list, types_label=None):
    if types_label is not None:
        conf_mat = multilabel_confusion_matrix(np.array(true_list),
                                           np.array(pred_list),labels=types_label)
    else: conf_mat = multilabel_confusion_matrix(np.array(true_list),
                                           np.array(pred_list))
    agg_conf_mat = conf_mat.sum(axis=0)
    # Note: Pos F1
    # [[TN FP], [FN, TP]] if we consider 1 as the positive class
    p = agg_conf_mat[1, 1] / agg_conf_mat[1, :].sum()
    r = agg_conf_mat[1, 1] / agg_conf_mat[:, 1].sum()
    
    micro_f1 = 2 * p * r / (p  + r) if (p + r) > 0 else 0.
    class_p = conf_mat[:, 1, 1] /  conf_mat[:, 1, :].sum(axis=1)
    class_r = conf_mat[:, 1, 1] /  conf_mat[:, :, 1].sum(axis=1)
    class_f1 = np.divide(2 * (class_p * class_r), class_p + class_r,
                         out=np.zeros_like(class_p), where=(class_p + class_r) != 0)
    class_f1 = np.nan_to_num(class_f1)
    macro_f1 = class_f1.mean()
    return (micro_f1, macro_f1, class_f1, conf_mat)

In [3]:
def sotab_eval(data_dir, labels, type_vocab, col_pairs=None, is_triplet=False, correct=False):
    preds = []
    ground_truth = []
    num_oov = 0
    hints = []
    with open(data_dir, mode ='r') as file:
        csvFile = csv.reader(file)
        for id, lines in enumerate(csvFile):
                num = lines[0]
                idx = lines[1]
                label = labels[id]
                ground_truth.append(str(label))
                hint = ""
                if len(lines) <= 2:
                    preds.append("")
                    hints.append("")
                    continue
                if len(lines) >= 4:
                    hint = lines[3]
                # if hint is not None and is_triplet: 
                #     hint = hint.split(":", 1)[1].strip()
                hints.append(hint)
                
                try:
                    data = json.loads(lines[2].replace("'", '"'))
                    type_value = data.get('relation', [])
                    if isinstance(type_value, list) and len(type_value) > 0:
                        type_value = type_value[0]
                    preds.append(str(type_value))
                    if type_value not in type_vocab:
                        num_oov += 1
                    if not correct and type_value != label:
                        col_pairs.append((num,idx,type_value,label,hint))
                    if correct and type_value == label:
                        col_pairs.append((num,idx,type_value,label,hint))
                except json.JSONDecodeError:
                    preds.append("")
                    print(f"Error decoding JSON in line: {lines}")
    print("len(preds): ", len(preds))
    micro_f1, macro_f1, class_f1, conf_mat = f1_score_multilabel(ground_truth, preds, type_vocab)
    return micro_f1, macro_f1, class_f1, conf_mat, num_oov, hints

In [4]:
def single_label_eval(data_dir, labels, wrong_col_pairs = None, decode_ok=0):
    preds = []
    ground_truth = []
    empty = serialize([])[0].values()
    empty = [*empty]
    num_oov = 0
    with open(data_dir, mode ='r') as file:
        csvFile = csv.reader(file)
        for id, lines in enumerate(csvFile):
                pred = []
                if len(lines) <= 2:
                    preds.append(empty)
                    continue
                try:
                    num = lines[0]
                    idx = lines[1]
                    data = json.loads(lines[2].replace("'", '"'))
                    type_value = data.get('type', [])
                    label = labels[id]
                    pred_dic, cnt = serialize(type_value)
                    num_oov += cnt
                    pred = pred_dic.values()
                    pred = [*pred]
                    preds.append(pred)
                    decode_ok += 1
                    if len(type_value) > 0 and type_value[0] in label:
                        gt = pred
                    else:
                        gt = serialize(label)[0].values()
                        gt = [*gt]
                        if wrong_col_pairs is not None: 
                            wrong_col_pairs.append((num,idx,type_value ,label))

                except json.JSONDecodeError:
                    
                    preds.append(empty)
                    label = labels[id]
                    gt = serialize(label)[0].values()
                    gt = [*gt]
                    print(f"Error decoding JSON in line: {lines}")
                ground_truth.append(gt)
    print("len(preds): ", len(preds))
    micro_f1, macro_f1, class_f1, conf_mat = f1_score_multilabel(ground_truth, preds)
    return micro_f1, macro_f1, class_f1, conf_mat, num_oov, decode_ok

def multi_label_eval(data_dir, labels):
    ground_truth = []
    for label in labels:
        label = ast.literal_eval(label)
        gt = serialize(label).values()
        gt = [*gt]
        ground_truth.append(gt)
    preds = []
    empty = serialize([])[0].values()
    empty = [*empty]
    with open(data_dir, mode ='r') as file:
        csvFile = csv.reader(file)
        for lines in csvFile:
                pred = []
                if len(lines) <= 2:
                    preds.append(empty)
                    continue
                try:
                    data = json.loads(lines[2].replace("'", '"'))
                    type_value = data.get('type', [])
                    pred = serialize(type_value).values()
                    pred = [*pred]
                    preds.append(pred)
                except json.JSONDecodeError:
                    preds.append(empty)
    
    micro_f1, macro_f1, class_f1, conf_mat = f1_score_multilabel(ground_truth, preds)
    return micro_f1, macro_f1, class_f1, conf_mat 

In [6]:
labels = []
label_dir = '/mmfs1/gscratch/balazinska/linxiwei/TURL/Baseline-TURL/sample_labels.txt'
with open(label_dir, 'r') as file:
    labels = file.readlines()
labels = [line.strip() for line in labels]
pred_dir = ''
micro_f1, macro_f1, class_f1, conf_mat, num_oov, hints = single_label_eval(pred_dir, labels)
print(micro_f1)
print(num_oov)

Error decoding JSON in line: ['149', '0', '```json\n{\n    "type": ["location.citytown"]\n}\n```', "['Henbury', 'Clifton', 'Cotham', 'Clifton', 'Fishponds']"]
Error decoding JSON in line: ['149', '1', '```json\n{\n    "type": ["location.city"]\n}\n```', "['Bristol', 'Somerset', 'Bristol', 'Bristol', 'Bristol']"]
Error decoding JSON in line: ['362', '0', '```json\n{\n    "type": [\n        "book.periodical_subject"\n    ]\n}\n```', "['abdominal aortic aneurysm', 'Athletics Australia', 'year', 'year', 'year']"]
Error decoding JSON in line: ['362', '1', '```json\n{\n    "type": [\n        "basketball.basketball_team"\n    ]\n}\n```', '[]']
Error decoding JSON in line: ['362', '2', '```json\n{\n    "type": [\n        "military.military_person"\n    ]\n}\n```', '[]']
Error decoding JSON in line: ['424', '0', '```json\n{\n    "type": ["people.person"]\n}\n```  ', '[]']
Error decoding JSON in line: ['424', '1', '```json\n{\n    "type": ["location.country"]\n}\n```   ', "['France', 'Spain', 'B

  class_p = conf_mat[:, 1, 1] /  conf_mat[:, 1, :].sum(axis=1)
  class_r = conf_mat[:, 1, 1] /  conf_mat[:, :, 1].sum(axis=1)
