In [202]:
import re
import itertools
import json
import pandas as pd
import numpy as np
from scipy.stats import beta
from prettytable import PrettyTable
from omegaconf import OmegaConf
from sklearn.metrics import f1_score, confusion_matrix, roc_curve, multilabel_confusion_matrix, matthews_corrcoef
from scipy.stats import spearmanr
from collections import Counter

from cgeval.rating import Ratings, Label, Observation
from cgeval import Report
from cgeval.report import GenericReport
from cgeval.distribution import Beta, BetaParams

In [203]:
SUBSAMPLING = False
BASE_PATH = f'../../out/pipeline'

TEXT = {
    'Llama-3.3-70B': f'{BASE_PATH}/2025-05-15_sentiment_analysis_llama3-3',
    'Llama-2-7B': f'{BASE_PATH}/2025-05-15_sentiment_analysis_llama2',
    'Mistral-7B': f'{BASE_PATH}/2025-05-15_sentiment_analysis_mistral',
}

IMAGE_MODEL_ANIMAL = {
    'Stable Diffusion 3.5': f'{BASE_PATH}/2025-05-15_animal_detection_stable-diffusion_animal',
    'Stable Cascade': f'{BASE_PATH}/2025-05-15_animal_detection_stable-cascade_animal',
    'FLUX.1-dev': f'{BASE_PATH}/2025-05-15_animal_detection_flux_animal'
}

IMAGE_MODEL_COUNT = {
    'Stable Diffusion 3.5': f'{BASE_PATH}/2025-05-15_animal_detection_stable-diffusion_count',
    'Stable Cascade': f'{BASE_PATH}/2025-05-15_animal_detection_stable-cascade_count',
    'FLUX.1-dev': f'{BASE_PATH}/2025-05-15_animal_detection_flux_count'
}

cls_name2id = {
    'FinancialBERT': 'FIB',
    'lxyuan_DistilBert': 'DSS',
    'ollama3.2': 'LL3',
    'Yolov8': 'VO8',
    'DETR': 'DTR',
    'LLaVA': 'LLV'
}

model_name2id = {
    'Llama-3.3-70B': 'L370B',
    'Llama-2-7B': 'L27B',
    'Mistral-7B': 'Mi7B',
    'Stable Diffusion 3.5': 'SD35',
    'Stable Cascade': 'StCa',
    'FLUX.1-dev': 'FLX1'
}

In [204]:
def pretty_print_latex(latex_str):
    lines = latex_str.replace(r" \\ ", r" \\" + "\n").splitlines()
    formatted_lines = []
    indent_level = 0
    for line in lines:
        if r"\begin" in line:
            formatted_lines.append(line)
            indent_level += 1
        elif r"\end" in line:
            indent_level -= 1
            formatted_lines.append(line)
        else:
            formatted_lines.append("    " * indent_level + line)
    return "\n".join(formatted_lines)

In [205]:
def load_reports(cfg, report_path, subsampling=None):
    reports = []

    for classifier in cfg.classifier:
        if subsampling is not None:
            B, M = subsampling
            report = GenericReport()
            report.load(f"{report_path}/subsampling/cls_report_{classifier.id}_{B}_{M}.json")
        else:
            report = GenericReport()
            report.load(f"{report_path}/cls_report_{classifier.id}.json")

        reports.append(report)

    return reports


In [206]:
def get_distributions(cfg, reports):
    distributions = {}

    for idx, cls in enumerate(cfg.classifier):
        report = vars(reports[idx])

        dist = report['dist_report']

        oracle = dist[0]
        p = dist[1]
        p_obs = dist[2]

        oracle_dist = Beta(params=BetaParams(oracle['a'], oracle['b']))
        p_dist = Beta(params=BetaParams(p['a'], p['b']))
        p_obs_dist = Beta(params=BetaParams(p_obs['a'], p_obs['b']))

        distributions[cls.id] = {
            'oracle': oracle_dist,
            'p': p_dist,
            'p_obs': p_obs_dist,
        }

    return distributions


In [207]:
def sample_value(a,b,B,M):
    if M == B:
        return 0

    return (a+b-B)/(M-B)

In [209]:
t = PrettyTable(['Generator', 'B', 'M', 'FIB V', 'FIB EAS2', 'DSS V', 'DSS EAS', 'LL3 V', 'LL3 EAS'])

for id, base_path in TEXT.items():
    cfg = OmegaConf.load(f'{base_path}/config.yaml')

    if 'subsampling' in cfg.quantify and SUBSAMPLING == True:
        B = cfg.quantify.subsampling.B
        M = cfg.quantify.subsampling.M

        for subsampling in list(itertools.product(B, M)):

            b, m = subsampling

            reports = load_reports(cfg, f'{base_path}/quantify', subsampling)
            distributions = get_distributions(cfg, reports)

            sample_values = []
            b_diff=[]
            for cls, dists in distributions.items():
                p = dists['p']

                sv = sample_value(p.params.a, p.params.b, B=b, M=m)
                sv = "%.2E" % sv

                diff = p.params.a + p.params.b - b
                b_diff.append(round(diff, 4))
                sample_values.append(sv)

            t.add_row([
                id,
                b,
                m,
                sample_values[0],
                b_diff[0],
                sample_values[1],
                b_diff[1],
                sample_values[2],
                b_diff[2]
            ])


    else:
        B=100
        M=10_000

        reports = load_reports(cfg, f'{base_path}/quantify')
        distributions = get_distributions(cfg, reports)

        sample_values = []
        b_diff = []
        for cls, dists in distributions.items():
            p = dists['p']

            sv = sample_value(p.params.a, p.params.b, B=B, M=M)
            sv = "%.2E" % sv

            diff = p.params.a + p.params.b - B
            b_diff.append(round(diff, 4))
            sample_values.append(sv)

        t.add_row([
            id,
            B,
            M,
            sample_values[0],
            b_diff[0],
            sample_values[1],
            b_diff[1],
            sample_values[2],
            b_diff[2]
        ])

t

Generator,B,M,FIB V,FIB EAS2,DSS V,DSS EAS,LL3 V,LL3 EAS
Llama-3.3-70B,100,10000,0.00104,10.28,0.00281,27.81,0.00566,56.05
Llama-2-7B,100,10000,0.00227,22.44,0.00594,58.81,0.0078,77.19
Mistral-7B,100,10000,0.00298,29.52,0.00547,54.16,0.00644,63.73


In [210]:
l = pretty_print_latex(t.get_latex_string())
print(l)

\begin{tabular}{ccccccccc}
    Generator & B & M & FIB V & FIB EAS2 & DSS V & DSS EAS & LL3 V & LL3 EAS \\
    Llama-3.3-70B & 100 & 10000 & 1.04E-03 & 10.28 & 2.81E-03 & 27.81 & 5.66E-03 & 56.05 \\
    Llama-2-7B & 100 & 10000 & 2.27E-03 & 22.44 & 5.94E-03 & 58.81 & 7.80E-03 & 77.19 \\
    Mistral-7B & 100 & 10000 & 2.98E-03 & 29.52 & 5.47E-03 & 54.16 & 6.44E-03 & 63.73 \\
\end{tabular}


In [213]:
t = PrettyTable(['Generator', 'B', 'M', 'VO8 V', 'VO8 EAS', 'DTR V', 'DTR EAS2', 'LLV V', 'LLV EAS'])


for id, base_path in IMAGE_MODEL_COUNT.items():
    cfg = OmegaConf.load(f'{base_path}/config.yaml')

    if 'subsampling' in cfg.quantify and SUBSAMPLING == True:
        B = cfg.quantify.subsampling.B
        M = cfg.quantify.subsampling.M

        for subsampling in list(itertools.product(B, M)):

            b, m = subsampling

            reports = load_reports(cfg, f'{base_path}/quantify', subsampling)
            distributions = get_distributions(cfg, reports)

            sample_values = []
            b_diff=[]
            for cls, dists in distributions.items():
                p = dists['p']

                sv = sample_value(p.params.a, p.params.b, B=b, M=m)
                sv = "%.2E" % sv

                diff = p.params.a + p.params.b - b
                b_diff.append(round(diff, 4))
                sample_values.append(sv)

            t.add_row([
                id,
                b,
                m,
                sample_values[0],
                b_diff[0],
                sample_values[1],
                b_diff[1],
                sample_values[2],
                b_diff[2]
            ])


    else:
        B=100
        M=10_000

        reports = load_reports(cfg, f'{base_path}/quantify')
        distributions = get_distributions(cfg, reports)

        sample_values = []
        b_diff=[]
        for cls, dists in distributions.items():
            p = dists['p']

            sv = sample_value(p.params.a, p.params.b, B=B, M=M)
            sv = "%.2E" % sv

            diff = p.params.a + p.params.b - B
            b_diff.append(round(diff, 4))
            sample_values.append(sv)

        t.add_row([
            id,
            B,
            M,
            sample_values[0],
            b_diff[0],
            sample_values[1],
            b_diff[1],
            sample_values[2],
            b_diff[2]
        ])

t

Generator,B,M,VO8 V,VO8 EAS,DTR V,DTR EAS2,LLV V,LLV EAS
Stable Diffusion 3.5,100,10000,0.00693,68.65,0.0113,111.46,0.00394,39.05
Stable Cascade,100,10000,0.00146,14.48,0.00833,82.43,-0.0033,-32.63
FLUX.1-dev,100,10000,0.0128,126.43,0.0259,256.76,0.0124,122.59


In [212]:
l = pretty_print_latex(t.get_latex_string())
print(l)

\begin{tabular}{ccccccccc}
    Generator & B & M & VO8 V & VO8 EAS & DTR V & DTR EAS2 & LLV V & LLV EAS \\
\end{tabular}
