## Ranking from Binary Comparisons

In this task there are a set $\mathcal{O}$ of $N$ objects with some intrinsic value $v : \mathcal{O} \to \mathbb{R}$, and $M$ raters who can determine relative ordering of value between two objects, $r : \mathcal{O} \times \mathcal{O} \to \{0, 1\}$. That is, the raters are not able to measure the value of each object or even the relative value between objects. Rather, they can only say whether one objects value is higher or lower than another. In this example we do not allow raters to assign equal value between objects.

Each rater provides $Q$ ratings between $Q$ random unique pairs of objects, and has a chance $p$ of getting it wrong. For rater $i$, and objects $j$ and $k$, define the rating $r_ijk$ as 1 if $v(o_j) > v(o_k)$ and 0 otherwise.

The objective of the experimenter is to deduce the relative value of each object on a given scale.

In this example, we will let that scale be the 5 star rating system, so that each object in $\mathcal{O}$ will be assigned on a value in $[1, 5]$ (inclusive).

To do this we apply a prior that each object $o_i \in \mathcal{O}$ has a value $v(o_i) \triangleq v_i \sim \mathcal{U}[1,5]$, and then form a likelihood which is the sum of violated constraints.


$L(v_{i:T}) = \sum_{ijk} r_{ijk} \mathbb{1}(v(o_j) < v(o_k))$

In [1]:
from jax import numpy as jnp, random
from jaxns import GlobalOptimiser, PriorChain
from jaxns import save_results, marginalise_dynamic
from jaxns.prior_transforms import UniformPrior
from itertools import combinations



In [2]:


def get_constraints(num_options, num_raters, tests_per_rater, p_wrong):
    key, true_key = random.split(random.PRNGKey(47573957), 2)
    actual_rank = random.uniform(true_key,shape=(num_options,),minval=1., maxval=5.)

    pairs = jnp.asarray(list(combinations(range(num_options), 2)))
    I = []
    J = []
    S = []
    for rater in range(num_raters):
        key, sample_key1, sample_key2, sample_key3 = random.split(key, 4)
        choices = random.choice(sample_key1,pairs.shape[0], shape=(tests_per_rater,), replace=False)
        I.append(pairs[choices,0])
        J.append(pairs[choices,1])
        guess = jnp.where(random.uniform(sample_key1) < p_wrong,
                        actual_rank[I[-1]] < actual_rank[J[-1]], # wrong
        actual_rank[I[-1]] > actual_rank[J[-1]] # right
                          )
        S.append(guess)

    return actual_rank, jnp.concatenate(I),jnp.concatenate(J),jnp.concatenate(S)

In [3]:

num_options=10
num_raters=20
tests_per_rater=3
p_wrong=0.1

actual_rank, I, J, S = get_constraints(num_options, num_raters, tests_per_rater, p_wrong=p_wrong)

INFO[2022-04-08 19:12:33,149]: Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO[2022-04-08 19:12:33,150]: Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO[2022-04-08 19:12:33,151]: Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


In [4]:

def log_likelihood(rank):
    order = rank[I] > rank[J]
    violations = jnp.sum(order != S)
    return -violations

with PriorChain() as prior_chain:
    UniformPrior('rank', jnp.ones(num_options), 5*jnp.ones(num_options))

go = GlobalOptimiser(loglikelihood=log_likelihood,
                   prior_chain=prior_chain)

results = go(random.PRNGKey(32564),
             termination_frac_likelihood_improvement=1e-3,termination_patience=3,
             termination_max_num_likelihood_evaluations=10e6)

In [5]:
go.summary(results)

--------
Termination Conditions:
On a plateau
--------
# likelihood evals: 171570
# samples: 939
# likelihood evals / sample: 182.7
--------
Maximum logL=-9.0
--------
rank[#]: max(L) est.
rank[0]: 2.13
rank[1]: 2.85
rank[2]: 3.24
rank[3]: 3.77
rank[4]: 4.38
rank[5]: 1.61
rank[6]: 3.76
rank[7]: 1.47
rank[8]: 4.67
rank[9]: 1.47
--------


'--------\nTermination Conditions:\nOn a plateau\n--------\n# likelihood evals: 171570\n# samples: 939\n# likelihood evals / sample: 182.7\n--------\nMaximum logL=-9.0\n--------\nrank[#]: max(L) est.\nrank[0]: 2.13\nrank[1]: 2.85\nrank[2]: 3.24\nrank[3]: 3.77\nrank[4]: 4.38\nrank[5]: 1.61\nrank[6]: 3.76\nrank[7]: 1.47\nrank[8]: 4.67\nrank[9]: 1.47\n--------'

In [6]:
# The maximum likelihood estimate has to fewest violation

rank_L_max_estimate = results.sample_L_max['rank']
log_L_max = results.log_L_max
print(f"Number of violations at maximum likelihood estimate: {-log_L_max}")

# Compare to median of posterior

log_L_actual = log_likelihood(actual_rank)
print(f"Number of violations at actual rank: {-log_L_actual}")

for i in range(num_options):
    print(f"Option {i}: True rank={actual_rank[i]}, L_Max rank={rank_L_max_estimate[i]}")

Number of violations at maximum likelihood estimate: 9.0
Number of violations at actual rank: 9
Option 0: True rank=2.3695788383483887, L_Max rank=2.126412868499756
Option 1: True rank=2.6293601989746094, L_Max rank=2.8469901084899902
Option 2: True rank=2.764465808868408, L_Max rank=3.243753433227539
Option 3: True rank=3.8243227005004883, L_Max rank=3.767789602279663
Option 4: True rank=2.954990863800049, L_Max rank=4.38162088394165
Option 5: True rank=2.3232431411743164, L_Max rank=1.6057441234588623
Option 6: True rank=4.1021881103515625, L_Max rank=3.7617671489715576
Option 7: True rank=1.4255428314208984, L_Max rank=1.4677155017852783
Option 8: True rank=4.619992733001709, L_Max rank=4.666685104370117
Option 9: True rank=1.1152663230895996, L_Max rank=1.4665040969848633


In [7]:
# Lets set what the relative order is.
ordering_L_max = jnp.argsort(rank_L_max_estimate)
ordering_true = jnp.argsort(actual_rank)
for i in range(num_options):
    print(f"True rank = {ordering_true[i]}, max(L) rank = {ordering_L_max[i]}")

True rank = 9, max(L) rank = 9
True rank = 7, max(L) rank = 7
True rank = 5, max(L) rank = 5
True rank = 0, max(L) rank = 0
True rank = 1, max(L) rank = 1
True rank = 2, max(L) rank = 2
True rank = 4, max(L) rank = 6
True rank = 3, max(L) rank = 3
True rank = 6, max(L) rank = 4
True rank = 8, max(L) rank = 8
