In [11]:
import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gc

import numpy as np
from sympy.logic.boolalg import Boolean
from tqdm import tqdm

from jax import nn, random, vmap, clear_caches
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
from numpyro.ops.indexing import Vindex

import json

from sklearn.metrics import log_loss,f1_score
from sklearn.preprocessing import OneHotEncoder

#from models import read_jsonl, multinomial, item_difficulty


In [12]:
def read_jsonl(file:str) -> List[Dict]:
    with open(file, "r") as f:
        x = f.read()
        x = x.split("\n")
        res = []
        for x_val in x:
            try:
                res.append(json.loads(x_val))
            except:
                print(x_val)
    return res

In [13]:
res = read_jsonl("ghc_train.jsonl")
annotators = np.array([np.array(it["annotators"]) for it in res if len(it["annotators"]) == 3])
annotations = np.array([np.array(it["labels"]) for it in res if len(it["annotators"]) == 3])
logits = np.load("outputs/Qwen2.5-32B/train/logits.npy")
logits = np.array([x for i, x in enumerate(logits[:, :2]) if len(res[i]["annotators"]) == 3])




In [14]:
logits[:,np.newaxis,:].shape
logits is None

False

In [15]:
def multinomial(annotations,logits=None,test:bool=False):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})
        with numpyro.plate("position", num_positions):
            if test:
                numpyro.sample("y", dist.Categorical(zeta[c]))
            else:
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)

In [16]:
def dawid_skene(positions, annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 2 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes)))

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})
    

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

        # here we use Vindex to allow broadcasting for the second index `c`
        # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
        with numpyro.plate("position", num_positions):
            y=numpyro.sample(
                "y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations
            )

In [33]:
def mace(positions, annotations, logits, test:bool=False): #pat
    """
    This model corresponds to the plate diagram in Figure 3 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):

        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

        with numpyro.plate("position", num_positions):
            s = numpyro.sample(
                "s",
                dist.Bernoulli(1 - theta[positions]),
                infer={"enumerate": "parallel"},
            )
            probs = jnp.where(
                s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
            )
            if test:
                numpyro.sample("y", dist.Categorical(probs))
            else:
                numpyro.sample("y", dist.Categorical(probs), obs=annotations)


In [40]:
def hierarchical_dawid_skene(positions, annotations,logits, test:bool=False):
#pat
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample(
            "zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        omega = numpyro.sample(
            "Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

        with numpyro.plate("position", num_positions):
            if test:
                local_logits = Vindex(beta)[positions, c, :]
                numpyro.sample("y", dist.Categorical(logits=local_logits))
            else:
                local_logits = Vindex(beta)[positions, c, :]
                numpyro.sample("y", dist.Categorical(logits=local_logits), obs=annotations)


In [41]:
def item_difficulty(annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        chi = numpyro.sample(
            "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", annotations.shape[-1]):
            numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)


In [42]:
def logistic_random_effects(positions, annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample(
            "zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        omega = numpyro.sample(
            "Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )
        chi = numpyro.sample(
            "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)


In [43]:
with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(dawid_skene).get_trace(annotators,annotations,logits)
print(numpyro.util.format_shapes(trace))

  Trace Shapes:                
   Param Sites:                
  Sample Sites:                
annotator plate          18 |  
    class plate           2 |  
      beta dist    18     2 | 2
          value    18     2 | 2
     item plate       24119 |  
         c dist 24119     1 |  
          value 24119     1 |  
 position plate           3 |  
         y dist 24119     3 |  
          value 24119     3 |  


In [50]:
# model = mace
model = hierarchical_dawid_skene

mcmc = MCMC(
    NUTS(model),
    num_warmup=5,
    num_samples=15,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
data = (
        (annotations,logits)
        if model in [multinomial, item_difficulty]
        else (annotators, annotations,logits)
    )
mcmc.run(random.PRNGKey(0), *data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *data)

item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
    discrete_samples["c"].squeeze(-1)
)
print("Histogram of the predicted class of each item:")
row_format = "{:>10}" * 3
print(row_format.format("", *["c={}".format(i) for i in range(2)]))
for i, row in enumerate(item_class):
    print(row_format.format(f"item[{i}]", *row))

sample: 100%|██████████| 20/20 [00:02<00:00,  6.76it/s, 3 steps of size 4.16e-02. acc. prob=0.00]



                             mean       std    median      5.0%     95.0%     n_eff     r_hat
             Omega[0,0]      0.26      0.00      0.26      0.26      0.26       nan       nan
             Omega[1,0]      0.13      0.00      0.13      0.13      0.13       nan       nan
 beta_decentered[0,0,0]     -1.11      0.00     -1.11     -1.11     -1.11      1.01       nan
 beta_decentered[0,1,0]     -0.02      0.00     -0.02     -0.02     -0.02       nan       nan
 beta_decentered[1,0,0]     -0.08      0.00     -0.08     -0.08     -0.08       nan       nan
 beta_decentered[1,1,0]      0.79      0.00      0.79      0.79      0.79      1.01      0.93
 beta_decentered[2,0,0]      0.15      0.00      0.15      0.15      0.15      1.01      0.93
 beta_decentered[2,1,0]     -0.97      0.00     -0.97     -0.97     -0.97      1.01      0.93
 beta_decentered[3,0,0]     -0.27      0.00     -0.27     -0.27     -0.27       nan       nan
 beta_decentered[3,1,0]     -1.69      0.00     -1.69     -

In [45]:
#np.all(discrete_samples['c'] == discrete_samples1['c'])

In [52]:
model = hierarchical_dawid_skene

mcmc = MCMC(
    NUTS(model),
    num_warmup=100,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
        (annotations[:train_size],None)
        if model in [multinomial, item_difficulty]
        else (annotators[:train_size], annotations[:train_size],logits[:train_size])
    )

test_data = (
        (annotations[train_size:],None,True)
        if model in [multinomial, item_difficulty]
        else (annotators[train_size:], annotations[train_size:],logits[train_size:],True)
    )

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

# item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
#     discrete_samples["c"].squeeze(-1)
# )
# print("Histogram of the predicted class of each item:")
# row_format = "{:>10}" * 3
# print(row_format.format("", *["c={}".format(i) for i in range(2)]))
# for i, row in enumerate(item_class):
#     print(row_format.format(f"item[{i}]", *row))

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)
# print("Histogram of the annotator predictions of each item:")
# row_format = "{:>10}" * 4
# print(row_format.format("", *["pos={}".format(i) for i in range(3)]))
# for i, row in enumerate(annotator_probs):
#     row = np.rint(row).astype(int)
#     print(row_format.format(f"item[{i}]", *row))

print(f'log loss over positions = {log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())}')
print(f'log loss over items = {log_loss(np.rint(annotations[train_size:].mean(1)),annotator_probs.mean(1))}')
print('F1 score with position predictions')
print(f'Binary F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1)}')
print(f'Micro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="weighted")}')
print('F1 score with item predictions')
print(f'Micro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average="weighted")}')


sample: 100%|██████████| 600/600 [10:22<00:00,  1.04s/it, 255 steps of size 1.67e-02. acc. prob=0.97] 



                             mean       std    median      5.0%     95.0%     n_eff     r_hat
             Omega[0,0]      1.62      0.30      1.58      1.13      2.02    132.16      1.02
             Omega[1,0]      1.13      0.20      1.12      0.81      1.42    187.02      1.00
 beta_decentered[0,0,0]     -1.74      0.44     -1.72     -2.47     -1.00    109.09      1.01
 beta_decentered[0,1,0]     -2.83      0.53     -2.79     -3.76     -2.01    180.23      1.00
 beta_decentered[1,0,0]      1.24      0.34      1.24      0.73      1.84    104.51      1.01
 beta_decentered[1,1,0]      1.23      0.33      1.21      0.72      1.79    140.85      1.00
 beta_decentered[2,0,0]      1.16      0.32      1.14      0.68      1.70    115.72      1.01
 beta_decentered[2,1,0]      0.53      0.25      0.52      0.13      0.96    134.15      1.00
 beta_decentered[3,0,0]     -0.16      0.27     -0.17     -0.51      0.37     85.96      1.00
 beta_decentered[3,1,0]     -0.22      0.22     -0.21     -

In [53]:

model = mace

mcmc = MCMC(
    NUTS(model),
    num_warmup=100,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
        (annotations[:train_size],logits[:train_size])
        if model in [multinomial, item_difficulty]
        else (annotators[:train_size], annotations[:train_size],logits[:train_size])
    )

test_data = (
        (annotations[train_size:],logits[train_size:],True)
        if model in [multinomial, item_difficulty]
        else (annotators[train_size:], annotations[train_size:],logits[train_size:],True)
    )

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

# item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
#     discrete_samples["c"].squeeze(-1)
# )
# print("Histogram of the predicted class of each item:")
# row_format = "{:>10}" * 3
# print(row_format.format("", *["c={}".format(i) for i in range(2)]))
# for i, row in enumerate(item_class):
#     print(row_format.format(f"item[{i}]", *row))

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)
# print("Histogram of the annotator predictions of each item:")
# row_format = "{:>10}" * 4
# print(row_format.format("", *["pos={}".format(i) for i in range(3)]))
# for i, row in enumerate(annotator_probs):
#     row = np.rint(row).astype(int)
#     print(row_format.format(f"item[{i}]", *row))

print(f'log loss over positions = {log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())}')
print(f'log loss over items = {log_loss(np.rint(annotations[train_size:].mean(1)),annotator_probs.mean(1))}')
print('F1 score with position predictions')
print(f'Binary F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1)}')
print(f'Micro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="weighted")}')
print('F1 score with predictions of number of 1s out of all positions')
print(f'Micro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="weighted")}')
print('F1 score with majority vote')
print(f'Binary F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')
print(f'Micro F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="weighted")}')


sample: 100%|██████████| 600/600 [1:01:54<00:00,  6.19s/it, 1023 steps of size 1.60e-03. acc. prob=0.91]



                   mean       std    median      5.0%     95.0%     n_eff     r_hat
 epsilon[0,0]      0.17      0.02      0.17      0.14      0.19     21.84      1.19
 epsilon[0,1]      0.83      0.02      0.83      0.81      0.86     21.84      1.19
 epsilon[1,0]      0.99      0.00      0.99      0.99      1.00      4.74      1.28
 epsilon[1,1]      0.01      0.00      0.01      0.00      0.01      4.74      1.28
 epsilon[2,0]      1.00      0.00      1.00      1.00      1.00      3.88      1.77
 epsilon[2,1]      0.00      0.00      0.00      0.00      0.00      3.79      1.70
 epsilon[3,0]      0.92      0.01      0.92      0.91      0.93     42.67      1.00
 epsilon[3,1]      0.08      0.01      0.08      0.07      0.09     42.67      1.00
 epsilon[4,0]      0.99      0.00      0.99      0.98      0.99     27.21      1.00
 epsilon[4,1]      0.01      0.00      0.01      0.01      0.02     27.21      1.00
 epsilon[5,0]      0.88      0.01      0.88      0.86      0.90     50.45  

In [47]:

model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=100,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
        (annotations[:train_size],logits[:train_size])
        if model in [multinomial, item_difficulty]
        else (annotators[:train_size], annotations[:train_size],logits[:train_size])
    )

test_data = (
        (annotations[train_size:],logits[train_size:],True)
        if model in [multinomial, item_difficulty]
        else (annotators[train_size:], annotations[train_size:],logits[train_size:],True)
    )

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

# item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
#     discrete_samples["c"].squeeze(-1)
# )
# print("Histogram of the predicted class of each item:")
# row_format = "{:>10}" * 3
# print(row_format.format("", *["c={}".format(i) for i in range(2)]))
# for i, row in enumerate(item_class):
#     print(row_format.format(f"item[{i}]", *row))

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)
# print("Histogram of the annotator predictions of each item:")
# row_format = "{:>10}" * 4
# print(row_format.format("", *["pos={}".format(i) for i in range(3)]))
# for i, row in enumerate(annotator_probs):
#     row = np.rint(row).astype(int)
#     print(row_format.format(f"item[{i}]", *row))

print(f'log loss over positions = {log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())}')
print(f'log loss over items = {log_loss(np.rint(annotations[train_size:].mean(1)),annotator_probs.mean(1))}')
print('F1 score with position predictions')
print(f'Binary F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1)}')
print(f'Micro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="weighted")}')
print('F1 score with predictions of number of 1s out of all positions')
print(f'Micro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="weighted")}')
print('F1 score with majority vote')
print(f'Binary F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')
print(f'Micro F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="weighted")}')


sample: 100%|██████████| 600/600 [00:07<00:00, 83.13it/s, 3 steps of size 6.25e-01. acc. prob=0.95] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94    357.93      1.00
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06    357.94      1.00
 zeta[1,0]      0.56      0.00      0.56      0.55      0.57    383.18      1.00
 zeta[1,1]      0.44      0.00      0.44      0.43      0.45    383.18      1.00

Number of divergences: 0
log loss over positions = 0.6922026726650096
log loss over items = 0.691429105740423
F1 score with position predictions
Binary F1 score = 0.21279587532224045
Micro F1 score = 0.5357932559425097
Macro F1 score = 0.4418217495180531
Weighted F1 score = 0.6112806398503334
F1 score with predictions of number of 1s out of all positions
Micro F1 score = 0.22470978441127695
Macro F1 score = 0.17095477443648727
Weighted F1 score = 0.25568417882500183
F1 score with majority vote
Binary F1 score = 0.15275590551181104
Micro F1 score = 0.5538971807628524
Macro F1 score = 0.4

In [48]:
np.rint(np.rint(annotator_probs).mean(1))

array([0., 0., 0., ..., 1., 0., 0.], dtype=float32)

In [49]:
annotator_sums = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)

In [None]:
enc=OneHotEncoder(sparse_output=False)

In [None]:
enc.fit(np.rint(annotator_probs).sum(1).reshape(-1,1))

In [None]:
test=enc.transform(np.rint(annotator_probs).sum(1).reshape(-1,1))

In [None]:
log_loss(annotations[train_size:].sum(1),test)

In [None]:
f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average='micro')

In [None]:
log_loss(np.rint(annotations[train_size:].mean(1)),np.rint(annotator_probs.mean(1)))

7.587835118650642

: 

: 

In [None]:
discrete_samples['y'][:,1000,:].sum(0)

Array([ 5, 11,  3], dtype=int32)

: 

: 

In [None]:
for i in range(discrete_samples['y'].shape[1]):
    arr = discrete_samples['y'][:,i,:]
    if not np.all(np.tile(arr[0],(15,1)) == arr):
        print(i)

: 

: 

In [None]:
from sklearn.metrics import log_loss,f1_score

: 

: 

In [None]:
discrete_samples['y'].mean(0)

Array([[1., 0., 0.],
       [0., 0., 0.],
       [1., 1., 0.],
       ...,
       [0., 0., 0.],
       [0., 1., 0.],
       [0., 0., 0.]], dtype=float32)

: 

: 

In [None]:
np.rint(annotator_probs)

array([[0., 0., 1.],
       [0., 0., 1.],
       [1., 1., 1.],
       ...,
       [1., 1., 1.],
       [1., 0., 0.],
       [1., 1., 0.]], dtype=float32)

: 

: 

In [None]:
log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())

0.7317581959541963

: 

: 

In [None]:
f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average='weighted')

0.5782398912935731

: 

: 

In [None]:
for i in range(discrete_samples['y'].shape[1]):
    arr = discrete_samples['y'][0,i,:]
    if not np.all(annotations[train_size:][i] == arr):
        print(i)

: 

: 

In [None]:
jnp.unique_counts(discrete_samples['y'][:,0,:])

_UniqueCountsResult(values=Array([0, 1], dtype=int32), counts=Array([30, 15], dtype=int32))

: 

: 

In [None]:
discrete_samples['y'][:,0,:]

Array([[1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0]], dtype=int32)

: 

: 

: 

: 