In [1]:
import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gc

import numpy as np
from sympy.logic.boolalg import Boolean
from tqdm import tqdm

from jax import nn, random, vmap, clear_caches
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
from numpyro.ops.indexing import Vindex

import json

from sklearn.metrics import log_loss,f1_score
from sklearn.preprocessing import OneHotEncoder

from utils.data_utils import read_jsonl

from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
res = read_jsonl("data/ghc_train.jsonl")
annotators = np.array([np.array(it["annotators"]) for it in res if len(it["annotators"]) == 3])
annotations = np.array([np.array(it["labels"]) for it in res if len(it["annotators"]) == 3])
logits = np.load("llm_data/Qwen2.5-32B/train/logits.npy")
logits = np.array([x for i, x in enumerate(logits[:, :2]) if len(res[i]["annotators"]) == 3])




In [5]:
def multinomial(annotations,logits=None,test:bool=False):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})
        with numpyro.plate("position", num_positions):
            if test:
                numpyro.sample("y", dist.Categorical(zeta[c]))
            else:
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)

def item_difficulty():
    pass


In [14]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=500,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
    (annotations[:train_size], logits[:train_size])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[:train_size], masks_[:train_size], global_num_classes, True, logits[:train_size])
    if model == dawid_skene
    else (annotators[:train_size], annotations[:train_size], logits[:train_size])
)

test_data = (
    (annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[train_size:], masks_[train_size:], global_num_classes, True, logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model == dawid_skene
    else (annotators[train_size:], annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
)

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')



sample: 100%|██████████| 1000/1000 [00:29<00:00, 33.49it/s, 7 steps of size 8.37e-01. acc. prob=0.93]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94    317.32      1.00
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06    317.32      1.00
 zeta[1,0]      0.56      0.00      0.56      0.55      0.57    271.48      1.00
 zeta[1,1]      0.44      0.00      0.44      0.43      0.45    271.48      1.00

Number of divergences: 0
Average Jensen-Shannon divergence across items= 0.1621229768657305
Average KL divergence across items= 0.5235341788748941
Binary F1 score with majority vote = 0.15275590551181104


In [15]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=500,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
    (annotations[:train_size], logits[:train_size])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[:train_size], masks_[:train_size], global_num_classes, True, logits[:train_size])
    if model == dawid_skene
    else (annotators[:train_size], annotations[:train_size], logits[:train_size])
)

test_data = (
    (annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[train_size:], masks_[train_size:], global_num_classes, True, logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model == dawid_skene
    else (annotators[train_size:], annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
)

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=False)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')



sample: 100%|██████████| 1000/1000 [00:33<00:00, 29.42it/s, 7 steps of size 8.37e-01. acc. prob=0.93] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94    317.32      1.00
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06    317.32      1.00
 zeta[1,0]      0.56      0.00      0.56      0.55      0.57    271.48      1.00
 zeta[1,1]      0.44      0.00      0.44      0.43      0.45    271.48      1.00

Number of divergences: 0
Average Jensen-Shannon divergence across items= 0.04204168142459972
Average KL divergence across items= 0.15842332199139497
Binary F1 score with majority vote = 0.0


In [7]:
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

In [8]:
for i in range(0,500):
    with numpyro.plate("item", discrete_samples["c"].shape[1], dim=-2):
        c = discrete_samples["c"][i]
        with numpyro.plate("position", 3):
            y = numpyro.sample("y", dist.Categorical(posterior_samples['zeta'][i][c]),rng_key=random.PRNGKey(1))
    if i == 0:
        ys=y
    elif i == 1:
        ys = np.vstack((ys[np.newaxis,...],y[np.newaxis,...]))
    else:
        ys = np.vstack((ys,y[np.newaxis,...]))

In [9]:
annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    ys
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')

Average Jensen-Shannon divergence across items= 0.09626321907655509
Average KL divergence across items= inf
Binary F1 score with majority vote = 0.11487481590574374


In [10]:
predictive = Predictive(model, posterior_samples, infer_discrete=False)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

In [11]:
for i in range(0,500):
    with numpyro.plate("item", discrete_samples["c"].shape[1], dim=-2):
        c = discrete_samples["c"][i]
        with numpyro.plate("position", 3):
            y = numpyro.sample("y", dist.Categorical(posterior_samples['zeta'][i][c]),rng_key=random.PRNGKey(1))
    if i == 0:
        ys=y
    elif i == 1:
        ys = np.vstack((ys[np.newaxis,...],y[np.newaxis,...]))
    else:
        ys = np.vstack((ys,y[np.newaxis,...]))

In [12]:
annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    ys
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')

Average Jensen-Shannon divergence across items= 0.05609169892066523
Average KL divergence across items= inf
Binary F1 score with majority vote = 0.3125
