## Ranking from Binary Comparisons

In this task there are a set $\mathcal{O}$ of $N$ objects with some intrinsic value $v : \mathcal{O} \to \mathbb{R}$, and $M$ raters who can determine relative ordering of value between two objects, $r : \mathcal{O} \times \mathcal{O} \to \{0, 1\}$. That is, the raters are not able to measure the value of each object or even the relative value between objects. Rather, they can only say whether one objects value is higher or lower than another. In this example we do not allow raters to assign equal value between objects.

Each rater provides $Q$ ratings between $Q$ random unique pairs of objects, and has a chance $p$ of getting it wrong. For rater $i$, and objects $j$ and $k$, define the rating $r_ijk$ as 1 if $v(o_j) > v(o_k)$ and 0 otherwise.

The objective of the experimenter is to deduce the relative value of each object on a given scale.

In this example, we will let that scale be the 5 star rating system, so that each object in $\mathcal{O}$ will be assigned on a value in $[1, 5]$ (inclusive).

To do this we apply a prior that each object $o_i \in \mathcal{O}$ has a value $v(o_i) \triangleq v_i \sim \mathcal{U}[1,5]$, and then form a likelihood which is the sum of violated constraints.


$L(v_{i:T}) = \sum_{ijk} r_{ijk} \mathbb{1}(v(o_j) < v(o_k))$

In [1]:
from jax import numpy as jnp, random, jit
from jaxns import NestedSampler, PriorChain
from jaxns import save_results, load_results, marginalise_dynamic
from jaxns.prior_transforms import UniformPrior
from jaxns.plotting import plot_diagnostics, plot_cornerplot
from jaxns.utils import summary, resample
from itertools import combinations

In [2]:


def get_constraints(num_options, num_raters, tests_per_rater, p_wrong):
    key, true_key = random.split(random.PRNGKey(47573957), 2)
    actual_rank = random.uniform(true_key,shape=(num_options,),minval=1., maxval=5.)

    pairs = jnp.asarray(list(combinations(range(num_options), 2)), dtype=jnp.int_)
    I = []
    J = []
    S = []
    for rater in range(num_raters):
        key, sample_key1, sample_key2, sample_key3 = random.split(key, 4)
        choices = random.choice(sample_key1,pairs.shape[0], shape=(tests_per_rater,), replace=False)
        I.append(pairs[choices,0])
        J.append(pairs[choices,1])
        guess = jnp.where(random.uniform(sample_key1) < p_wrong,
                        actual_rank[I[-1]] < actual_rank[J[-1]], # wrong
        actual_rank[I[-1]] > actual_rank[J[-1]] # right
                          )
        S.append(guess)

    return actual_rank, jnp.concatenate(I),jnp.concatenate(J),jnp.concatenate(S)

In [3]:

num_options=10
num_raters=20
tests_per_rater=3
p_wrong=0.1

actual_rank, I, J, S = get_constraints(num_options, num_raters, tests_per_rater, p_wrong=p_wrong)

  lax._check_user_dtype_supported(dtype, "asarray")


In [None]:

def log_likelihood(rank):
    order = rank[I] > rank[J]
    violations = jnp.sum(order != S)
    return -violations

with PriorChain() as prior_chain:
    UniformPrior('rank', jnp.ones(num_options), 5*jnp.ones(num_options))

ns = NestedSampler(loglikelihood=log_likelihood,
                   prior_chain=prior_chain,
                   sampler_kwargs=dict(gradient_boost=True))

results = ns(random.PRNGKey(32564),
             termination_max_num_steps=40,
             maximise_likelihood=True)
save_results(results, 'ranking_save.npz')
# results = load_results('ranking_save.npz')

  lax._check_user_dtype_supported(dtype, "zeros")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "astype")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "astype")
  lax._check_user_dtype_supported(dtype, "astype")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "zeros")
  lax._check_user_dtype_supported(dtype, "astype")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "asarray")


In [None]:

summary(results)
plot_diagnostics(results)
plot_cornerplot(results, vars=['rank'])

In [None]:
# The maximum likelihood estimate has to fewest violation

rank_L_max_estimate = results.sample_L_max['rank']
log_L_max = results.log_L_max
print(f"Number of violations at maximum likelihood estimate: {-log_L_max}")

# Compare to median of posterior
samples = resample(random.PRNGKey(245944), results.samples, results.log_dp_mean, S=int(results.ESS))
rank_median_estimate = jnp.median(samples['rank'], axis=0)
log_L_median_post = log_likelihood(rank_median_estimate)
print(f"Number of violations at median of posterior: {-log_L_median_post}")

log_L_actual = log_likelihood(actual_rank)
print(f"Number of violations at actual rank: {-log_L_actual}")

for i in range(num_options):
    print(f"Option {i}: True rank={actual_rank[i]}, L_Max rank={rank_L_max_estimate[i]}, median posterior rank={rank_median_estimate[i]}")

In [None]:
# Lets set what the relative order is.
# Posterior predictive ordering
ordering_marginalised = marginalise_dynamic(random.PRNGKey(42), results.samples, results.log_dp_mean, results.ESS, lambda rank: jnp.argsort(rank))
ordering_L_max = jnp.argsort(rank_L_max_estimate)
ordering_median_posterior = jnp.argsort(rank_median_estimate)
ordering_true = jnp.argsort(actual_rank)
for i in range(num_options):
    print(f"True rank = {ordering_true[i]}, max(L) rank = {ordering_L_max[i]}, mean_posterior_ordering={ordering_median_posterior[i]}, marinalised rank= {ordering_marginalised[i]}")