In [20]:
import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gc

import numpy as np
from sympy.logic.boolalg import Boolean
from tqdm import tqdm

from jax import nn, random, vmap, clear_caches
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
from numpyro.ops.indexing import Vindex

import json

from sklearn.metrics import log_loss,f1_score
from sklearn.preprocessing import OneHotEncoder

from utils.data_utils import read_jsonl

from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon


In [21]:
res = read_jsonl("data/ghc_train.jsonl")
annotators = np.array([np.array(it["annotators"]) for it in res if len(it["annotators"]) == 3])
annotations = np.array([np.array(it["labels"]) for it in res if len(it["annotators"]) == 3])
logits = np.load("llm_data/Qwen2.5-32B/train/logits.npy")
logits = np.array([x for i, x in enumerate(logits[:, :2]) if len(res[i]["annotators"]) == 3])




In [71]:
def item_difficulty(annotations,logits,mask=None):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        chi = numpyro.sample(
            "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations, obs_mask=mask)

In [23]:
def multinomial(num_classes, num_items, num_positions, annotations=None,logits=None,mask=None):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    # num_classes = int(np.max(annotations)) + 1
    # num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

        with numpyro.plate("position", num_positions):
            # if test:
            #     numpyro.sample("y", dist.Categorical(zeta[c]))
            # else:
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations, obs_mask=mask)


In [8]:
with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(multinomial).get_trace(3,annotations.shape[0], 3, annotations, logits, mask)
print(numpyro.util.format_shapes(trace))

NameError: name 'mask' is not defined

In [73]:
model = item_difficulty

mcmc = MCMC(
    NUTS(model),
    num_warmup=10,
    num_samples=10,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)
#
# train_data = (
#     (annotations[:train_size], logits[:train_size])
#     if model in [multinomial, item_difficulty]
#     else (positions_, annotations_[:train_size], masks_[:train_size], global_num_classes, True, logits[:train_size])
#     if model == dawid_skene
#     else (annotators[:train_size], annotations[:train_size], logits[:train_size])
# )
#
# test_data = (
#     (annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
#     if model in [multinomial, item_difficulty]
#     else (positions_, annotations_[train_size:], masks_[train_size:], global_num_classes, True, logits[train_size:], [True] * annotations[train_size:].shape[0])
#     if model == dawid_skene
#     else (annotators[train_size:], annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
# )


mask=jnp.array([True]*train_size + [False]*(annotations.shape[0]-train_size)).reshape(-1,1)
mask=jnp.tile(mask,(1,3))
train_data = (
    (annotations,logits, mask)
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[:train_size], masks_[:train_size], global_num_classes, True, logits[:train_size])
    if model == dawid_skene
    else (annotators[:train_size], annotations[:train_size], logits[:train_size])
)

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
# discrete_samples = predictive(random.PRNGKey(1), *test_data)
discrete_samples = predictive(random.PRNGKey(1), *train_data)

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y_unobserved"][:,train_size:],
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')



  mcmc.run(random.PRNGKey(0), *train_data)
sample: 100%|██████████| 20/20 [02:16<00:00,  6.81s/it, 255 steps of size 1.86e-02. acc. prob=0.91]
  n_eff = np.prod(x.shape[:2]) / tau



                                 mean       std    median      5.0%     95.0%     n_eff     r_hat
                   Chi[0,0]      0.13      0.14      0.10      0.00      0.41      2.98      2.75
                   Chi[1,0]      0.40      0.23      0.53      0.02      0.60      3.21      1.76
                   eta[0,0]      2.79      0.02      2.79      2.76      2.83      9.42      0.91
                   eta[1,0]      0.26      0.02      0.25      0.23      0.30      5.09      1.37
    theta_decentered[0,0,0]      0.18      0.91      0.12     -1.27      1.73     30.91      0.90
    theta_decentered[1,0,0]     -0.00      1.31      0.11     -1.77      1.64     84.37      0.94
    theta_decentered[2,0,0]      0.04      1.50      0.10     -2.00      1.83     30.32      0.94
    theta_decentered[3,0,0]     -0.20      0.45     -0.17     -0.79      0.49     15.54      0.94
    theta_decentered[4,0,0]      0.07      0.62      0.03     -0.95      0.89     24.28      0.92
    theta_decentere

In [65]:
posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
# discrete_samples = predictive(random.PRNGKey(1), *train_data)
discrete_samples = predictive(random.PRNGKey(1), *(2,logits[:].shape[0],3, annotations,logits[:],mask))

In [66]:
annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y_unobserved"][:,train_size:],
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')

Average Jensen-Shannon divergence across items= 0.16220774515307626
Average KL divergence across items= 0.5238433605106096
Binary F1 score with majority vote = 0.14330708661417324


In [60]:
discrete_samples["y_unobserved"][train_size:].shape

(0, 24119, 3)

In [58]:
posterior_samples['y'].shape

(500, 2, 1, 24119, 3)

In [57]:
discrete_samples.keys()

dict_keys(['c', 'y_observed', 'y_unobserved'])

In [51]:
yt = posterior_samples['y'][:,:,0,train_size:,:].transpose((0,2,1,3))

In [28]:
yt.shape

(500, 2412, 2, 3)

In [43]:
discrete_samples['c'].shape

(500, 2412, 1)

In [34]:
yt[0,0,discrete_samples['c'][0,0]]

Array([[0, 0, 0]], dtype=int32)

In [50]:
c=discrete_samples['c']

In [52]:
# First, remove the singleton dimension from c if needed
c_idx = jnp.squeeze(c, axis=-1)  # shape: (500, 2412)

# Use jnp.take_along_axis to index the 3rd dimension (axis=2)
# yt shape: (500, 2412, 2, 3) --> we want to select 1 from axis=2

# Expand c_idx to match yt's shape along axis=2
c_idx_exp = c_idx[..., None, None]  # shape: (500, 2412, 1, 1)

# Tile to shape: (500, 2412, 1, 3) to match yt
c_idx_exp = jnp.tile(c_idx_exp, (1, 1, 1, 3))

# Now gather along axis=2
a = jnp.take_along_axis(yt, c_idx_exp, axis=2)  # shape: (500, 2412, 1, 3)

# Squeeze out the axis=2
a = jnp.squeeze(a, axis=2)  # shape: (500, 2412, 3)

In [46]:
a.shape

(500, 2412, 3)

In [2]:
ys=jnp.take(yt,discrete_samples["c"],axis=2)

NameError: name 'jnp' is not defined

In [14]:
discrete_samples['c'].shape

(500, 2412, 1)

In [38]:
annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    posterior_samples["y"][:,train_size:]
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')

ValueError: operands could not be broadcast together with shapes (2,2412) (0,24119,3) 

In [19]:
annotator_probs.shape

(24119, 3)

In [20]:
discrete_samples["y_unobserved"][train_size:].shape

(0, 24119, 3)

In [26]:
posterior_samples.keys()

dict_keys(['y', 'zeta'])

In [29]:
posterior_samples['y'].shape

(500, 2, 1, 24119, 3)

In [34]:
discrete_samples['c'].shape

(500, 2412, 1)

In [36]:
posterior_samples['y'][:,:,:,train_size:,:].shape

(500, 2, 1, 2412, 3)

In [37]:
y = posterior_samples['y'][:,:,:,train_size:,:]

In [48]:
a = np.empty(discrete_samples['c'].shape)

In [49]:
a=np.tile(a,(1,1,3))

In [58]:
yt=y.transpose((0,3,1,4,2))

In [50]:
for i in range(500):
    for j in range(2412):
        a[i,j] = y[i,discrete_samples['c'][i,j],0,j,:]

In [53]:
a.shape

(500, 2412, 3)

In [55]:
discrete_samples['c'].shape

(500, 2412, 1)

In [1]:
yt[discrete_samples['c']]

NameError: name 'yt' is not defined

In [54]:
annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    a
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')

Average Jensen-Shannon divergence across items= 0.06398697845361095
Average KL divergence across items= inf
Binary F1 score with majority vote = 0.45874587458745875


In [145]:
discrete_samples["y_unobserved"][train_size:]

Array([], shape=(0, 24119, 3), dtype=int32)

In [99]:
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

In [82]:
for i in range(0,500):
    with numpyro.plate("item", discrete_samples["c"].shape[1], dim=-2):
        c = discrete_samples["c"][i]
        with numpyro.plate("position", 3):
            y = numpyro.sample("y", dist.Categorical(posterior_samples['zeta'][i][c]),rng_key=random.PRNGKey(1))
    if i == 0:
        ys=y
    elif i == 1:
        ys = np.vstack((ys[np.newaxis,...],y[np.newaxis,...]))
    else:
        ys = np.vstack((ys,y[np.newaxis,...]))

In [40]:
with numpyro.plate("my_plate1", 3):
    a =  numpyro.sample('a', dist.Bernoulli(0.5),rng_key=random.PRNGKey(1))
    b = numpyro.sample('b',  dist.Bernoulli(0.5),rng_key=random.PRNGKey(1))
    with numpyro.plate("my_plate2", 2):
        c = numpyro.sample('c', dist.Bernoulli(0.5),rng_key=random.PRNGKey(1))

In [41]:
print(a.shape)
print(b.shape)
print(c.shape)

(3,)
(3,)
(2, 3)
