## Ranking from Binary Comparisons

In this task there are a set $\mathcal{O}$ of $N$ objects with some intrinsic value $v : \mathcal{O} \to \mathbb{R}$, and $M$ raters who can determine relative ordering of value between two objects, $r : \mathcal{O} \times \mathcal{O} \to \{0, 1\}$. That is, the raters are not able to measure the value of each object or even the relative value between objects. Rather, they can only say whether one objects value is higher or lower than another. In this example we do not allow raters to assign equal value between objects.

Each rater provides $Q$ ratings between $Q$ random unique pairs of objects, and has a chance $p$ of getting it wrong. For rater $i$, and objects $j$ and $k$, define the rating $r_ijk$ as 1 if $v(o_j) > v(o_k)$ and 0 otherwise.

The objective of the experimenter is to deduce the relative value of each object on a given scale.

In this example, we will let that scale be the 5 star rating system, so that each object in $\mathcal{O}$ will be assigned on a value in $[1, 5]$ (inclusive).

To do this we apply a prior that each object $o_i \in \mathcal{O}$ has a value $v(o_i) \triangleq v_i \sim \mathcal{U}[1,5]$, and then form a likelihood which is the sum of violated constraints.


$L(v_{i:T}) = \sum_{ijk} r_{ijk} \mathbb{1}(v(o_j) < v(o_k))$

In [1]:
from jax import numpy as jnp, random, jit
from jaxns import GlobalOptimiser, PriorChain
from jaxns import save_results, marginalise_dynamic
from jaxns.prior_transforms import UniformPrior
from itertools import combinations

# for parallel sampling
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

In [2]:


def get_constraints(num_options, num_raters, tests_per_rater, rater_accuracy):
    key, true_key = random.split(random.PRNGKey(47573957), 2)
    actual_rank = random.uniform(true_key,shape=(num_options,),minval=1., maxval=5.)

    pairs = jnp.asarray(list(combinations(range(num_options), 2)))
    I = []
    J = []
    S = []
    errors = []
    for rater in range(num_raters):
        key, sample_key1, sample_key2, sample_key3 = random.split(key, 4)
        choices = random.choice(sample_key1,pairs.shape[0], shape=(tests_per_rater,), replace=False)
        I.append(pairs[choices,0])
        J.append(pairs[choices,1])

        rate_error = random.normal(sample_key1, shape=(tests_per_rater,))*rater_accuracy
        difference_guess = actual_rank[I[-1]] - actual_rank[J[-1]]
        guess = difference_guess + rate_error > 0.
        errors.append((difference_guess>0) != guess)
        S.append(guess)

    return actual_rank, jnp.concatenate(I),jnp.concatenate(J),jnp.concatenate(S), jnp.concatenate(errors)

In [3]:

num_options=50
num_raters=20
tests_per_rater=20
rater_accuracy=0.1

actual_rank, I, J, S, errors = get_constraints(num_options, num_raters, tests_per_rater, rater_accuracy)
print(f"Number of rater errors: {jnp.sum(errors)}")

INFO[2022-06-01 22:53:03,187]: Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO[2022-06-01 22:53:03,188]: Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO[2022-06-01 22:53:03,189]: Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


Number of rater errors: 8


In [4]:

def log_likelihood(rank):
    order = rank[I] > rank[J]
    violations = jnp.sum(order != S)
    return -violations

with PriorChain() as prior_chain:
    UniformPrior('rank', jnp.ones(num_options), 5*jnp.ones(num_options))

go = GlobalOptimiser(loglikelihood=log_likelihood,
                   prior_chain=prior_chain,
                    num_parallel_samplers=4)

results = jit(go)(random.PRNGKey(32564),
             termination_patience=2)



In [5]:
go.summary(results)

--------
Termination Conditions:
On a plateau
--------
# likelihood evals: 7534181
# samples: 65000
# likelihood evals / sample: 115.9
--------
Maximum logL=0.0
--------
rank[#]: max(L) est.
rank[0]: 4.33
rank[1]: 4.7
rank[2]: 4.61
rank[3]: 4.59
rank[4]: 2.81
rank[5]: 2.87
rank[6]: 3.07
rank[7]: 4.29
rank[8]: 2.52
rank[9]: 2.88
rank[10]: 2.66
rank[11]: 1.76
rank[12]: 2.2
rank[13]: 4.44
rank[14]: 4.71
rank[15]: 2.34
rank[16]: 3.33
rank[17]: 1.273
rank[18]: 3.52
rank[19]: 1.32
rank[20]: 2.82
rank[21]: 2.94
rank[22]: 3.59
rank[23]: 3.14
rank[24]: 2.22
rank[25]: 4.5
rank[26]: 3.62
rank[27]: 1.91
rank[28]: 4.02
rank[29]: 3.82
rank[30]: 1.41
rank[31]: 2.36
rank[32]: 2.61
rank[33]: 1.44
rank[34]: 2.58
rank[35]: 2.2
rank[36]: 1.47
rank[37]: 2.18
rank[38]: 2.56
rank[39]: 3.16
rank[40]: 3.54
rank[41]: 2.93
rank[42]: 1.37
rank[43]: 3.98
rank[44]: 3.14
rank[45]: 2.95
rank[46]: 3.8
rank[47]: 1.09
rank[48]: 1.45
rank[49]: 3.25
--------


'--------\nTermination Conditions:\nOn a plateau\n--------\n# likelihood evals: 7534181\n# samples: 65000\n# likelihood evals / sample: 115.9\n--------\nMaximum logL=0.0\n--------\nrank[#]: max(L) est.\nrank[0]: 4.33\nrank[1]: 4.7\nrank[2]: 4.61\nrank[3]: 4.59\nrank[4]: 2.81\nrank[5]: 2.87\nrank[6]: 3.07\nrank[7]: 4.29\nrank[8]: 2.52\nrank[9]: 2.88\nrank[10]: 2.66\nrank[11]: 1.76\nrank[12]: 2.2\nrank[13]: 4.44\nrank[14]: 4.71\nrank[15]: 2.34\nrank[16]: 3.33\nrank[17]: 1.273\nrank[18]: 3.52\nrank[19]: 1.32\nrank[20]: 2.82\nrank[21]: 2.94\nrank[22]: 3.59\nrank[23]: 3.14\nrank[24]: 2.22\nrank[25]: 4.5\nrank[26]: 3.62\nrank[27]: 1.91\nrank[28]: 4.02\nrank[29]: 3.82\nrank[30]: 1.41\nrank[31]: 2.36\nrank[32]: 2.61\nrank[33]: 1.44\nrank[34]: 2.58\nrank[35]: 2.2\nrank[36]: 1.47\nrank[37]: 2.18\nrank[38]: 2.56\nrank[39]: 3.16\nrank[40]: 3.54\nrank[41]: 2.93\nrank[42]: 1.37\nrank[43]: 3.98\nrank[44]: 3.14\nrank[45]: 2.95\nrank[46]: 3.8\nrank[47]: 1.09\nrank[48]: 1.45\nrank[49]: 3.25\n--------'

In [6]:
# The maximum likelihood estimate has to fewest violation

rank_L_max_estimate = results.sample_L_max['rank']
log_L_max = results.log_L_max
print(f"Number of violations at maximum likelihood estimate: {-log_L_max}")

# Compare to median of posterior

log_L_actual = log_likelihood(actual_rank)
print(f"Number of violations at actual rank: {-log_L_actual}")

for i in range(num_options):
    print(f"Option {i}: True rank={actual_rank[i]}, L_Max rank={rank_L_max_estimate[i]}")

Number of violations at maximum likelihood estimate: -0.0
Number of violations at actual rank: 8
Option 0: True rank=3.9682798385620117, L_Max rank=4.333476543426514
Option 1: True rank=4.861865520477295, L_Max rank=4.704471111297607
Option 2: True rank=4.602080821990967, L_Max rank=4.609418869018555
Option 3: True rank=4.625879287719727, L_Max rank=4.594964981079102
Option 4: True rank=2.1947169303894043, L_Max rank=2.8111586570739746
Option 5: True rank=2.5244269371032715, L_Max rank=2.87062406539917
Option 6: True rank=3.381387710571289, L_Max rank=3.0678021907806396
Option 7: True rank=4.326211929321289, L_Max rank=4.2943220138549805
Option 8: True rank=2.9776535034179688, L_Max rank=2.522937297821045
Option 9: True rank=2.9064841270446777, L_Max rank=2.877124786376953
Option 10: True rank=2.7344398498535156, L_Max rank=2.6613364219665527
Option 11: True rank=1.8382315635681152, L_Max rank=1.7611408233642578
Option 12: True rank=2.1450705528259277, L_Max rank=2.2023191452026367
Opt

In [7]:
# Lets set what the relative order is.
ordering_L_max = jnp.argsort(jnp.argsort(rank_L_max_estimate))
ordering_true = jnp.argsort(jnp.argsort(actual_rank))
avg_order_diff = jnp.abs(ordering_L_max - ordering_true).mean()
print(f"Avgerage ordering abs diff: {avg_order_diff}")
for i in range(num_options):
    print(f"True rank = {ordering_true[i]}, max(L) rank = {ordering_L_max[i]}")

Avgerage ordering abs diff: 2.5999999046325684
True rank = 40, max(L) rank = 43
True rank = 48, max(L) rank = 48
True rank = 46, max(L) rank = 47
True rank = 47, max(L) rank = 46
True rank = 13, max(L) rank = 21
True rank = 17, max(L) rank = 23
True rank = 29, max(L) rank = 28
True rank = 43, max(L) rank = 42
True rank = 22, max(L) rank = 16
True rank = 21, max(L) rank = 24
True rank = 19, max(L) rank = 20
True rank = 7, max(L) rank = 8
True rank = 11, max(L) rank = 12
True rank = 42, max(L) rank = 44
True rank = 49, max(L) rank = 49
True rank = 18, max(L) rank = 14
True rank = 31, max(L) rank = 33
True rank = 1, max(L) rank = 1
True rank = 36, max(L) rank = 34
True rank = 2, max(L) rank = 2
True rank = 25, max(L) rank = 22
True rank = 32, max(L) rank = 26
True rank = 39, max(L) rank = 36
True rank = 28, max(L) rank = 29
True rank = 14, max(L) rank = 13
True rank = 44, max(L) rank = 45
True rank = 35, max(L) rank = 37
True rank = 9, max(L) rank = 9
True rank = 37, max(L) rank = 41
True