In [198]:
import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gc

import numpy as np
from sympy.logic.boolalg import Boolean
from tqdm import tqdm

from jax import nn, random, vmap, clear_caches
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
from numpyro.ops.indexing import Vindex

import json

from sklearn.metrics import log_loss,f1_score
from sklearn.preprocessing import OneHotEncoder

from utils.data_utils import read_jsonl

from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon


In [199]:
res = read_jsonl("data/ghc_train.jsonl")
annotators = np.array([np.array(it["annotators"]) for it in res if len(it["annotators"]) == 3])
annotations = np.array([np.array(it["labels"]) for it in res if len(it["annotators"]) == 3])
logits = np.load("llm_data/Qwen2.5-32B/train/logits.npy")
logits = np.array([x for i, x in enumerate(logits[:, :2]) if len(res[i]["annotators"]) == 3])




In [200]:
def multinomial(annotations,logits=None,test:bool=False):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})
        with numpyro.plate("position", num_positions):
            if test:
                numpyro.sample("y", dist.Categorical(zeta[c]))
            else:
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)

def item_difficulty():
    pass


Function to print prediction results

In [201]:
def pred_results(y):
    annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
        y
    )

    pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
    emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

    print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
    print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
    print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')

Define multinomial_test for purposes of forward sampling with posterior latent samples as observations

In [202]:
def multinomial_test(annotations,logits=None,test:bool=False, zeta_samples = None, c_samples = None):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)),obs=zeta_samples)

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"}, obs=c_samples)
            # print(c.shape)
        with numpyro.plate("position", num_positions):
            # print(zeta[c].shape)
            z=posterior_samples['zeta'][:,np.newaxis,...]
            c=discrete_samples['c'][...,np.newaxis]
            # print(c.shape)
            # print(z.shape)
            # print(np.take_along_axis(z,c,axis=2).shape)
            if test:
                numpyro.sample("y", dist.Categorical(np.take_along_axis(z,c,axis=2)))
            else:
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)

Infer discrete is True

In [14]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=500,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
    (annotations[:train_size], logits[:train_size])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[:train_size], masks_[:train_size], global_num_classes, True, logits[:train_size])
    if model == dawid_skene
    else (annotators[:train_size], annotations[:train_size], logits[:train_size])
)

test_data = (
    (annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[train_size:], masks_[train_size:], global_num_classes, True, logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model == dawid_skene
    else (annotators[train_size:], annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
)

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')



sample: 100%|██████████| 1000/1000 [00:29<00:00, 33.49it/s, 7 steps of size 8.37e-01. acc. prob=0.93]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94    317.32      1.00
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06    317.32      1.00
 zeta[1,0]      0.56      0.00      0.56      0.55      0.57    271.48      1.00
 zeta[1,1]      0.44      0.00      0.44      0.43      0.45    271.48      1.00

Number of divergences: 0
Average Jensen-Shannon divergence across items= 0.1621229768657305
Average KL divergence across items= 0.5235341788748941
Binary F1 score with majority vote = 0.15275590551181104


Infer discrete is False

In [19]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=500,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
    (annotations[:train_size], logits[:train_size])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[:train_size], masks_[:train_size], global_num_classes, True, logits[:train_size])
    if model == dawid_skene
    else (annotators[:train_size], annotations[:train_size], logits[:train_size])
)

test_data = (
    (annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model in [multinomial, item_difficulty]
    else (positions_, annotations_[train_size:], masks_[train_size:], global_num_classes, True, logits[train_size:], [True] * annotations[train_size:].shape[0])
    if model == dawid_skene
    else (annotators[train_size:], annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0])
)

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=False)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)

pred_probs = np.vstack((annotator_probs.mean(1),1-annotator_probs.mean(1)))
emp_probs = np.vstack((annotations[train_size:].mean(1),1-annotations[train_size:].mean(1)))

print(f'Average Jensen-Shannon divergence across items= {np.power(jensenshannon(emp_probs,pred_probs),2).mean()}')
print(f'Average KL divergence across items= {entropy(emp_probs,pred_probs).mean()}')
print(f'Binary F1 score with majority vote = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')



sample: 100%|██████████| 1000/1000 [00:30<00:00, 32.85it/s, 7 steps of size 8.37e-01. acc. prob=0.93]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94    317.32      1.00
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06    317.32      1.00
 zeta[1,0]      0.56      0.00      0.56      0.55      0.57    271.48      1.00
 zeta[1,1]      0.44      0.00      0.44      0.43      0.45    271.48      1.00

Number of divergences: 0
Average Jensen-Shannon divergence across items= 0.04204168142459972
Average KL divergence across items= 0.15842332199139497
Binary F1 score with majority vote = 0.0


Try manual implementation of prediction looping over 500 samples with infer discrete True

In [213]:
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

In [214]:
for i in range(0,500):
    with numpyro.plate("item", discrete_samples["c"].shape[1], dim=-2):
        c = discrete_samples["c"][i]
        with numpyro.plate("position", 3):
            y = numpyro.sample("y", dist.Categorical(posterior_samples['zeta'][i][c]),rng_key=random.PRNGKey(1))
    if i == 0:
        ys=y
    elif i == 1:
        ys = np.vstack((ys[np.newaxis,...],y[np.newaxis,...]))
    else:
        ys = np.vstack((ys,y[np.newaxis,...]))

In [217]:
pred_results(ys)

Average Jensen-Shannon divergence across items= 0.09626321907655509
Average KL divergence across items= inf
Binary F1 score with majority vote = 0.11487481590574374


Try manual implementation of prediction looping over 500 samples with infer discrete False

In [218]:
predictive = Predictive(model, posterior_samples, infer_discrete=False)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

In [220]:
for i in range(0,500):
    with numpyro.plate("item", discrete_samples["c"].shape[1], dim=-2):
        c = discrete_samples["c"][i]
        with numpyro.plate("position", 3):
            y = numpyro.sample("y", dist.Categorical(posterior_samples['zeta'][i][c]),rng_key=random.PRNGKey(1))
    if i == 0:
        ys=y
    elif i == 1:
        ys = np.vstack((ys[np.newaxis,...],y[np.newaxis,...]))
    else:
        ys = np.vstack((ys,y[np.newaxis,...]))

In [221]:
pred_results(ys)

Average Jensen-Shannon divergence across items= 0.05609169892066523
Average KL divergence across items= inf
Binary F1 score with majority vote = 0.3125


Now instead of looping over 500 samples, try forward sampling from multinomial_test with posterior samples passed as observations
First get class samples from posterior from multionomial with infer_discrete = False

In [175]:
predictive = Predictive(model, posterior_samples, infer_discrete=False)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

In [176]:
data = (annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0], posterior_samples['zeta'],discrete_samples['c'])

Now run predictive on multionomial_test but the Predictive class is now sampling from what it thinks is the prior model but it's actually the posterior because the posterior samples are passed in the data


In [177]:
predictive1 = Predictive(multinomial_test, num_samples=1, infer_discrete=False)

discrete_samples1 = predictive1(random.PRNGKey(1), *data)

Results are consistent with just predicting using plain multionomial with infer_discrete=False. This is strange because it is different from the manual loop implementation earlier but seems to be doing the same thing as the loop.

In [179]:
pred_results(discrete_samples1['y'].squeeze())

Average Jensen-Shannon divergence across items= 0.04199882242942597
Average KL divergence across items= 0.15812588526376475
Binary F1 score with majority vote = 0.0


Try manually running the multinomial_test model without wrapping it in the Predictive class

In [215]:
annotations1,logits1,test, zeta_samples , c_samples = data
num_classes = int(np.max(annotations1)) + 1
num_items, num_positions = annotations1.shape

with numpyro.plate("class", num_classes):
    zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)),obs=zeta_samples)

if logits is None:
    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
# pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
# pi = nn.softmax(logits).mean(0)
# c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

with numpyro.plate("item", num_items, dim=-2):
    if logits is None:
        c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
    else:
        c = numpyro.sample("c", dist.Categorical(logits = logits1[:,np.newaxis,:]), infer={"enumerate": "parallel"}, obs=c_samples)
        # print(c.shape)
    with numpyro.plate("position", num_positions):
        # print(zeta[c].shape)
        # z=posterior_samples['zeta'][:,np.newaxis,...]
        # c=discrete_samples['c'][...,np.newaxis]
        z = zeta_samples[:,np.newaxis,...]
        c = c_samples[...,np.newaxis]
        # print(c.shape)
        # print(z.shape)
        # print(np.take_along_axis(z,c,axis=2).shape)
        if test:
            y=numpyro.sample("y", dist.Categorical(np.take_along_axis(z,c,axis=2)),rng_key=random.PRNGKey(1))
        else:
            numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)

Still consistent results with using Predictive class. So something in the loop implementation is wrong

In [216]:
pred_results(y)

Average Jensen-Shannon divergence across items= 0.04210504575377209
Average KL divergence across items= 0.1589373858139542
Binary F1 score with majority vote = 0.0


Try directly sampling y from zc = np.take_along_axis(z,c,axis=2) and make sure this is the same as the parameters for y in the loop method. Define zs by stacking the 500 samples in the loop


In [182]:
for i in range(0,500):
    if i == 0:
        zs=posterior_samples['zeta'][i][discrete_samples["c"][i]]
    elif i == 1:
        zs = np.vstack((zs[np.newaxis,...],posterior_samples['zeta'][i][discrete_samples["c"][i]][np.newaxis,...]))
    else:
        zs = np.vstack((zs,posterior_samples['zeta'][i][discrete_samples["c"][i]][np.newaxis,...]))

In [None]:
zc = np.take_along_axis(z,c,axis=2)

Both zc and zs are identical

In [183]:
(zc == zs).sum()

Array(2412000, dtype=int32)

sample y directly conditioned on zc or zs which are the posterior samples

In [184]:
y = numpyro.sample("y", dist.Categorical(jnp.tile(zc[np.newaxis,...],(3,1,1,1,1))),rng_key=random.PRNGKey(1))

In [187]:
pred_results(y.squeeze().transpose(1,2,0))

Average Jensen-Shannon divergence across items= 0.04210504575377209
Average KL divergence across items= 0.1589373858139542
Binary F1 score with majority vote = 0.0


Nothing wrong. Somehow, looping over 500 samples gives different results

In [191]:
jnp.tile(zc[np.newaxis,...],(3,1,1,1,1)).shape

(3, 500, 2412, 1, 2)

Loop now over second axis of the modified zc

In [222]:
for i in range(0,500):
    y = numpyro.sample("y", dist.Categorical(jnp.tile(zc[np.newaxis,...],(3,1,1,1,1))[:,i]),rng_key=random.PRNGKey(1))
    if i == 0:
        ys=y
    elif i == 1:
        ys = np.vstack((ys[np.newaxis,...],y[np.newaxis,...]))
    else:
        ys = np.vstack((ys,y[np.newaxis,...]))

In [197]:
pred_results(ys.squeeze().transpose(0,2,1))

Average Jensen-Shannon divergence across items= 0.05609169892066523
Average KL divergence across items= inf
Binary F1 score with majority vote = 0.3125


In [None]:
With looping the results are different

In [68]:
with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(multinomial_test).get_trace(*data)
print(numpyro.util.format_shapes(trace))

(500, 2412, 1, 1)
(500, 1, 2, 2)
(500, 2412, 1, 2)
(2412, 3)
(2412, 3)
(2412, 3)
 Trace Shapes:                  
  Param Sites:                  
 Sample Sites:                  
   class plate             2 |  
     zeta dist             2 | 2
         value      500    2 | 2
    item plate          2412 |  
        c dist     2412    1 |  
         value 500 2412    1 |  
position plate             3 |  
        y dist 500 2412    3 |  
         value 500 2412    3 |  


In [65]:
with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(multinomial).get_trace(*(annotations[train_size:], logits[train_size:], [True] * annotations[train_size:].shape[0]))
print(numpyro.util.format_shapes(trace))

(2412, 1)
(2412, 1, 2)
 Trace Shapes:              
  Param Sites:              
 Sample Sites:              
   class plate         2 |  
     zeta dist         2 | 2
         value         2 | 2
    item plate      2412 |  
        c dist 2412    1 |  
         value 2412    1 |  
position plate         3 |  
        y dist 2412    3 |  
         value 2412    3 |  
