In [10]:
from typing import List, Dict
import numpy as np
from scipy.stats import bootstrap, t, norm
from scipy import stats
from pathlib import Path
from collections import Counter
import json
import shutil
from sequence_level_trigger_warning_assignment.utility import major_warning, infer_label, multilabel_to_multiclass

In [17]:
scores_dir = Path("../../resources/classification-results/scores")
dataset_file = Path("../../resources/generated/dataset.jsonl")
predictions_dir = Path("../../resources/classification-results/predictions")
output_dir = Path("../../resources/classification-results")
output_dir.mkdir(exist_ok=True)

warning_plot_order = ["violence", "death", "war", "abduction", "racism", "homophobia", "misogyny", "ableism"]
model_plot_order = ["binary", "binary-extended", "multiclass", "multilabel", "multilabel-extended", "gpt3-5-turbo", "gpt4-turbo", "mistral-7b", "mixtral-8x7b", "llama-7b", "llama-13b"]

warning_map = {json.loads(example)["example_id"]: json.loads(example)["warning"] for example in open(dataset_file)}
votes_map = {json.loads(example)["example_id"]: sum(json.loads(example)["labels"]) for example in open(dataset_file)}

def _r(value, decimals=2, pct=False):
    v = round(value, decimals)
    if pct:
        v = v * 100
    v = str(v)
    if len(v) == 3:
        v = v + "0"
    return v

def _parse_file_name(filename):
    model = filename[:-3].removeprefix("acl24-")\
            .removeprefix("fanbert-").removesuffix("-20ep")\
            .removesuffix("-1e-5lr").removesuffix("-2e-5lr").removesuffix("-5e-5lr")
    aggregation = "minority" if model.endswith("minority") else "majority"
    model = model.removesuffix("-minority").removesuffix("-majority")
    distribution = "id" if model.endswith("id") else "ood"
    model = model.removesuffix("-id").removesuffix("-ood")
    return aggregation, distribution, model

def _load_examples(examples):
    if isinstance(examples[0]['labels'], list):
        mode = 'multilabel'
    elif 2 in {_["labels"] for _ in examples}:  # multiclass file
        mode = 'multiclass'
    else:  # multiclass file
        mode = 'binary'
    for example in examples:
        warning = warning_map[example["id"]]
        if mode == 'multilabel':
            truth = 1 if multilabel_to_multiclass(example["labels"], level='minor') == warning else 0
            prediction = 1 if multilabel_to_multiclass(example["predictions"], level='minor') == warning else 0
        elif mode == 'multiclass':
            truth = 0 if example["labels"] == 8 else 1
            prediction = 1 if example["predictions"] == example["labels"] and example["predictions"] != 8 else 0
        else: # binary file
            truth = example["labels"]
            prediction = example["predictions"]
        yield example["id"], "true" if truth == prediction else "false"

In [28]:
"""scores_dir: {
    "model": {"distribution": {"aggregation": {"warning": {"result": set(IDS)}}}}
}
"""

combinations = [("id","minority"), ("id","majority")] #, ("ood","minority"),("ood","majority")]
ft_true = set()  # add IDs of the documents
ft_false = set()
llm_true = set() # add IDs of the documents
llm_false = set()
results = {}
for predictions_file in predictions_dir.glob("*.jsonl"):
    aggregation, distribution, model = _parse_file_name(predictions_file.stem)
    examples = [json.loads(line) for line in open(predictions_file)]
    for _id, result in _load_examples(examples):
        results.setdefault(model, {}).setdefault(distribution, {}).setdefault(aggregation, {})\
            .setdefault(result, set()).add(_id)

# stuff thats correctly predicted by every model
for distribution, aggregation in combinations:
    for warning in warning_plot_order:
        for model in ["mixtral-8x7b"]:
            llm_true = llm_true.union(results[model][distribution][aggregation]["true"])
            llm_false = llm_false.union(results[model][distribution][aggregation]["false"])

        for model in ["binary"]:
            ft_true = ft_true.union(results[model][distribution][aggregation]["true"])
            ft_false = ft_false.union(results[model][distribution][aggregation]["false"])

unique_examples = llm_true.union(llm_false)
llm_always_true = llm_true - llm_false
llm_always_false = llm_false - llm_true
ft_always_true = ft_true - ft_false
ft_always_false = ft_false - ft_true

print("Unique Examples", unique_examples)
print("LLM Always true instances", len(llm_always_true), llm_always_true)
print("LLM Always false instances", len(llm_always_false), llm_always_false)
print("FT Always true instances",  len(ft_always_true), ft_always_true)
print("FT Always false instances",  len(ft_always_false), ft_always_false)

# by warning
print("llm_always_true warnings", Counter([warning_map[_id] for _id in llm_always_true]))
print("llm_always_false warnings", Counter([warning_map[_id] for _id in llm_always_false]))
print("ft_always_true warnings", Counter([warning_map[_id] for _id in ft_always_true]))
print("ft_always_false warnings", Counter([warning_map[_id] for _id in ft_always_false]))

# by votes
all_counts = Counter([votes for _id, votes in votes_map.items() if _id in unique_examples])
print(all_counts)
print("llm_always_true votes", [(elem, count, _r(count/all_counts[elem])) for elem, count in Counter([votes_map[_id] for _id in llm_always_true]).most_common()])
print("llm_always_false votes", [(elem, count, _r(count/all_counts[elem])) for elem, count in Counter([votes_map[_id] for _id in llm_always_false]).most_common()])
print("ft_always_true votes", [(elem, count, _r(count/all_counts[elem])) for elem, count in Counter([votes_map[_id] for _id in ft_always_true]).most_common()])
print("ft_always_false votes", [(elem, count, _r(count/all_counts[elem])) for elem, count in Counter([votes_map[_id] for _id in ft_always_false]).most_common()])

Unique Examples {86017, 86018, 86019, 86021, 86022, 83855, 86028, 86029, 95254, 86032, 83857, 86036, 86037, 86038, 86041, 86042, 86044, 86045, 86046, 86049, 86050, 86051, 86053, 83861, 86057, 86058, 86059, 83862, 86061, 86062, 86063, 86064, 86066, 86067, 86068, 86069, 86072, 86074, 86076, 86077, 86081, 86082, 86088, 86089, 86090, 86091, 83868, 86093, 83869, 86097, 93001, 86099, 86100, 86101, 86102, 83871, 91200, 83875, 83879, 87951, 92975, 83896, 95296, 95304, 84024, 83910, 88909, 92055, 92980, 83918, 88911, 83931, 92073, 83932, 83934, 88914, 83946, 83949, 94701, 94702, 94707, 94713, 94714, 94715, 94719, 94722, 94731, 94732, 94733, 94739, 94740, 94750, 94751, 94752, 94760, 94761, 94764, 94765, 83964, 83965, 94773, 94774, 84050, 94776, 94777, 94778, 94780, 94781, 94782, 94785, 92991, 84036, 94797, 94798, 94799, 88921, 94801, 94802, 94805, 83972, 94807, 94809, 106017, 94813, 94814, 84037, 94815, 94817, 94824, 83976, 94826, 94828, 94831, 94834, 94836, 88863, 94838, 94839, 94840, 92120, 92