In [1]:
import numpy as np
from jax import random
import jax.numpy as jnp
from pathlib import Path
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.ops.indexing import Vindex
from tqdm import tqdm
from numpyro.handlers import mask

import json

In [2]:
def get_class_probabilities(logits_path, tokens_json_path, label_tokens=["0", "1"]):
    """
    Computes softmax probabilities for class tokens from LLM logits.

    Args:
        logits_path (str): Path to .npy file containing logits (N x top_k).
        tokens_json_path (str): Path to JSON file mapping logits columns to tokens.
        label_tokens (list): List of class token strings, e.g. ["0", "1"].

    Returns:
        np.ndarray: Array of shape (N, 2) with probabilities for class "0" and "1".
    """
    # Load data
    logits = np.load(logits_path)
    with open(tokens_json_path, "r") as f:
        token_map = json.load(f)

    # Find which columns in logits correspond to the class tokens
    label_columns = []
    for pos_str, (token, _) in token_map.items():
        if token.strip() in label_tokens:
            label_columns.append(int(pos_str))

    if len(label_columns) != 2:
        raise ValueError(f"Expected 2 class tokens, found {len(label_columns)}: {label_columns}")

    # Extract relevant logits
    selected_logits = logits[:, label_columns]

    # Softmax function
    def softmax(x):
        e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        return e_x / e_x.sum(axis=1, keepdims=True)

    # Calculate probabilities
    probs = softmax(selected_logits)

    return selected_logits, probs

logits_train, probs_train = get_class_probabilities("outputs/train/logits.npy", "outputs/train/top_tokens.json")
logits_val, probs_val = get_class_probabilities("outputs/val/logits.npy", "outputs/val/top_tokens.json")

# Show first 5 examples
for i in range(5):
    print(f"Example {i+1}: P(0) = {probs_train[i, 0]:.3f}, P(1) = {probs_train[i, 1]:.3f}")

print(probs_train[:5])
print(logits_train[:5])

Example 1: P(0) = 0.998, P(1) = 0.002
Example 2: P(0) = 0.998, P(1) = 0.002
Example 3: P(0) = 0.980, P(1) = 0.020
Example 4: P(0) = 0.005, P(1) = 0.995
Example 5: P(0) = 0.626, P(1) = 0.374
[[0.99847525 0.00152479]
 [0.9979493  0.00205074]
 [0.97997653 0.02002344]
 [0.00530189 0.99469805]
 [0.6261242  0.37387583]]
[[27.625    21.140625]
 [26.46875  20.28125 ]
 [26.703125 22.8125  ]
 [22.90625  28.140625]
 [24.015625 23.5     ]]


In [3]:
# --- DATA LOADING HELPERS --- #

def load_jsonl(file_path, max_items=None):
    data = []
    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            if max_items is not None and i >= max_items:
                break
            data.append(eval(line))
    return data

def create_annotator_mapping(data):
    from collections import defaultdict
    annotator_positions = defaultdict(set)
    for item in data:
        for pos, ann in enumerate(item['annotators']):
            annotator_positions[ann].add(pos)
    annotator_to_positions = {}
    current_position = 0
    for annotator in sorted(annotator_positions.keys()):
        positions = sorted(annotator_positions[annotator])
        for pos in positions:
            annotator_to_positions[(annotator, pos)] = current_position
            current_position += 1
    return annotator_to_positions

def process_annotations(data, annotator_mapping=None):
    if annotator_mapping is None:
        annotator_mapping = create_annotator_mapping(data)

    total_positions = max(annotator_mapping.values()) + 1
    positions = np.zeros(total_positions, dtype=int)
    annotations = np.zeros((len(data), total_positions), dtype=int)
    masks = np.zeros((len(data), total_positions), dtype=bool)

    for item_idx, item in enumerate(data):
        for pos, (annotator, label) in enumerate(zip(item['annotators'], item['labels'])):
            if (annotator, pos) in annotator_mapping:
                matrix_pos = annotator_mapping[(annotator, pos)]
                annotations[item_idx, matrix_pos] = label
                masks[item_idx, matrix_pos] = True
                positions[matrix_pos] = annotator
    return positions, annotations, masks

# as result, matrix annotations with columns meaning (annotator, position) and columns meaning labels for each item on (annotator, position)
# positions[i] - who annotated the column i in annotations

In [4]:
# --- DAWID-SKENE MODEL --- #

def dawid_skene(positions, annotations, masks, use_llm_prior=False, llm_probs=None):
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes)))
    
    if use_llm_prior:
        assert llm_probs is not None, "LLM probabilities must be provided if use_llm_prior is True"
        # pi = numpyro.sample("pi", dist.Dirichlet(llm_probs))
        #pi = jnp.asarray(llm_probs)
        pi = jnp.array(llm_probs[:,np.newaxis,:])  # shape: (num_items, num_classes)

    else:
        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})

        with numpyro.plate("position", num_positions):
            with mask(mask=masks):
                numpyro.sample(
                    "y",
                    dist.Categorical(Vindex(beta)[positions, c, :]),
                    obs=annotations,
                )


# --- MAIN EXECUTION --- #

def run_ds_on_subset(json_path, use_llm_prior=False, llm_probs=None, max_items=None):
    data = load_jsonl(json_path, max_items=max_items)
    positions, annotations, masks = process_annotations(data)

    if use_llm_prior:
        if llm_probs is None:
            raise ValueError("llm_probs must be provided when use_llm_prior=True")
        if len(llm_probs.shape) != 2:
            raise ValueError(f"llm_probs should be 2D array, got shape {llm_probs.shape}")
        if llm_probs.shape[0] < annotations.shape[0]:
            raise ValueError(f"Not enough LLM probabilities for all items: {llm_probs.shape[0]} < {annotations.shape[0]}")
        llm_probs = llm_probs[:annotations.shape[0]]
    
    kernel = NUTS(dawid_skene)
    mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
    mcmc.run(
        random.PRNGKey(0),
        positions,
        annotations,
        masks,
        use_llm_prior=use_llm_prior,
        llm_probs=llm_probs,
    )
    mcmc.print_summary()

    samples = mcmc.get_samples()
    beta_mean = jnp.mean(samples['beta'], axis=0)

    print("\nInferred confusion matrices (beta) for annotators:")
    for i, matrix in enumerate(beta_mean):
        print(f"Annotator {i}:\n{np.round(matrix, 2)}\n")

    predictive = Predictive(dawid_skene, samples, infer_discrete=True)
    discrete_samples = predictive(
        random.PRNGKey(1),
        positions,
        annotations,
        masks,
        use_llm_prior=use_llm_prior,
        llm_probs=llm_probs,
    )
    predicted_labels = discrete_samples["c"]
    return samples, predicted_labels, beta_mean

# if __name__ == "__main__":
#    mcmc_samples, predicted_labels, beta_mean = run_ds_on_subset("data/ghc_train.jsonl", use_llm_prior=True, llm_probs=probs_train)

In [8]:
'''
from numpyro.infer import Predictive

# train_data = load_jsonl("data/ghc_train.jsonl", max_items=100)
# train_positions, train_annotations, train_masks = process_annotations(train_data)

val_data = load_jsonl("data/ghc_val.jsonl") #, max_items=100)
val_positions, val_annotations, val_masks = process_annotations(val_data)

train_mcmc_samples, predicted_labels_train, beta_mean_train = run_ds_on_subset(
    json_path="data/ghc_train.jsonl",
    use_llm_prior=True,
    llm_probs=probs_train
    )
    # max_items=100

predictive_val = Predictive(
    dawid_skene,
    posterior_samples={"beta": beta_mean_train[None, ...]},  # making shape (1, annotators, C, C)
    infer_discrete=True
)

probs_val = probs_val[:val_annotations.shape[0]]

val_pred = predictive_val(
    random.PRNGKey(1),
    val_positions,
    val_annotations,
    val_masks,
    use_llm_prior=True,
    llm_probs=probs_val[:val_annotations.shape[0]]
)

val_predicted_labels = np.array(val_pred["c"]).squeeze()  # shape (num_val_items,)

'''

'\nfrom numpyro.infer import Predictive\n\n# train_data = load_jsonl("data/ghc_train.jsonl", max_items=100)\n# train_positions, train_annotations, train_masks = process_annotations(train_data)\n\nval_data = load_jsonl("data/ghc_val.jsonl") #, max_items=100)\nval_positions, val_annotations, val_masks = process_annotations(val_data)\n\ntrain_mcmc_samples, predicted_labels_train, beta_mean_train = run_ds_on_subset(\n    json_path="data/ghc_train.jsonl",\n    use_llm_prior=True,\n    llm_probs=probs_train\n    )\n    # max_items=100\n\npredictive_val = Predictive(\n    dawid_skene,\n    posterior_samples={"beta": beta_mean_train[None, ...]},  # making shape (1, annotators, C, C)\n    infer_discrete=True\n)\n\nprobs_val = probs_val[:val_annotations.shape[0]]\n\nval_pred = predictive_val(\n    random.PRNGKey(1),\n    val_positions,\n    val_annotations,\n    val_masks,\n    use_llm_prior=True,\n    llm_probs=probs_val[:val_annotations.shape[0]]\n)\n\nval_predicted_labels = np.array(val_pred[

In [6]:
from numpyro.infer import Predictive
from jax import random

val_data = load_jsonl("data/ghc_val.jsonl", max_items=100)
val_positions, val_annotations, val_masks = process_annotations(val_data)

train_mcmc_samples, predicted_labels_train_point_estimate, beta_mean_train = run_ds_on_subset(
    json_path="data/ghc_train.jsonl",
    use_llm_prior=True,
    llm_probs=probs_train,
    max_items=100
)

predictive_for_val_c_samples = Predictive(
    dawid_skene,
    posterior_samples={'beta': train_mcmc_samples['beta']}, 
    infer_discrete=True,
    return_sites=['c']
)

# Ensure the correct number of items and 2D shape
num_val_items_loaded = val_annotations.shape[0]
probs_val_for_pred = probs_val[:num_val_items_loaded]

# If probs_val_for_pred happens to be 3D (N, 1, C), squeeze to 2D (N, C)
if probs_val_for_pred.ndim == 3 and probs_val_for_pred.shape[1] == 1:
    probs_val_for_pred = probs_val_for_pred.squeeze(axis=1)
elif probs_val_for_pred.ndim != 2:
    raise ValueError(f"probs_val_for_pred has an unexpected shape: {probs_val_for_pred.shape}, expected 2D.")

val_c_dist_output = predictive_for_val_c_samples(
    random.PRNGKey(1),
    val_positions,
    val_annotations,
    val_masks,
    use_llm_prior=True,
    llm_probs=probs_val_for_pred
)

c_val_samples = val_c_dist_output["c"]

sample: 100%|██████████| 1500/1500 [00:05<00:00, 289.04it/s, 7 steps of size 3.39e-01. acc. prob=0.92] 



                  mean       std    median      5.0%     95.0%     n_eff     r_hat
 beta[0,0,0]      0.83      0.12      0.85      0.67      1.00   1557.15      1.00
 beta[0,0,1]      0.17      0.12      0.15      0.00      0.33   1557.15      1.00
 beta[0,1,0]      0.30      0.16      0.28      0.02      0.54   1674.09      1.00
 beta[0,1,1]      0.70      0.16      0.72      0.46      0.98   1674.09      1.00
 beta[1,0,0]      0.84      0.14      0.89      0.64      1.00   1680.09      1.00
 beta[1,0,1]      0.16      0.14      0.11      0.00      0.36   1680.09      1.00
 beta[1,1,0]      0.78      0.18      0.83      0.51      1.00   1747.06      1.00
 beta[1,1,1]      0.22      0.18      0.17      0.00      0.49   1747.06      1.00
 beta[2,0,0]      0.83      0.14      0.87      0.64      1.00   1695.99      1.00
 beta[2,0,1]      0.17      0.14      0.13      0.00      0.36   1695.99      1.00
 beta[2,1,0]      0.51      0.23      0.52      0.15      0.89   1571.79      1.00
 be

In [7]:
if 'c_val_samples' not in globals() or 'beta_mean_train' not in globals():
    print("Error: 'c_val_samples' or 'beta_mean_train' not found.")
    print("Please ensure you have run the predictions for 'c' on the validation set (Cell 1) ")
    print("and have the 'beta_mean_train' from the training phase (from Cell 1).")
else:
    if c_val_samples.ndim == 3 and c_val_samples.shape[-1] == 1:
        # If shape is (num_samples, num_items, 1), squeeze out the last dimension
        c_val_samples_processed = c_val_samples.squeeze(axis=-1)
        print(f"Original c_val_samples shape was {c_val_samples.shape}, processed to {c_val_samples_processed.shape}.")
    elif c_val_samples.ndim == 2:
        c_val_samples_processed = c_val_samples
        print(f"c_val_samples shape is {c_val_samples.shape}, using as is.")
    else:
        raise ValueError(f"Unexpected shape for c_val_samples: {c_val_samples.shape}. Expected 2D or 3D with last dim 1 (after running Cell 1).")
    
    num_mcmc_draws, num_val_items = c_val_samples_processed.shape
    
    if 'val_annotations' not in globals() or 'val_positions' not in globals() or 'val_masks' not in globals():
        print("Error: Validation annotation data (val_annotations, val_positions, val_masks) not found.")
    elif val_annotations.shape[0] != num_val_items:
         raise ValueError(f"Number of items in processed c_val_samples ({num_val_items}) "
                          f"does not match val_annotations ({val_annotations.shape[0]})")
    else:
        num_val_items_from_annot, num_positions_in_matrix = val_annotations.shape

        correct_predictions_count = 0
        total_predictions_made = 0

        print(f"Calculating accuracy of predicting annotator responses...")
        print(f"Using {num_mcmc_draws} MCMC samples for c_val.")

        # for every "scenario" generated
        for mcmc_idx in range(num_mcmc_draws):
            if (mcmc_idx + 1) % 100 == 0: # Print progress every 100 samples
                 print(f"  Processed MCMC samples: {mcmc_idx + 1}/{num_mcmc_draws}")

            for item_idx in range(num_val_items):
                # Get the 'true' label for this item from the current MCMC sample of c
                c_true_sample_for_item = c_val_samples_processed[mcmc_idx, item_idx] # Use processed samples

                for pos_idx in range(num_positions_in_matrix):
                    if val_masks[item_idx, pos_idx]: # If there's an actual annotation
                        annotator_id = val_positions[pos_idx]
                        actual_annotator_label = val_annotations[item_idx, pos_idx]

                        # Probability distribution of this annotator's response P(y_ann | c_true, beta_ann)
                        # This is beta_mean_train[annotator_id, c_true_sample_for_item, :]
                        prob_dist_annotator_response = beta_mean_train[annotator_id, c_true_sample_for_item, :]

                        # Make a hard prediction for this annotator's response
                        predicted_annotator_label = jnp.argmax(prob_dist_annotator_response)

                        if predicted_annotator_label == actual_annotator_label:
                            correct_predictions_count += 1
                        total_predictions_made += 1
        
        print("Calculation finished.")

        if total_predictions_made > 0:
            accuracy_of_predicting_annotator_responses = correct_predictions_count / total_predictions_made
            print(f"\nAccuracy of predicting annotator responses on validation: {accuracy_of_predicting_annotator_responses:.4f}")
            print(f"(Based on {total_predictions_made} individual annotator response predictions)")
        else:
            print("\nNo predictions were made for annotator responses.")
            print("Check val_masks (perhaps all False) or the number of MCMC samples.")

Original c_val_samples shape was (1000, 100, 1), processed to (1000, 100).
Calculating accuracy of predicting annotator responses...
Using 1000 MCMC samples for c_val.
  Processed MCMC samples: 100/1000
  Processed MCMC samples: 200/1000
  Processed MCMC samples: 300/1000
  Processed MCMC samples: 400/1000
  Processed MCMC samples: 500/1000
  Processed MCMC samples: 600/1000
  Processed MCMC samples: 700/1000
  Processed MCMC samples: 800/1000
  Processed MCMC samples: 900/1000
  Processed MCMC samples: 1000/1000
Calculation finished.

Accuracy of predicting annotator responses on validation: 0.7342
(Based on 448000 individual annotator response predictions)


In [9]:
# Independent Baseline - LLM priors for 'c' & Ideal Annotators for 'beta'

if 'probs_val_for_pred' not in globals(): # LLM priors for validation P(c|item)
    print("Error: 'probs_val_for_pred' (LLM priors for validation) not found.")
else:
    if 'val_annotations' not in globals() or 'val_positions' not in globals() or 'val_masks' not in globals():
        print("Error: Validation annotation data (val_annotations, val_positions, val_masks) not found.")
    else:
        num_val_items_baseline_ideal = probs_val_for_pred.shape[0]
        num_classes_baseline_ideal = probs_val_for_pred.shape[1]

        if val_annotations.shape[0] != num_val_items_baseline_ideal:
            raise ValueError(f"Number of items in probs_val_for_pred ({num_val_items_baseline_ideal}) "
                             f"does not match val_annotations ({val_annotations.shape[0]})")
        if num_classes_baseline_ideal != 2:
            raise ValueError(f"Baseline code (ideal annotator) currently assumes num_classes=2, found {num_classes_baseline_ideal}")

        num_val_items_from_annot_baseline_ideal, num_positions_in_matrix_baseline_ideal = val_annotations.shape
        
        correct_predictions_baseline_ideal_count = 0
        total_predictions_baseline_ideal_made = 0

        print(f"\nCalculating fully independent baseline accuracy (LLM priors for 'c', Ideal Annotators)...")

        # Get hard predictions for 'c' from LLM priors
        # c_pred_llm_hard will be 0 or 1 for each item
        c_pred_llm_hard = (probs_val_for_pred[:, 1] > 0.5).astype(jnp.int32)

        for item_idx in range(num_val_items_baseline_ideal):
            # Predicted 'true' label for this item by LLM (hard prediction)
            predicted_c_for_item_by_llm = c_pred_llm_hard[item_idx]

            for pos_idx in range(num_positions_in_matrix_baseline_ideal):
                if val_masks[item_idx, pos_idx]: # If there's an actual annotation
                    # annotator_id = val_positions[pos_idx] # Not needed if annotators are ideal
                    actual_annotator_label = val_annotations[item_idx, pos_idx]

                    # Baseline prediction for annotator's response:
                    # If annotators are ideal, they respond with the predicted 'true' label
                    predicted_annotator_label_baseline_ideal = predicted_c_for_item_by_llm
                    
                    if predicted_annotator_label_baseline_ideal == actual_annotator_label:
                        correct_predictions_baseline_ideal_count += 1
                    total_predictions_baseline_ideal_made += 1
        
        print("Fully independent baseline calculation finished.")

        if total_predictions_baseline_ideal_made > 0:
            accuracy_baseline_ideal = correct_predictions_baseline_ideal_count / total_predictions_baseline_ideal_made
            print(f"\nFully Independent Baseline Accuracy (LLM for 'c', Ideal Beta): {accuracy_baseline_ideal:.4f}")
            print(f"(Based on {total_predictions_baseline_ideal_made} individual annotator response predictions)")
        else:
            print("\nNo predictions were made for the fully independent baseline.")


Calculating fully independent baseline accuracy (LLM priors for 'c', Ideal Annotators)...
Fully independent baseline calculation finished.

Fully Independent Baseline Accuracy (LLM for 'c', Ideal Beta): 0.6942
(Based on 448 individual annotator response predictions)


In baseline, we get LLM priors and for each text consider the label with higher LLM prior as true label. Then, we think that all annotators are "ideal" (vote as true labels are) and by that make predictions for annotators labels. Then, we compare predicted labels of annotators with the real one and report the accuracy.