In [1]:
import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gc

import numpy as np
from sympy.logic.boolalg import Boolean
from tqdm import tqdm

from jax import nn, random, vmap, clear_caches
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
from numpyro.ops.indexing import Vindex

import json

from sklearn.metrics import log_loss,f1_score
from sklearn.preprocessing import OneHotEncoder

from models import read_jsonl, multinomial, item_difficulty


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_jsonl(file:str) -> List[Dict]:
    with open(file, "r") as f:
        x = f.read()
        x = x.split("\n")
        res = []
        for x_val in x:
            try:
                res.append(json.loads(x_val))
            except:
                print(x_val)
    return res

In [2]:
res = read_jsonl("ghc_train.jsonl")
annotators = np.array([np.array(it["annotators"]) for it in res if len(it["annotators"]) == 3])
annotations = np.array([np.array(it["labels"]) for it in res if len(it["annotators"]) == 3])
logits = np.load("llm_data/Qwen2.5-32B/train/logits.npy")
logits = np.array([x for i, x in enumerate(logits[:, :2]) if len(res[i]["annotators"]) == 3])




In [17]:
logits[:,np.newaxis,:].shape
logits is None

False

In [18]:
def multinomial(annotations,logits=None,test:bool=False):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

    if logits is None:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs=nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        if logits is None:
            c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        else:
            c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})
        with numpyro.plate("position", num_positions):
            if test:
                numpyro.sample("y", dist.Categorical(zeta[c]))
            else:
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)

In [248]:
def dawid_skene(positions, annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 2 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes)))

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    # c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

        # here we use Vindex to allow broadcasting for the second index `c`
        # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
        with numpyro.plate("position", num_positions):
            y=numpyro.sample(
                "y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations
            )

In [54]:
def mace(positions, annotations, logits):
    """
    This model corresponds to the plate diagram in Figure 3 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample(
        #     "c",
        #     dist.DiscreteUniform(0, num_classes - 1),
        #     infer={"enumerate": "parallel"},
        # )

        # c = numpyro.sample(
        #     "c",
        #     dist.Categorical(probs=nn.softmax(logits).mean(0)),
        #     infer={"enumerate": "parallel"},
        # )

        with numpyro.plate("position", num_positions):
            s = numpyro.sample(
                "s",
                dist.Bernoulli(1 - theta[positions]),
                infer={"enumerate": "parallel"},
            )
            probs = jnp.where(
                s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
            )
            numpyro.sample("y", dist.Categorical(probs), obs=annotations)


In [56]:
def hierarchical_dawid_skene(positions, annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample(
            "zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        omega = numpyro.sample(
            "Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)


In [13]:
def item_difficulty(annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        chi = numpyro.sample(
            "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", annotations.shape[-1]):
            numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)


In [33]:
def logistic_random_effects(positions, annotations,logits):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample(
            "zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
        )
        omega = numpyro.sample(
            "Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )
        chi = numpyro.sample(
            "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
        )

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    # pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)),obs = nn.softmax(logits).mean(0))
    # pi = nn.softmax(logits).mean(0)
    c = numpyro.sample("c", dist.Categorical(logits = logits[:,np.newaxis,:]), infer={"enumerate": "parallel"})

    with numpyro.plate("item", num_items, dim=-2):
        # c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)


In [133]:
with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(dawid_skene).get_trace(annotators,annotations,logits)
print(numpyro.util.format_shapes(trace))

  Trace Shapes:                
   Param Sites:                
  Sample Sites:                
annotator plate          18 |  
    class plate           2 |  
      beta dist    18     2 | 2
          value    18     2 | 2
     item plate       24119 |  
         c dist 24119     1 |  
          value 24119     1 |  
 position plate           3 |  
         y dist 24119     3 |  
          value 24119     3 |  


In [16]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=5,
    num_samples=15,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
data = (
        (annotations,logits)
        if model in [multinomial, item_difficulty]
        else (annotators, annotations,logits)
    )
mcmc.run(random.PRNGKey(0), *data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *data)

item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
    discrete_samples["c"].squeeze(-1)
)
print("Histogram of the predicted class of each item:")
row_format = "{:>10}" * 3
print(row_format.format("", *["c={}".format(i) for i in range(2)]))
for i, row in enumerate(item_class):
    print(row_format.format(f"item[{i}]", *row))

sample: 100%|██████████| 20/20 [00:03<00:00,  5.79it/s, 7 steps of size 3.28e-02. acc. prob=0.62]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
     pi[0]      0.86      0.01      0.86      0.85      0.87     13.37      0.93
     pi[1]      0.14      0.01      0.14      0.13      0.15     13.37      0.93
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94     16.67      0.94
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06     16.67      0.94
 zeta[1,0]      0.46      0.01      0.46      0.45      0.49     12.64      0.93
 zeta[1,1]      0.54      0.01      0.54      0.51      0.55     12.64      0.93

Number of divergences: 0
Histogram of the predicted class of each item:
                 c=0       c=1
   item[0]         8         7
   item[1]        13         2
   item[2]         9         6
   item[3]         7         8
   item[4]         7         8
   item[5]         6         9
   item[6]         7         8
   item[7]         6         9
   item[8]         3        12
   item[9]        10         5
  item[10]        

In [14]:
np.all(discrete_samples['c'] == discrete_samples1['c'])

Array(True, dtype=bool)

In [15]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=100,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
        (annotations[:train_size],None)
        if model in [multinomial, item_difficulty]
        else (annotators[:train_size], annotations[:train_size],logits[:train_size])
    )

test_data = (
        (annotations[train_size:],None,True)
        if model in [multinomial, item_difficulty]
        else (annotators[train_size:], annotations[train_size:],logits[train_size:],True)
    )

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

# item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
#     discrete_samples["c"].squeeze(-1)
# )
# print("Histogram of the predicted class of each item:")
# row_format = "{:>10}" * 3
# print(row_format.format("", *["c={}".format(i) for i in range(2)]))
# for i, row in enumerate(item_class):
#     print(row_format.format(f"item[{i}]", *row))

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)
# print("Histogram of the annotator predictions of each item:")
# row_format = "{:>10}" * 4
# print(row_format.format("", *["pos={}".format(i) for i in range(3)]))
# for i, row in enumerate(annotator_probs):
#     row = np.rint(row).astype(int)
#     print(row_format.format(f"item[{i}]", *row))

print(f'log loss over positions = {log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())}')
print(f'log loss over items = {log_loss(np.rint(annotations[train_size:].mean(1)),annotator_probs.mean(1))}')
print('F1 score with position predictions')
print(f'Binary F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1)}')
print(f'Micro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="weighted")}')
print('F1 score with item predictions')
print(f'Micro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average="weighted")}')


sample: 100%|██████████| 600/600 [01:01<00:00,  9.76it/s, 7 steps of size 2.79e-01. acc. prob=0.93] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
     pi[0]      0.86      0.01      0.86      0.84      0.87     66.80      1.00
     pi[1]      0.14      0.01      0.14      0.13      0.16     66.80      1.00
 zeta[0,0]      0.94      0.00      0.94      0.94      0.95     75.19      1.00
 zeta[0,1]      0.06      0.00      0.06      0.05      0.06     75.19      1.00
 zeta[1,0]      0.47      0.02      0.47      0.44      0.50     72.72      1.00
 zeta[1,1]      0.53      0.02      0.53      0.50      0.56     72.72      1.00

Number of divergences: 0
log loss over positions = 0.6940914460098021
log loss over items = 0.6929893482984053
F1 score with position predictions
Binary F1 score = 0.19740141326646912
Micro F1 score = 0.5134051962410171
Macro F1 score = 0.4241345192261944
Weighted F1 score = 0.5918969638514527
F1 score with item predictions
Micro F1 score = 0.13308457711442787
Macro F1 score = 0.10012212325476766
Weighted F1 score = 0.065472879

In [6]:
model = multinomial

mcmc = MCMC(
    NUTS(model),
    num_warmup=100,
    num_samples=500,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

train_size=round(annotations.shape[0]*0.9)

train_data = (
        (annotations[:train_size],logits[:train_size])
        if model in [multinomial, item_difficulty]
        else (annotators[:train_size], annotations[:train_size],logits[:train_size])
    )

test_data = (
        (annotations[train_size:],logits[train_size:],True)
        if model in [multinomial, item_difficulty]
        else (annotators[train_size:], annotations[train_size:],logits[train_size:],True)
    )

mcmc.run(random.PRNGKey(0), *train_data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *test_data)

# item_class = vmap(lambda x: jnp.bincount(x, length=2), in_axes=1)(
#     discrete_samples["c"].squeeze(-1)
# )
# print("Histogram of the predicted class of each item:")
# row_format = "{:>10}" * 3
# print(row_format.format("", *["c={}".format(i) for i in range(2)]))
# for i, row in enumerate(item_class):
#     print(row_format.format(f"item[{i}]", *row))

annotator_probs = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)
# print("Histogram of the annotator predictions of each item:")
# row_format = "{:>10}" * 4
# print(row_format.format("", *["pos={}".format(i) for i in range(3)]))
# for i, row in enumerate(annotator_probs):
#     row = np.rint(row).astype(int)
#     print(row_format.format(f"item[{i}]", *row))

print(f'log loss over positions = {log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())}')
print(f'log loss over items = {log_loss(np.rint(annotations[train_size:].mean(1)),annotator_probs.mean(1))}')
print('F1 score with position predictions')
print(f'Binary F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1)}')
print(f'Micro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average="weighted")}')
print('F1 score with predictions of number of 1s out of all positions')
print(f'Micro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="micro")}')
print(f'Macro F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="macro")}')
print(f'Weighted F1 score = {f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs).sum(1),average="weighted")}')
print('F1 score with majority vote')
print(f'Binary F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1)}')
print(f'Micro F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="micro")}')
print(f'Macro F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="macro")}')
print(f'Weighted F1 score = {f1_score(np.rint(annotations[train_size:].mean(1)),np.rint(np.rint(annotator_probs).mean(1)),pos_label=1,average="weighted")}')


sample: 100%|██████████| 600/600 [00:33<00:00, 17.88it/s, 7 steps of size 5.53e-01. acc. prob=0.96]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
 zeta[0,0]      0.94      0.00      0.94      0.94      0.94    292.82      1.00
 zeta[0,1]      0.06      0.00      0.06      0.06      0.06    292.81      1.00
 zeta[1,0]      0.56      0.01      0.56      0.55      0.57    226.40      1.01
 zeta[1,1]      0.44      0.01      0.44      0.43      0.45    226.40      1.01

Number of divergences: 0
log loss over positions = 0.6922026726650096
log loss over items = 0.691429105740423
F1 score with position predictions
Binary F1 score = 0.21279587532224045
Micro F1 score = 0.5357932559425097
Macro F1 score = 0.4418217495180531
Weighted F1 score = 0.6112806398503334
F1 score with predictions of number of 1s out of all positions
Micro F1 score = 0.22470978441127695
Macro F1 score = 0.17095477443648727
Weighted F1 score = 0.25568417882500183
F1 score with majority vote
Binary F1 score = 0.15275590551181104
Binary F1 score = 0.5538971807628524
Binary F1 score = 0

In [5]:
np.rint(np.rint(annotator_probs).mean(1))

array([0., 0., 0., ..., 1., 0., 0.], dtype=float32)

In [None]:
annotator_sums = vmap(lambda x: x.mean(0), in_axes=1)(
    discrete_samples["y"]
)

In [35]:
enc=OneHotEncoder(sparse_output=False)

In [36]:
enc.fit(np.rint(annotator_probs).sum(1).reshape(-1,1))

In [38]:
test=enc.transform(np.rint(annotator_probs).sum(1).reshape(-1,1))

In [40]:
log_loss(annotations[train_size:].sum(1),test)

28.273048180849774

In [21]:
f1_score(annotations[train_size:].sum(1),np.rint(annotator_probs.sum(1)),average='micro')

0.1384742951907131

In [10]:
log_loss(np.rint(annotations[train_size:].mean(1)),np.rint(annotator_probs.mean(1)))

7.587835118650642

In [33]:
discrete_samples['y'][:,1000,:].sum(0)

Array([ 5, 11,  3], dtype=int32)

In [23]:
for i in range(discrete_samples['y'].shape[1]):
    arr = discrete_samples['y'][:,i,:]
    if not np.all(np.tile(arr[0],(15,1)) == arr):
        print(i)

In [25]:
from sklearn.metrics import log_loss,f1_score

In [241]:
discrete_samples['y'].mean(0)

Array([[1., 0., 0.],
       [0., 0., 0.],
       [1., 1., 0.],
       ...,
       [0., 0., 0.],
       [0., 1., 0.],
       [0., 0., 0.]], dtype=float32)

In [72]:
np.rint(annotator_probs)

array([[0., 0., 1.],
       [0., 0., 1.],
       [1., 1., 1.],
       ...,
       [1., 1., 1.],
       [1., 0., 0.],
       [1., 1., 0.]], dtype=float32)

In [71]:
log_loss(annotations[train_size:].flatten(),annotator_probs.flatten())

0.7317581959541963

In [80]:
f1_score(annotations[train_size:].flatten(),np.rint(annotator_probs).flatten(),pos_label=1,average='weighted')

0.5782398912935731

In [24]:
for i in range(discrete_samples['y'].shape[1]):
    arr = discrete_samples['y'][0,i,:]
    if not np.all(annotations[train_size:][i] == arr):
        print(i)

In [211]:
jnp.unique_counts(discrete_samples['y'][:,0,:])

_UniqueCountsResult(values=Array([0, 1], dtype=int32), counts=Array([30, 15], dtype=int32))

In [212]:
discrete_samples['y'][:,0,:]

Array([[1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0]], dtype=int32)