# Making of: DocShRED
DocRED, but with entities shuffled between documents:
Document-level SHuffled-entity Relation Extraction Dataset.

We'll need to do a few things:
- Read all the docs.
- Collect all of the entities for each doc
  - Collect all of the mentions for every doc, plus mention type labels
- Estimate the correct type of the entity based on the most common label
- Shuffle all of the entities for a particular type
- Insert those entities into the document

Things to consider:
- Not all entities have the same number of unique mentions...
- Some mentions are attributed incorrectly to multiple entities. First come, first served.
- Some mention type labels conflict. Take the most common label as the entity type.
- Use the first-first rule: When selecting a mention, the first mention of an entity will always be replaced with the first mention of another entity.

In [None]:
import json
# Schema typing
from typing import List, Dict, TypedDict, Tuple, NamedTuple, Iterator, TypedDict
from functools import cached_property
from collections import Counter, defaultdict
import random
from enum import Flag, auto
from itertools import count

def flatten(xss):
    return [ x for xs in xss for x in xs ]

random.seed(4423)

data_set = "re-docred"
data_slice = "test"
data_folder = f"data/{data_set}"
json_path = f"{data_folder}/{data_slice}.json"


In [None]:
class Label(NamedTuple):
    r: str
    h: int
    t: int
    evidence: List[int]

    def toDict(self):
        return {
            "r": self.r,
            "h": self.h,
            "t": self.t,
            "evidence": self.evidence
        }

class Mention(NamedTuple):
    name: str
    pos: Tuple[int, int]
    sent_id: int
    type: str
    text: str
    id: int
    global_pos: Tuple[int, int] = None
    index: str = None

    def toDict(self):
        d = {
            "name": self.name,
            "pos": self.pos,
            "sent_id": self.sent_id,
            "type": self.type,
            "text": self.text
        }
        if self.index:
            d["index"] = self.index
            d["global_pos"] = self.global_pos
        return d


# Entity = List[Mention]

WordList = List[str]
SentenceList = List[WordList]

class Entity(List[Mention]):
    doc: int
    ent: int

    @staticmethod
    def new(ls: List, ent:int):
        e = Entity(ls)
        e.ent = ent
        return e

    @cached_property
    def type(self):
        return Counter([m.type for m in self]).most_common(1)[0][0] if len(self) > 0 else None
        

VertexSet = List[Entity]


# class ReplaceMode(Flag):
#     ENTITY = auto()
#     MENTION = auto()
#     MENTION_MARKER = auto()
#     MASK = auto()
#     FIXED_WIDTH = auto()
#     FULL_WIDTH = auto()
#     POSITION_UPDATE = auto()


class Document:
    vertexSet: VertexSet
    labels: List[Label]
    title: str
    sents: SentenceList

    def __init__(self, docnum, vertexSet, labels, title, sents):
        self.seen = defaultdict(list)
        self.answerset = set()
        c = count()
        self.labels = [Label(**l) for l in labels]
        for lab in self.labels:
            self.answerset.add(lab.t)
            self.answerset.add(lab.h)

        # Absolute chaos.
        # Enumerate the entities (vertices), sort by if they're in an answer (true relation),
        # then filter their mentions by whether tokens from those mentions have already been seen.
        # Basically, we're trying to handle the case that a mention is assigned to multiple entities
        # as gracefully as possible.
        # self.vertexSet = [Entity.new([Mention(**m, text=sents[m['sent_id']][m['pos'][0]:m['pos'][1]], id=next(c)) for m in e if self.check(m)], ent=i) for i, e in 
        #                   sorted(enumerate(vertexSet), key=lambda x: (0, x[0]) if x[0] in self.answerset else (1, x[0]))]
        
        self.vertexSet = [None]*len(vertexSet)
        for i, e in sorted(enumerate(vertexSet), key=lambda x: (0, x[0]) if x[0] in self.answerset else (1, x[0])):
            en = []
            for m in e:
                if self.check(m):
                    if "text" in m:
                        en.append(Mention(**m, id=next(c)))
                    else:
                        en.append(Mention(**m, text=sents[m['sent_id']][m['pos'][0]:m['pos'][1]], id=next(c)))
            self.vertexSet[i] = Entity.new(en, i)
            self.vertexSet[i].doc = docnum

        # for e in self.vertexSet:
        #     e.doc = docnum
        # self.vertexSet.sort(key=lambda x: x.ent)
        self.title = title
        self.sents = sents

    def check(self, m: int):
        s = m['sent_id']

        covered = flatten(self.seen[s])
        candidates = list(range(*m['pos']))
        if any( w in covered for w in candidates ):
            return False
        self.seen[s].append(candidates)
        return True

    def print(self, sep=" "):
        print(sep.join(' '.join(sent) for sent in self.sents))

    def getMentionsInSentence(self, sentenceId: str) -> Iterator[Tuple[int, Mention]]:
        for i, e in enumerate(self.vertexSet):
            for m in e:
                seen = set()
                if m.sent_id == sentenceId:
                    if m.pos[0] not in seen:
                        seen.add(m.pos[0])
                        yield i, m

    def sentenceReplaceWithMarkers(self, sents=None) -> List[List[str]]:
        map = []
        if not sents:
            sents = self.sents
        for i, sent in enumerate(sents):
            mentions = list(self.getMentionsInSentence(i))
            # print(i, mentions)
            row = []
            for j, s in enumerate(sent):
                v = [s]
                for k, m in sorted(mentions, key=lambda x: x[1].pos[0]):
                    if j in range(*m.pos):
                        if j == m.pos[0]:
                            v = [f'<M_{m.pos[0]}>']
                        else:
                            v = []
                        break
                row.extend(v)
            map.append(row)
        return map

    def sentenceUpdate(self, apply=False):
        map = []
        for i, sent in enumerate(self.sentenceReplaceWithMarkers()):
            mentions = {m.pos[0]: m for _, m in self.getMentionsInSentence(i)}
            row = []
            for j, s in enumerate(sent):
                if s.startswith("<M_"):
                    m = mentions[int(s[3:-1])]
                    if apply:
                        m.pos[0] = len(row)
                    row.extend(m.text)
                    if apply:
                        m.pos[1] = len(row)
                    pass
                else:
                    row.append(s)
            map.append(row)
        return map
    
    def toDict(self):
        return {
            "title": self.title,
            "vertexSet": [[m.toDict() for m in e] for e in self.vertexSet],
            "labels": [lab.toDict() for lab in self.labels],
            "sents": self.sents
        }
    
    def toJson(self):
        # def default(obj):
        #     if isinstance(Mention):
        #         return obj.toDict()
            
        return json.dumps(self.toDict(), ensure_ascii=False)

In [None]:
docred_dev = ''
with open(json_path, 'r', encoding='utf8') as docred_dev_json:
    docred_dev = json.load(docred_dev_json)

# Schema: List[{vertexSet, labels: Label, title: str, sents: SentenceList}]

d = Document(docnum=0, **docred_dev[0])
for r in d.sents:
    print(r)
print("---------------------------------")
em = d.sentenceReplaceWithMarkers()
for r in em:
    print(r)

print("---------------------------------")
em = d.sentenceUpdate()
for r in em:
    print(r)

print("---------------------------------")
em = d.sentenceUpdate()
for r in em:
    print(r)

['Lark', 'Force', 'was', 'an', 'Australian', 'Army', 'formation', 'established', 'in', 'March', '1941', 'during', 'World', 'War', 'II', 'for', 'service', 'in', 'New', 'Britain', 'and', 'New', 'Ireland', '.']
['Under', 'the', 'command', 'of', 'Lieutenant', 'Colonel', 'John', 'Scanlan', ',', 'it', 'was', 'raised', 'in', 'Australia', 'and', 'deployed', 'to', 'Rabaul', 'and', 'Kavieng', ',', 'aboard', 'SS', 'Katoomba', ',', 'MV', 'Neptuna', 'and', 'HMAT', 'Zealandia', ',', 'to', 'defend', 'their', 'strategically', 'important', 'harbours', 'and', 'airfields', '.']
['The', 'objective', 'of', 'the', 'force', ',', 'was', 'to', 'maintain', 'a', 'forward', 'air', 'observation', 'line', 'as', 'long', 'as', 'possible', 'and', 'to', 'make', 'the', 'enemy', 'fight', 'for', 'this', 'line', 'rather', 'than', 'abandon', 'it', 'at', 'the', 'first', 'threat', 'as', 'the', 'force', 'was', 'considered', 'too', 'small', 'to', 'withstand', 'any', 'invasion', '.']
['Most', 'of', 'Lark', 'Force', 'was', 'captu

In [4]:
print(d.vertexSet)

[[Mention(name='Lark Force', pos=[0, 2], sent_id=0, type='ORG', text=['Lark', 'Force'], id=0), Mention(name='Lark Force', pos=[2, 4], sent_id=3, type='ORG', text=['Lark', 'Force'], id=1), Mention(name='Lark Force', pos=[3, 5], sent_id=4, type='ORG', text=['Lark', 'Force'], id=2)], [Mention(name='Australian Army', pos=[4, 6], sent_id=0, type='ORG', text=['Australian', 'Army'], id=3)], [Mention(name='March 1941', pos=[9, 11], sent_id=0, type='TIME', text=['March', '1941'], id=4)], [Mention(name='World War II', pos=[12, 15], sent_id=0, type='MISC', text=['World', 'War', 'II'], id=5)], [Mention(name='New Britain', pos=[18, 20], sent_id=0, type='LOC', text=['New', 'Britain'], id=14)], [Mention(name='New Ireland', pos=[21, 23], sent_id=0, type='LOC', text=['New', 'Ireland'], id=15)], [Mention(name='John Scanlan', pos=[6, 8], sent_id=1, type='PER', text=['John', 'Scanlan'], id=6)], [Mention(name='Australia', pos=[13, 14], sent_id=1, type='LOC', text=['Australia'], id=7)], [Mention(name='Rabau

In [5]:
d: Document = Document(docnum=0, **docred_dev[0])


In [None]:
devDocuments = [Document(docnum=i, **d) for i, d in enumerate(docred_dev)]
with open(json_path, 'r', encoding='utf8') as docred_dev_json:
    devDocumentsSafe = [Document(docnum=i, **d) for i, d in enumerate(json.load(docred_dev_json))]


In [7]:
dn = 13
d: Document = devDocumentsSafe[dn]

for i, v in enumerate(d.vertexSet):
    print(i, v)

d.print(sep='\n')

0 [Mention(name='Delaware General Assembly', pos=[1, 4], sent_id=0, type='ORG', text=['Delaware', 'General', 'Assembly'], id=0), Mention(name='Assembly', pos=[30, 31], sent_id=2, type='ORG', text=['Assembly'], id=1)]
1 [Mention(name='U.S.', pos=[9, 10], sent_id=0, type='LOC', text=['U.S.'], id=2)]
2 [Mention(name='Delaware', pos=[12, 13], sent_id=0, type='LOC', text=['Delaware'], id=3), Mention(name='Delaware', pos=[8, 9], sent_id=2, type='LOC', text=['Delaware'], id=4)]
3 [Mention(name='Delaware Senate', pos=[8, 10], sent_id=1, type='ORG', text=['Delaware', 'Senate'], id=5), Mention(name='Senate', pos=[19, 20], sent_id=6, type='ORG', text=['Senate'], id=6)]
4 [Mention(name='21', pos=[11, 12], sent_id=1, type='NUM', text=['21'], id=11)]
5 [Mention(name='Delaware House of Representatives', pos=[15, 19], sent_id=1, type='ORG', text=['Delaware', 'House', 'of', 'Representatives'], id=7)]
6 [Mention(name='41', pos=[20, 21], sent_id=1, type='NUM', text=['41'], id=12)]
7 [Mention(name='Legisl

In [None]:
def checkForErrors(docSet: List[Document]):
    # 1: All mentions map to themselves
    for n, d in enumerate(docSet):
        for e, v in enumerate(d.vertexSet):
            for m in v:
                if d.sents[m.sent_id][m.pos[0]:m.pos[1]] != m.text:
                    print(f"{n}@{m.sent_id}[{m.pos[0]}:{m.pos[1]}]: ({e}){d.sents[m.sent_id][m.pos[0]:m.pos[1]]} != {m.text}")
                    print(list(d.getMentionsInSentence(m.sent_id)))
                    print(list(devDocumentsSafe[n].getMentionsInSentence(m.sent_id)))
                    print(" ".join(d.sents[m.sent_id]))
                    print(" ".join(devDocumentsSafe[n].sents[m.sent_id]))
                    print()
                    print()
                # assert d.sents[m.sent_id][m.pos[0]:m.pos[1]] == m.text, f"{n}@{m.sent_id}[{m.pos[0]}:{m.pos[1]}]: {d.sents[m.sent_id][m.pos[0]:m.pos[1]]} != {m.text}"
    #2: All entities which appear in relations have mentions
    for n, d in enumerate(docSet):
        for lab in d.labels:
            if len(d.vertexSet[lab.h]) == 0:
                print(f"{n}, entity {lab.h} in {lab.toDict()}")
            if len(d.vertexSet[lab.t]) == 0:
                print(f"{n}, entity {lab.t} in {lab.toDict()}")
            # assert len(d.vertexSet[lab.h]) > 0, f"{n}, entity {lab.h} in {lab.toDict()}"
            # assert len(d.vertexSet[lab.t]) > 0, f"{n}, entity {lab.t} in {lab.toDict()}"
            

In [9]:
checkForErrors(devDocumentsSafe)

97, entity 9 in {'r': 'P17', 'h': 9, 't': 3, 'evidence': [0, 3]}
211, entity 16 in {'r': 'P577', 'h': 15, 't': 16, 'evidence': [7]}
241, entity 19 in {'r': 'P17', 'h': 18, 't': 19, 'evidence': [5, 6]}
241, entity 19 in {'r': 'P276', 'h': 18, 't': 19, 'evidence': [5, 6]}
242, entity 1 in {'r': 'P17', 'h': 1, 't': 5, 'evidence': [0, 1]}
266, entity 6 in {'r': 'P27', 'h': 6, 't': 2, 'evidence': [0]}
337, entity 2 in {'r': 'P570', 'h': 0, 't': 2, 'evidence': [0]}


In [None]:
entities_by_type = defaultdict(list)

# random.seed = 0

for d in devDocuments:
    for e in d.vertexSet:
        if e.type:
            entities_by_type[e.type].append(e)

# print(entities_by_type['MISC'])

for t in sorted(entities_by_type.keys()):
    print(t)
    random.shuffle(entities_by_type[t])

# print(entities_by_type['MISC'])
# print(entities_by_type['MISC'])

for t in sorted(entities_by_type.keys()):
    print(t, entities_by_type[t][0])


LOC
MISC
NUM
ORG
PER
TIME
LOC [Mention(name='Europe', pos=[10, 11], sent_id=10, type='LOC', text=['Europe'], id=17)]
MISC [Mention(name='The Bride with White Hair 2', pos=[22, 28], sent_id=5, type='MISC', text=['The', 'Bride', 'with', 'White', 'Hair', '2'], id=10)]
NUM [Mention(name='Eleven', pos=[0, 1], sent_id=4, type='NUM', text=['Eleven'], id=20)]
ORG [Mention(name='Old Man Luedecke', pos=[43, 46], sent_id=5, type='ORG', text=['Old', 'Man', 'Luedecke'], id=25)]
PER [Mention(name='Medan', pos=[5, 6], sent_id=0, type='PER', text=['Medan'], id=1), Mention(name='Medan', pos=[0, 1], sent_id=1, type='PER', text=['Medan'], id=2)]
TIME [Mention(name='25 October 2008', pos=[4, 7], sent_id=2, type='TIME', text=['25', 'October', '2008'], id=11)]


In [None]:
def mentionSwap(oldMent: Mention, newMent: Mention, newType: str):
    return Mention(name=newMent.name, pos=oldMent.pos[:], sent_id=oldMent.sent_id, type=newType, text=newMent.text, id=oldMent.id)


def entitySwap(document:Document, entities: Dict[str, List[Entity]]):
    print("BEFORE:")
    document.print(sep='\n')
    # Step 1: go through entities one by one
    for i, e in enumerate(document.vertexSet):
        if (e.type == 'NUM') or (e.type == 'TIME') or e.type is None:
            # print("SKIP")
            continue
        print(f"{e.type}: {e[0].name}")
        _new = []
        old_ent = document.vertexSet[i]
        # Step 2: select a new entity
        new_ent = entities[e.type].pop(0)
        # Step 3: replace mention 1 with mention 1.
        old_m0 = old_ent[0]
        new_m0 = new_ent[0]
        _new.append(mentionSwap(old_m0, new_m0, e.type))
        # Step 4: For all other mentions of original entity,
        for old_m in old_ent[1:]:
            # randomly choose a mention of the new entity and replace.
            _new.append(mentionSwap(old_m, random.choice(new_ent), e.type))
        # print(f"{e.type}: {e[0].name}")
        assert len(old_ent) == len(_new)
        document.vertexSet[i] = Entity(_new)
    # Step ?: The positions will be off. Maybe update those?

    print()
    print("AFTER:")
    for i, e in enumerate(document.vertexSet):
        print(f"{e.type}: {e[0].name if e.type else '<EMPTY>'}")
    document.sents = document.sentenceUpdate(apply=True)
    
    
# print(entities_by_type.keys())

for d in devDocuments:
    entitySwap(d, entities_by_type)
    d.print()



BEFORE:
Lark Force was an Australian Army formation established in March 1941 during World War II for service in New Britain and New Ireland .
Under the command of Lieutenant Colonel John Scanlan , it was raised in Australia and deployed to Rabaul and Kavieng , aboard SS Katoomba , MV Neptuna and HMAT Zealandia , to defend their strategically important harbours and airfields .
The objective of the force , was to maintain a forward air observation line as long as possible and to make the enemy fight for this line rather than abandon it at the first threat as the force was considered too small to withstand any invasion .
Most of Lark Force was captured by the Imperial Japanese Army after Rabaul and Kavieng were captured in January 1942 .
The officers of Lark Force were transported to Japan , however the NCOs and men were unfortunately torpedoed by the USS Sturgeon while being transported aboard the Montevideo Maru .
Only a handful of the Japanese crew were rescued , with none of the betw

In [12]:
d.print()

John Martin ( born October 15 , 1942 ) is an White Mountains jurist . He served in both houses of the Republican as a member of the U.S. Navy and was most recently an Associate Justice of the Jewish . John Martin was born in Talkeetna Ranger Station in 1942 . He grew up on his family 's dairy farm near Talkeetna Ranger Station . As a teenager , he raised purebred UK hogs to finance his college education . He earned a bachelor of arts degree in economics , political science , and history in 1963 from Supreme Tribunal of Justice . In 1966 , he also earned a law degree from the Jews . John Martin owns a cattle farm in San Bernardino National Wildlife Refuge , near his childhood home . John Martin and his wife George Harrison have three children and four grandchildren .


In [13]:
dn = 13
d: Document = Document(docnum=dn, **docred_dev[dn])

for i, v in enumerate(d.vertexSet):
    print(i, v)

d.print(sep='\n')

0 [Mention(name='Delaware General Assembly', pos=[1, 4], sent_id=0, type='ORG', text=['Delaware', 'General', 'Assembly'], id=0), Mention(name='Assembly', pos=[30, 31], sent_id=2, type='ORG', text=['Assembly'], id=1)]
1 [Mention(name='U.S.', pos=[9, 10], sent_id=0, type='LOC', text=['U.S.'], id=2)]
2 [Mention(name='Delaware', pos=[12, 13], sent_id=0, type='LOC', text=['Delaware'], id=3), Mention(name='Delaware', pos=[8, 9], sent_id=2, type='LOC', text=['Delaware'], id=4)]
3 [Mention(name='Delaware Senate', pos=[8, 10], sent_id=1, type='ORG', text=['Delaware', 'Senate'], id=5), Mention(name='Senate', pos=[19, 20], sent_id=6, type='ORG', text=['Senate'], id=6)]
4 [Mention(name='21', pos=[10, 11], sent_id=1, type='NUM', text=['with'], id=11)]
5 [Mention(name='Delaware House of Representatives', pos=[15, 19], sent_id=1, type='ORG', text=['Delaware', 'House', 'of', 'Representatives'], id=7)]
6 []
7 [Mention(name='Legislative Hall', pos=[3, 5], sent_id=2, type='LOC', text=['Legislative', 'Hal

In [14]:
devDocuments[2].vertexSet

[[Mention(name='Ninth Judicial Circuit Court', pos=[0, 4], sent_id=4, type='ORG', text=['Ninth', 'Judicial', 'Circuit', 'Court'], id=0),
  Mention(name='Ninth Judicial Circuit Court', pos=[0, 4], sent_id=0, type='ORG', text=['Ninth', 'Judicial', 'Circuit', 'Court'], id=1),
  Mention(name='Ninth Judicial Circuit Court', pos=[3, 7], sent_id=5, type='ORG', text=['Ninth', 'Judicial', 'Circuit', 'Court'], id=2)],
 [Mention(name='Afghan', pos=[6, 7], sent_id=0, type='LOC', text=['Afghan'], id=8)],
 [Mention(name='American', pos=[16, 17], sent_id=0, type='LOC', text=['American'], id=3)],
 [Mention(name='1', pos=[5, 6], sent_id=1, type='ORG', text=['1'], id=4)],
 [Mention(name='1st of April 2006', pos=[8, 12], sent_id=2, type='TIME', text=['1st', 'of', 'April', '2006'], id=9)],
 [Mention(name='Bailey County', pos=[14, 16], sent_id=2, type='LOC', text=['Bailey', 'County'], id=5)],
 [Mention(name='University of Otago', pos=[16, 19], sent_id=3, type='ORG', text=['University', 'of', 'Otago'], id=1

### Errors in DocShRED
While I don't like that there are some errors, they're difficult to avoid.
They are by no means unavoidable, but the time and effort needed to circumvent them is too high.
Basically, what it boild down to is that there are errors in the entity recognition for DocRED.
Some mentions are attributed to multiple unique entities, which is clearly wrong.
For such a mention, the entity with the lowest index value was chosen to map it to.
For example, consider the following two conflicting mentions in Document 57, sentence 4:
(14, Mention(name='the United States Navy', pos=[7, 11], sent_id=4, type='ORG', text=['the', 'United', 'States', 'Navy'])),
(15, Mention(name='United States Navy', pos=[8, 11], sent_id=4, type='ORG', text=['United', 'States', 'Navy']))
The mention "the United States Navy" would be kept for entity 14, but the mention for "United States Navy" for entity 15 would be removed.
Unfortunately, the answer set only lists 15 as having valid relations, while 14 would be equally as correct in any of the relations... It's really a shame.

We partially addressed this by handling entities in relations first when looking for overlaps, but there are still a few bad cases left.
The following lists all triples present in DocRED that are un-extractable in DocShRED.


In [15]:
checkForErrors(devDocuments)

97, entity 9 in {'r': 'P17', 'h': 9, 't': 3, 'evidence': [0, 3]}
211, entity 16 in {'r': 'P577', 'h': 15, 't': 16, 'evidence': [7]}
241, entity 19 in {'r': 'P17', 'h': 18, 't': 19, 'evidence': [5, 6]}
241, entity 19 in {'r': 'P276', 'h': 18, 't': 19, 'evidence': [5, 6]}
242, entity 1 in {'r': 'P17', 'h': 1, 't': 5, 'evidence': [0, 1]}
266, entity 6 in {'r': 'P27', 'h': 6, 't': 2, 'evidence': [0]}
337, entity 2 in {'r': 'P570', 'h': 0, 't': 2, 'evidence': [0]}


In [None]:
import os
new_data_set = data_set.replace('red', 'shred')
os.makedirs(f"data/{new_data_set}", exist_ok=True)

with open(f"data/{new_data_set}/{data_slice}.json", 'w', encoding="utf8") as dshred:
    print("[", file=dshred)
    print(",\n".join([d.toJson() for d in devDocuments]), file=dshred)
    print("]", file=dshred)