# Calculate Mean Reciprocal Expected Rank
Just a quick notebook for calculating the expected rank for each document, then using that to calculate MRER for the whole data set.

The expected rank is calculated by treating the ranking task as an urn problem. Given an urn filled with $n$ marbles of which $t$ are green, the expected number of draws from the urn to get a green marble is $\frac{n+1}{t+1}$.
This is equivalent to assigning an i.i.d. random score to every statement and asking how far down the first true statement occurs.

The Mean Reciprocol Expected Rank measures the average reciprocol of the expected rank for every document: $$\frac{1}{\left|\mathcal{D}_p\right|}\sum^{\left|\mathcal{D}_p\right|}_{i=1}{\frac{t_i+1}{n_i+1}}$$

In [1]:
import json
# Schema typing
from typing import List, Dict, TypedDict, Tuple, NamedTuple, Iterator, TypedDict
from functools import cached_property
from collections import Counter, defaultdict
import random
from enum import Flag, auto
from itertools import count
import pickle

def flatten(xss):
    return [ x for xs in xss for x in xs ]

random.seed(4423)


In [8]:
class Label(NamedTuple):
    r: str
    h: int
    t: int
    evidence: List[int]

    def toDict(self):
        return {
            "r": self.r,
            "h": self.h,
            "t": self.t,
            "evidence": self.evidence
        }

class Mention(NamedTuple):
    name: str
    pos: Tuple[int, int]
    sent_id: int
    type: str
    text: str
    id: int
    global_pos: Tuple[int, int] = None
    index: str = None

    def toDict(self):
        d = {
            "name": self.name,
            "pos": self.pos,
            "sent_id": self.sent_id,
            "type": self.type,
            "text": self.text
        }
        if self.index:
            d["index"] = self.index
            d["global_pos"] = self.global_pos
        return d


# Entity = List[Mention]

WordList = List[str]
SentenceList = List[WordList]

class Entity(List[Mention]):
    doc: int
    ent: int

    @staticmethod
    def new(ls: List, ent:int):
        e = Entity(ls)
        e.ent = ent
        return e

    @cached_property
    def type(self):
        return Counter([m.type for m in self]).most_common(1)[0][0] if len(self) > 0 else None
        

VertexSet = List[Entity]


# class ReplaceMode(Flag):
#     ENTITY = auto()
#     MENTION = auto()
#     MENTION_MARKER = auto()
#     MASK = auto()
#     FIXED_WIDTH = auto()
#     FULL_WIDTH = auto()
#     POSITION_UPDATE = auto()


class Document:
    vertexSet: VertexSet
    labels: List[Label]
    title: str
    sents: SentenceList

    def __init__(self, docnum, vertexSet, labels, title, sents):
        self.seen = defaultdict(list)
        self.answerset = set()
        c = count()
        self.labels = [Label(**l) for l in labels]
        for lab in self.labels:
            self.answerset.add(lab.t)
            self.answerset.add(lab.h)

        # Absolute chaos.
        # Enumerate the entities (vertices), sort by if they're in an answer (true relation),
        # then filter their mentions by whether tokens from those mentions have already been seen.
        # Basically, we're trying to handle the case that a mention is assigned to multiple entities
        # as gracefully as possible.
        # self.vertexSet = [Entity.new([Mention(**m, text=sents[m['sent_id']][m['pos'][0]:m['pos'][1]], id=next(c)) for m in e if self.check(m)], ent=i) for i, e in 
        #                   sorted(enumerate(vertexSet), key=lambda x: (0, x[0]) if x[0] in self.answerset else (1, x[0]))]
        
        self.vertexSet = [None]*len(vertexSet)
        for i, e in sorted(enumerate(vertexSet), key=lambda x: (0, x[0]) if x[0] in self.answerset else (1, x[0])):
            en = []
            for m in e:
                if self.check(m):
                    if "text" in m:
                        en.append(Mention(**m, id=next(c)))
                    else:
                        en.append(Mention(**m, text=sents[m['sent_id']][m['pos'][0]:m['pos'][1]], id=next(c)))
            self.vertexSet[i] = Entity.new(en, i)
            self.vertexSet[i].doc = docnum

        # for e in self.vertexSet:
        #     e.doc = docnum
        # self.vertexSet.sort(key=lambda x: x.ent)
        self.title = title
        self.sents = sents

    def check(self, m: int):
        s = m['sent_id']

        covered = flatten(self.seen[s])
        candidates = list(range(*m['pos']))
        if any( w in covered for w in candidates ):
            return False
        self.seen[s].append(candidates)
        return True

    def print(self, sep=" "):
        print(sep.join(' '.join(sent) for sent in self.sents))

    def getMentionsInSentence(self, sentenceId: str) -> Iterator[Tuple[int, Mention]]:
        for i, e in enumerate(self.vertexSet):
            for m in e:
                seen = set()
                if m.sent_id == sentenceId:
                    if m.pos[0] not in seen:
                        seen.add(m.pos[0])
                        yield i, m

    def sentenceReplaceWithMarkers(self, sents=None) -> List[List[str]]:
        map = []
        if not sents:
            sents = self.sents
        for i, sent in enumerate(sents):
            mentions = list(self.getMentionsInSentence(i))
            # print(i, mentions)
            row = []
            for j, s in enumerate(sent):
                v = [s]
                for k, m in sorted(mentions, key=lambda x: x[1].pos[0]):
                    if j in range(*m.pos):
                        if j == m.pos[0]:
                            v = [f'<M_{m.pos[0]}>']
                        else:
                            v = []
                        break
                row.extend(v)
            map.append(row)
        return map

    def sentenceUpdate(self, apply=False):
        map = []
        for i, sent in enumerate(self.sentenceReplaceWithMarkers()):
            mentions = {m.pos[0]: m for _, m in self.getMentionsInSentence(i)}
            row = []
            for j, s in enumerate(sent):
                if s.startswith("<M_"):
                    m = mentions[int(s[3:-1])]
                    if apply:
                        m.pos[0] = len(row)
                    row.extend(m.text)
                    if apply:
                        m.pos[1] = len(row)
                    pass
                else:
                    row.append(s)
            map.append(row)
        return map
    
    def toDict(self):
        return {
            "title": self.title,
            "vertexSet": [[m.toDict() for m in e] for e in self.vertexSet],
            "labels": [lab.toDict() for lab in self.labels],
            "sents": self.sents
        }
    
    def toJson(self):
        # def default(obj):
        #     if isinstance(Mention):
        #         return obj.toDict()
            
        return json.dumps(self.toDict(), ensure_ascii=False)

In [9]:
# with open("data/re-docred/rel_info_domain_range.pickle", 'rb') as pick:
#     dev_answers = pickle.load(pick)

with open("data/re-docred/rel_info_full.json", 'r') as ri:
    rel_info = json.load(ri)
rel_info["P17"]

{'name': 'country',
 'desc': 'sovereign state that this item is in (not to be used for human beings)',
 'prompt_xy': '?x is in the country of ?y.',
 'prompt_yx': '?y is the country where ?x is located.',
 'domain': ['ORG', 'LOC', 'MISC'],
 'range': ['LOC'],
 'reflexive': False,
 'irreflexive': True,
 'symmetric': False,
 'antisymmetric': False,
 'transitive': False,
 'implied_by': [],
 'tokens': ['?x', 'is', 'in', 'the', 'country', 'of', '?y', '.'],
 'verb': 1}

In [None]:
task_name = 're-docred'
dset = 'dev_c'


with open(f'data/{task_name}/{dset}.json', 'r', encoding='utf8') as docred_dev_json:
    docred_dev = json.load(docred_dev_json)

with open(f"data/{task_name}/rel_info_full.json", 'r') as ri:
    rel_info = json.load(ri)

# print(rel_info["P17"])
# Schema: List[{vertexSet, labels: Label, title: str, sents: SentenceList}]

rers = {r:[] for r in rel_info}

if 'bio' in task_name:
    all_types = {'DiseaseOrPhenotypicFeature', 'GeneOrGeneProduct', 'ChemicalEntity', 'SequenceVariant', 'OrganismTaxon', 'CellLine'}
else:
    all_types = {'PER','LOC','NUM','TIME','ORG','MISC'}

for i, doc in enumerate(docred_dev):
    d = Document(docnum=i, **docred_dev[i])
    # print(json.dumps(d.toDict(), indent=2))
    rel_set = Counter(ans.r for ans in d.labels)
    # Calculate the possible domain and ranges of each relation

    ent_by_type = {t:[] for t in all_types}

    for e in d.vertexSet:
        if e.type:
            ent_by_type[e.type].append(e.ent)

    # print(ent_by_type)

    for rel in rel_set:
        # print(rel)
        dom = rel_info[rel]['domain']
        # print(dom)
        ran = rel_info[rel]['range']
        # print(ran)
        _dom = set()
        for t in dom:
            _dom.update(ent_by_type[t])
        _ran = set()
        for t in ran:
            _ran.update(ent_by_type[t])
        # print(_dom)
        # print(_ran)
        t = rel_set[rel]
        n = len(_dom)*len(_ran)
        if(rel_info[rel]['irreflexive']):
            n -= len(_dom&_ran)
        # print(f"{t=}")
        # print(f"{n=}")
        er=(n+1)/(t+1)
        # print(f"ER={er}")
        rers[rel].append(1/er)
    
# print(rer)

mrers = []

if 'bio' in task_name:
    relsel = ['Association','Bind','Negative_Correlation','Positive_Correlation']
else:
    relsel = ['P17', 'P27', 'P131', 'P150', 'P161', 'P175', 'P527', 'P569', 'P570', 'P577']

for rel in relsel:
    rer = rers[rel]
    mrers.append(sum(rer)/len(rer))
    print(f"MRER for {rel}:", sum(rer)/len(rer))

print(', '.join(f'{m*100:.2f}' for m in mrers))


# print(rel_set)




MRER for P17: 0.17407445029865207
MRER for P27: 0.3272475147234804
MRER for P131: 0.2585455308747644
MRER for P150: 0.13219489306885063
MRER for P161: 0.27225270606268115
MRER for P175: 0.13170181668098732
MRER for P527: 0.027260517926438497
MRER for P569: 0.18630934428136633
MRER for P570: 0.16771662674820365
MRER for P577: 0.2502119580913329
17.41,32.72,25.85,13.22,27.23,13.17,2.73,18.63,16.77,25.02
Counter({'P131': 13, 'P17': 11, 'P1001': 2, 'P150': 1})
