In [14]:
from cgeval import Report
from cgeval.report import GenericReport
from cgeval.distribution import Beta, BetaParams
from prettytable import PrettyTable
from decimal import Decimal


import textwrap
import re
from collections import Counter
import numpy as np
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from dataclasses import dataclass
from scipy.stats import beta
import numpy as np
from scipy.stats import norm
from scipy.integrate import quad
from scipy.spatial.distance import jensenshannon


In [15]:
TEXT = {
    'L370B': '../../out/pipeline/2025-06-25_sentiment_llama3-3',
    'L27B': '../../out/pipeline/2025-06-25_sentiment_llama2',
    'Mi7B': '../../out/pipeline/2025-06-25_sentiment_mistral',
}

IMAGE = {
    'SD35_count': '../../out/pipeline/2025-06-25_animal_detection_stable-diffusion_count',
    'SD35_animal': '../../out/pipeline/2025-06-25_animal_detection_stable-diffusion_animal',
    'StCa_count': '../../out/pipeline/2025-06-25_animal_detection_stable-cascade_count',
    'StCa_animal': '../../out/pipeline/2025-06-25_animal_detection_stable-cascade_animal',
    'FLX1_count': '../../out/pipeline/2025-06-25_animal_detection_flux_count',
    'FLX1_animal': '../../out/pipeline/2025-06-25_animal_detection_flux_animal',
}


cls_name2id = {
    'FinancialBERT': 'FIB',
    'lxyuan_DistilBert': 'DSS',
    'ollama3.2': 'LL3',
    'Yolov8': 'VO8',
    'DETR': 'DTR',
    'LLaVA': 'LLV'
}


In [16]:
def load_reports(cfg, base_path):
    reports = []

    for classifier in cfg.classifier:
        report = GenericReport()
        report.load(f"{base_path}/cls_report_{classifier.id}.json")
        reports.append(report)

    return reports

In [17]:
def get_distributions(cfg, reports):
    distributions = {}

    for idx, cls in enumerate(cfg.classifier):
        report = vars(reports[idx])

        dist = report['dist_report']

        oracle = dist[0]
        p = dist[1]
        p_obs = dist[2]

        oracle_dist = Beta(params=BetaParams(oracle['a'], oracle['b']))
        p_dist = Beta(params=BetaParams(p['a'], p['b']))
        p_obs_dist = Beta(params=BetaParams(p_obs['a'], p_obs['b']))

        distributions[cls.id] = {
            'oracle': oracle_dist,
            'p': p_dist,
            'p_obs': p_obs_dist,
        }

    return distributions

def beta_to_normal_params(alpha, beta_param):
    mu = alpha / (alpha + beta_param)
    var = (alpha * beta_param) / ((alpha + beta_param) ** 2 * (alpha + beta_param + 1))
    sigma = np.sqrt(var)
    return mu, sigma

def js_gauss(mu1, s1, mu2, s2):
    x = np.linspace(0, 1, 1000)
    
    p = norm.pdf(x, mu1, s1)
    q = norm.pdf(x, mu2, s2)

    return jensenshannon(p, q, base=2) 

def js_beta(dist_p, dist_q):
    xs_upper = np.arange(1, 1001) / 1000
    xs_lower = np.arange(0, 1000) / 1000

    p = beta.cdf(xs_upper, dist_p.params.a, dist_p.params.b) - beta.cdf(xs_lower, dist_p.params.a, dist_p.params.b)
    q = beta.cdf(xs_upper, dist_q.params.a, dist_q.params.b) - beta.cdf(xs_lower, dist_q.params.a, dist_q.params.b)

    return jensenshannon(p, q, base=2)


In [18]:
def add_row(task, id, base_path):
    cfg = OmegaConf.load(f'{base_path}/config.yaml')

    reports = load_reports(cfg, f'{base_path}/quantify')
    distributions = get_distributions(cfg, reports)

    rows = []
    for cls, dists in distributions.items():
        oracle = dists['oracle']
        p = dists['p']
        p_obs = dists['p_obs']

        jsd_p = js_beta(oracle, p)
        jsd_p_obs = js_beta(oracle, p_obs)


        _, p_var, _, _ = p.stats()
        _, p_obs_var, _, _ = p_obs.stats()
        _, oracle_var, _, _ = oracle.stats()

        oracle_var = "%.2E" % oracle_var

        cc = f'{round(jsd_p_obs, 4)} / {"%.2E" % p_obs_var}'
        bcc = f'{round(jsd_p, 4)} / {"%.2E" % p_var}'

        row = [task, f'{id}', f'{cls_name2id[cls]}', cc, bcc, oracle_var]
        rows.append(row)

    return rows


In [19]:
import pandas as pd

columns = ['Task', 'Generator', 'Classifier', 'CC', 'BCC', 'Oracle Var']
rows = []
t = PrettyTable(columns)


for id, base_path in TEXT.items():
    r = add_row('Text', id, base_path)
    rows.extend(r)
    t.add_rows(r)

for id, base_path in IMAGE.items():
    r = add_row('Images', id, base_path)
    rows.extend(r)
    t.add_rows(r)

t
df = pd.DataFrame(rows, columns=columns)
df.to_json('tab_7_jensen-shannon.json', orient='records')


t


Task,Generator,Classifier,CC,BCC,Oracle Var
Text,L370B,FIB,1.0 / 2.49E-05,0.162 / 1.14E-03,0.00115
Text,L370B,DSS,1.0 / 2.38E-05,0.0726 / 9.04E-04,0.00115
Text,L370B,LL3,0.9982 / 1.97E-05,0.1239 / 7.52E-04,0.00115
Text,L27B,FIB,1.0 / 2.49E-05,0.1305 / 1.19E-03,0.00152
Text,L27B,DSS,0.9994 / 2.33E-05,0.3561 / 8.44E-04,0.00152
Text,L27B,LL3,0.9648 / 2.03E-05,0.1686 / 8.53E-04,0.00152
Text,Mi7B,FIB,1.0 / 2.30E-05,0.1255 / 1.25E-03,0.0017
Text,Mi7B,DSS,0.9995 / 2.42E-05,0.216 / 1.02E-03,0.0017
Text,Mi7B,LL3,0.9399 / 2.09E-05,0.1942 / 9.69E-04,0.0017
Images,SD35_count,VO8,0.7897 / 2.37E-05,0.3455 / 1.34E-03,0.00231


In [20]:

def pretty_print_latex(latex_str):
    lines = latex_str.replace(r" \\ ", r" \\" + "\n").splitlines()
    formatted_lines = []
    indent_level = 0
    for line in lines:
        if r"\begin" in line:
            formatted_lines.append(line)
            indent_level += 1
        elif r"\end" in line:
            indent_level -= 1
            formatted_lines.append(line)
        else:
            formatted_lines.append("    " * indent_level + line)
    return "\n".join(formatted_lines)

formatted_latex = pretty_print_latex(t.get_latex_string())

print(formatted_latex)

\begin{tabular}{cccccc}
    Task & Generator & Classifier & CC & BCC & Oracle Var \\
    Text & L370B & FIB & 1.0 / 2.49E-05 & 0.162 / 1.14E-03 & 1.15E-03 \\
    Text & L370B & DSS & 1.0 / 2.38E-05 & 0.0726 / 9.04E-04 & 1.15E-03 \\
    Text & L370B & LL3 & 0.9982 / 1.97E-05 & 0.1239 / 7.52E-04 & 1.15E-03 \\
    Text & L27B & FIB & 1.0 / 2.49E-05 & 0.1305 / 1.19E-03 & 1.52E-03 \\
    Text & L27B & DSS & 0.9994 / 2.33E-05 & 0.3561 / 8.44E-04 & 1.52E-03 \\
    Text & L27B & LL3 & 0.9648 / 2.03E-05 & 0.1686 / 8.53E-04 & 1.52E-03 \\
    Text & Mi7B & FIB & 1.0 / 2.30E-05 & 0.1255 / 1.25E-03 & 1.70E-03 \\
    Text & Mi7B & DSS & 0.9995 / 2.42E-05 & 0.216 / 1.02E-03 & 1.70E-03 \\
    Text & Mi7B & LL3 & 0.9399 / 2.09E-05 & 0.1942 / 9.69E-04 & 1.70E-03 \\
    Images & SD35_count & VO8 & 0.7897 / 2.37E-05 & 0.3455 / 1.34E-03 & 2.31E-03 \\
    Images & SD35_count & DTR & 0.7992 / 2.35E-05 & 0.235 / 1.02E-03 & 2.31E-03 \\
    Images & SD35_count & LLV & 0.8374 / 2.46E-05 & 0.0996 / 1.67E-03 & 2.3