In [1]:
%config IPCompleter.greedy=True

### TEI Preprocessing
We preprocess a PDF of our source material: *Graph Representation Learning* by Hamilton, available [here](https://www.cs.mcgill.ca/~wlh/grl_book/files/GRL_Book.pdf).

Text extraction is done following Alpizar-Chacon & Sosnovsky, 2020.

Their data pipeline is available as a web service at https://intextbooks.science.uu.nl/.

The code for the TEI pipeline is available on Github ([link](https://github.com/intextbooks/ITCore?tab=readme-ov-file)), but requires the deployment and coordination of multiple software components. Specifically, it requires MySQL, Apache Jena, and a partial local copy of DBPedia. We use the web service to avoid the effort of deploying the extraction pipeline locally.

We optionally enabled "identify index terms in text" and "link entities to DBPedia" using the category "https://<span/>dbpedia.org/page/Category:Technology."

### XML Data-Munging
We process the XML output of the TEI pipeline as described in Yao 2023.

#### install stuff

In [2]:
!pip install xmltodict==0.13.0



In [3]:
import itertools
import re
import xmltodict

#### ingest xml

In [4]:
# f = open("DLB_TEI/teiModel.xml")
f = open("DLB_TEI/teiModel.xml")

book = xmltodict.parse(
    f.read(),
    xml_attribs=True,
)

f.close()

#### get section headings from table of contents

In [5]:
# crawl the XML tree in search of items with text
# return a list of table of contents item headings
def grab_toc_headings(node):
    items = []
    if type(node) is dict:
        keys = node.keys()
        if "#text" in keys:
            tup = (
                node["#text"],
                node["ref"].get("@target", "NO_TARGET"),
                
            )
            items.append(tup)
        if "item" in keys:
            items += grab_toc_headings(node["item"])
        if "list" in keys:
            items += grab_toc_headings(node["list"])
    if type(node) is list:
        for elem in node:
            items += grab_toc_headings(elem)
    return items

# remove section numbers from heading text like "1.2.3 foo bar section"
def strip_toc_headings(lst):
    return [
        # (heading.split(" ", maxsplit=1)[-1].strip(), ref) for heading, ref in lst
        (re.split(r"\d+", heading)[-1].strip(), ref) for heading, ref in lst
    ]

table_of_contents = book["TEI"]["front"]["div"]
toc_headings = grab_toc_headings(table_of_contents)
clean_toc_headings = strip_toc_headings(toc_headings)

In [6]:
toc_headings[:4]

[('1 introduction', 'seg_1'),
 ('i applied math and machine learning basics', 'NO_TARGET'),
 ('ii deep networks: modern practices', 'NO_TARGET'),
 ('1.1 who should read this book?', 'seg_3')]

In [7]:
clean_toc_headings[:4]

[('introduction', 'seg_1'),
 ('i applied math and machine learning basics', 'NO_TARGET'),
 ('ii deep networks: modern practices', 'NO_TARGET'),
 ('who should read this book?', 'seg_3')]

#### get index entries
For some books, TEI fails to distinguish the *bibliography* section and the *index* section. It just combines papers, citations, and index terms. We filter these out. It also tends to interpret page ranges in bib citations ("pages 177-228") as if they were indexes back into the book text.

We use a simple heuristic of checking the string length of items that TEI identifies as index entries and set a cutoff between the point where the actual index items end and the bibliographic citations begin.

This is not completely effective, because TEI also has trouble with two-column layouts that are common in book indexes. For about 15% of the items, it produces combinations like "Point estimator, 119 Reinforcement learning." This results in several relative long, garbled index items.

For the Deep Learning Book, we just set a heuristic of 65 chars. In this book, it separates index items from bib citations. In other cases, it might also filter out extra-long garbled index items.

In [8]:
# remove bibliography citations that TEI mixed into the index for some reason
# we just use excessive length as the heuristic
# every index item for the Deep Learning Book is under 70 characters
def remove_overlong_items(lst):
    return [
        elem for elem in lst
        if type(elem) is dict and len(elem.get("#text", "")) < 70
    ]

# filter out stuff like "Object detection, 444 Probability distribution"
def filter_combined_items(lst):
    return [
        elem for elem in lst
        if len(re.split(r",\s+\d+", elem.get("#text", ""))) == 1
    ]

# index tuples are ("foo", set(seg_id...))
def grab_index_tuples(lst):
    tuples = []
    for elem in lst:
        elem_name = elem["#text"]
        if "ref" in elem.keys():
            ref = elem["ref"]
            if type(ref) is dict:
                target = ref.get("@target", "NO_TARGET")
                tup = (elem_name, set([target]))
                tuples.append(tup)
            if type(ref) is list:
                targets = set(r.get("@target", "NO_TARGET") for r in ref)
                tup = (elem_name, targets)
                tuples.append(tup)
    return tuples

all_pairs = lambda lst: itertools.permutations(lst, 2)

# for index items that have a "FOO, see BAR"
# we make a dict of all pairs "foo=bar" and "bar=foo"
# all lowercase, for canonical lookups
def grab_index_aliases(lst):
    alias_dict = {}
    sameas_uri = "owl:sameAs"
    for elem in lst:
        elem_name = elem["#text"]
        ref = elem.get("seg", {}).get("ref", {})
        
        if type(ref) is dict and ref.get("@property", "") == sameas_uri:
            equivalents = map(str.lower, [
                elem_name,
                normalize_prop_uri(ref["@resource"]),
            ])
            alias_dict.update(all_pairs(equivalents))

        if type(ref) is list:
            equivalents = map(str.lower, [
                elem_name,
                *(
                    normalize_prop_uri(r["@resource"])
                    for r in ref
                    if r.get("@property", "") == sameas_uri
                ),
            ])
            alias_dict.update(all_pairs(equivalents))
    
    return alias_dict

# add the "FOO, see BAR" terms to the index dict
# with the same segments as BAR
def enrich_with_aliases(index_dict, alias_dict):
    index_dict_copy = dict(index_dict)
    keys = list(index_dict_copy.keys())
    keys_lower = list(map(str.lower, keys))
    for key, alias in alias_dict.items():
        if key not in keys_lower:
            matching_term = next((k for k in keys if k.lower() == alias), None)
            if matching_term != None:
                index_dict_copy[key] = index_dict[matching_term]
    return index_dict_copy

# property URIs look something like https://intextbooks.science.uu.nl/model/XXX/property_name
model_domain = book["TEI"]["teiHeader"]["fileDesc"]["publicationStmt"]["pubPlace"]
# model_id = book["TEI"]["@xml:id"]
model_id = book["TEI"]["teiHeader"]["fileDesc"]["titleStmt"]["title"].rsplit(" ", maxsplit=1)[-1]
model_property_uri_prefix = f"{model_domain}model/{model_id}/"
normalize_prop_uri = lambda s: s.removeprefix(model_property_uri_prefix).replace("_", " ").lower()
# normalize_prop_uri = lambda s: s

index_items = book["TEI"]["back"]["div"]["list"]["item"]
index_items_filtered = filter_combined_items(remove_overlong_items(index_items))
index_tuples = grab_index_tuples(index_items_filtered)
alias_dict = grab_index_aliases(index_items_filtered)
index_dict = dict(index_tuples)
index_dict_all = enrich_with_aliases(index_dict, alias_dict)

In [9]:
# [*itertools.islice(index_dict_all.items(), 20)]

In [10]:
# sorted([len(elem["#text"]) for elem in index_items if type(elem) is dict], reverse=True)[:50]

#### Assembling Book Data into Triples

In [11]:
from collections import defaultdict

segment2indexitem = defaultdict(set)
for item, segs in index_dict_all.items():
    for seg in segs:
        segment2indexitem[seg].add(item.lower())

lower_toc_headings = defaultdict(set)
for heading, seg in clean_toc_headings:
    if seg != "NO_TARGET":
        lower_toc_headings[f"the section about {heading.lower()}"].add(seg)

triples = []
for heading, segs in lower_toc_headings.items():
    for seg in segs:
        for index_item in segment2indexitem[seg]:
            triple = (heading.lower(), "contains", index_item.lower())
            triples.append(triple)


In [12]:
# triples from the book look like this
_tmp = '\t\n'.join(
    map(
        lambda tup: ', '.join(tup),
        triples[-10:],
    )
)
print(f"Num triples from book: {len(triples)}\nSample triples:\n{_tmp}")

Num triples from book: 542
Sample triples:
the section about directed generative nets, contains, ais	
the section about directed generative nets, contains, fully-visible bayes network	
the section about directed generative nets, contains, generative moment matching networks	
the section about directed generative nets, contains, nade	
the section about directed generative nets, contains, annealed importance sampling	
the section about directed generative nets, contains, moment matching	
the section about directed generative nets, contains, generative adversarial networks	
the section about directed generative nets, contains, lapgan	
the section about directed generative nets, contains, dcgan	
the section about directed generative nets, contains, approximate bayesian computation


### Wikidata enrichment
Use the Neo4J query API to look for matching entities in Wikidata.

#### Wikidata Query Notes
See:
- https://www.wikidata.org/wiki/Wikidata:SPARQL_tutorial
- https://www.mediawiki.org/wiki/Wikidata_Query_Service/User_Manual
- https://www.wikidata.org/wiki/Wikidata:SPARQL_query_service/queries

Notes:
- RDF resource description framework (w3c standard)
- OWL web ontology language
- subject-predicate-object
- <http://www.wikidata.org/entity/Q30> x 3 or wd:Q30  wdt:P36  wd:Q61 .
- wdt for truthy, props have a ranking of current truthiness
- subj and prop are uri's, value not necessarily

In [13]:
from operator import itemgetter
from urllib.request import Request, urlopen
from urllib.parse import urlencode
import json
import time

wikidata_url = "https://query.wikidata.org/sparql"

# limit for sparql queries
limit_results = 30

headers = {
    "User-Agent": "Mozilla/5.0",
    "Accept": "application/sparql-results+json",
}

# try to map string terms to wikidata entity URIs
# this is a very loose mapping: we don't control for ambiguous terms,
# or distinct things with the same name
# we just search for the term by its wikidata label,
# attempting to match on lowercase, CAPS, and Title Case
def get_sparql_for_entity_term_lookup(term):
    return {
        "query": f"""
            PREFIX wikibase: <http://wikiba.se/ontology#>
            
            SELECT DISTINCT ?item ?itemLabel ?itemDescription
            
            WHERE {{
              VALUES ( ?termAsis ?termTitle ?termLower ?termCaps )
              
              {{ (
                  "{term}"@en
                  "{term.title()}"@en
                  "{term.lower()}"@en
                  "{term.upper()}"@en
              ) }}
              
              {{ ?item rdfs:label ?termAsis }}
              UNION
              {{ ?item rdfs:label ?termTitle }}
              UNION
              {{ ?item rdfs:label ?termLower }}
              UNION
              {{ ?item rdfs:label ?termCaps }} .
              
              SERVICE wikibase:label {{
                bd:serviceParam wikibase:language "en" .
              }}
            }}
        """
    }

def get_sparql_for_entity_term_count_lookup(term):
    return {
        "query": f"""
            PREFIX wikibase: <http://wikiba.se/ontology#>
            
            SELECT DISTINCT (COUNT(?item) AS ?count)
            
            WHERE {{
              VALUES ( ?termAsis ?termTitle ?termLower ?termCaps )
              
              {{ (
                  "{term}"@en
                  "{term.title()}"@en
                  "{term.lower()}"@en
                  "{term.upper()}"@en
              ) }}
              
              {{ ?item rdfs:label ?termAsis }}
              UNION
              {{ ?item rdfs:label ?termTitle }}
              UNION
              {{ ?item rdfs:label ?termLower }}
              UNION
              {{ ?item rdfs:label ?termCaps }} .
              
              SERVICE wikibase:label {{
                bd:serviceParam wikibase:language "en" .
              }}
            }}
        """
    }

# For entity E, query all (E Relation Object),
# filtering for English-language object and relation labels.
# E has to be a full URI, e.g. "http://www.wikidata.org/entity/Q3757".
def get_sparql_for_subject_position_triples(entity):
    term, uri, label = itemgetter("term", "uri", "label")(entity)
    return {
        "query": f"""
            PREFIX wikibase: <http://wikiba.se/ontology#>

            SELECT ?subjectLabel ?predicate ?predicateEntityLabel ?object ?objectLabel {{

                # suggested here: https://www.wikidata.org/wiki/Wikidata:SPARQL_query_service/queries#Adding_labels_for_properties
                hint:Query hint:optimizer "None" .

                VALUES ( ?subjectLabel ) {{ (
                    "{label}"
                ) }}

                <{uri}> ?predicate ?object .

                # directClaim associates predicates (WDT's) with entities (WD's), and only entities have labels
                ?predicateEntity wikibase:directClaim ?predicate .

                ?predicateEntity rdfs:label ?predicateEntityLabel .

                ?object rdfs:label ?objectLabel .

                FILTER (
                    lang(?predicateEntityLabel) = "en" &&
                    lang(?objectLabel) = "en" 
                )

            }}
            LIMIT {limit_results}
        """
    }

# For entity E, query all (Subject Relation Entity),
# filtering for English-language subject and relation labels,
# and EXCLUDING relation P921 "main subject,"
# which returns results that are too numerous and too specific.
# E has to be a full URI, e.g. "http://www.wikidata.org/entity/Q3757".
def get_sparql_for_object_position_triples(entity):
    term, uri, label = itemgetter("term", "uri", "label")(entity)
    return {
        "query": f"""
            PREFIX wikibase: <http://wikiba.se/ontology#>

            SELECT ?subject ?subjectLabel ?predicate ?predicateEntityLabel ?objectLabel {{

                # suggested here: https://www.wikidata.org/wiki/Wikidata:SPARQL_query_service/queries#Adding_labels_for_properties
                hint:Query hint:optimizer "None" .

                VALUES ( ?objectLabel ) {{ (
                    "{label}"
                ) }}

                ?subject ?predicate <{uri}> .

                # directClaim associates predicates (WDT's) with entities (WD's), and only entities have labels
                ?predicateEntity wikibase:directClaim ?predicate .

                ?predicateEntity rdfs:label ?predicateEntityLabel .

                ?subject rdfs:label ?subjectLabel .

                FILTER (
                    lang(?predicateEntityLabel) = "en" &&
                    lang(?subjectLabel) = "en"
                )

                MINUS {{
                    # not "X main_subject Y" there are too many overly specific ones  
                    ?subject wdt:P921 <{uri}> .
                }}

            }}
            LIMIT {limit_results}
        """
    }


def query_wikidata(sparql):
    body = urlencode(sparql).encode()

    def make_request():
        req = Request(url=wikidata_url, headers=headers, data=body)
        return urlopen(req)

    # respect being throttled by wikidata
    while (response := make_request()) and response.getcode() == 429:
        print(f"Got 429 Too Many Requests when querying for [{query}], backing off!")
        delay = response.getheader("Retry-After") or response.getheader("retry-after")
        delay_secs = 1.05 * float(delay)
        print(f"Wikidata tells us to try again after [{delay}] seconds, sleeping...")
        time.sleep(delay_secs)

    response_json = json.load(response)
    results = response_json["results"]["bindings"]
    
    # if len(results) == 0:
    #     print(f"Got 0 results for query {sparql['query']}!")

    return results


def get_subject_position_triples_for_entity(entity):
    sparql = get_sparql_for_subject_position_triples(entity)
    results = query_wikidata(sparql)
    return [
        (
            r["subjectLabel"]["value"].lower(),
            r["predicateEntityLabel"]["value"].lower(),
            r["objectLabel"]["value"].lower(),
        ) for r in results
    ]

def get_object_position_triples_for_entity(entity):
    sparql = get_sparql_for_object_position_triples(entity)
    results = query_wikidata(sparql)
    return [
        (
            r["subjectLabel"]["value"].lower(),
            r["predicateEntityLabel"]["value"].lower(),
            r["objectLabel"]["value"].lower(),
        ) for r in results
    ]

def get_entities_from_term(term):
    sparql = get_sparql_for_entity_term_lookup(term)
    results = query_wikidata(sparql)
    return [
        {
            "term": term,
            "uri": result["item"]["value"],
            "label": result["itemLabel"]["value"],
            "description": result.get("itemDescription", {}).get("value", "NO_DESCRIPTION")
        }
        for result in results
    ]

def get_all_triples_for_term(term):
    entities = get_entities_from_term(term)
    return [
        *[
            triple
                for e in entities
                for triple in get_subject_position_triples_for_entity(e)
        ],
        *[
            triple
                for e in entities
                for triple in get_object_position_triples_for_entity(e)
        ],
    ]

def count_entities_for_term(term):
    sparql = get_sparql_for_entity_term_count_lookup(term)
    res = query_wikidata(sparql)
    return int(res[0]["count"]["value"])


In [14]:
get_all_triples_for_term("Ancestral sampling")
#get_object_position_triples_for_entity(get_entities_from_term("autoencoder")[0])

[]

In [15]:
get_object_position_triples_for_entity(get_entities_from_term("java")[0])

[('lou sebert', 'place of birth', 'java'),
 ('java', 'different from', 'java'),
 ('java', 'different from', 'java')]

In [16]:
count_entities_for_term("autoencoder")

2

In [17]:
# sorted([(k, count_entities_for_term(k)) for k in index_dict_all.keys()], key=lambda tup: tup[0])

### Putting the Dataset Together
We have the triples of (TOC section heading, "contains", index term). We now enrich this set of triples by looking up index terms in Wikidata. We add triples of (Entity, Relation, index term) and (index term, Relation, Entity) with an upper limit of 50 added per item per position.

In [18]:
terms = set()

for subj, verb, obj in triples:
    terms.add(subj)
    terms.add(obj)

terms = sorted(terms)
print(f"Num terms: {len(terms)}")

Num terms: 505


In [21]:
import pickle 

with open("wiki_triples.pickle", "wb") as f:
    pickle.dump(wiki_triples, f)

In [24]:
with open("wiki_triples.pickle", "rb") as f:
    print(pickle.load(f)[-20:])

[('information entropy', 'wikidata property example', 'normal distribution'), ('numerical differentiation', 'subclass of', 'differentiation'), ('numerical differentiation', 'subclass of', 'numerical algorithm'), ('numerical differentiation', 'maintained by wikiproject', 'wikiproject mathematics'), ('partition function', 'maintained by wikiproject', 'wikiproject mathematics'), ('partition function', 'maintained by wikiproject', 'wikiproject mathematics'), ('partition function', 'subclass of', 'dimensionless quantity'), ('partition function', 'instance of', 'integer sequence'), ('partition function', 'described by source', 'encyclopædia britannica 11th edition'), ('partition function', 'instance of', 'integer-valued function'), ('partition function', 'maintained by wikiproject', 'wikiproject mathematics'), ('perceptron', 'instance of', 'algorithm'), ('perceptron', 'discoverer or inventor', 'frank rosenblatt'), ('perceptron', 'subclass of', 'feedforward neural network'), ('perceptron', 'd

In [27]:
import os
os.path.isfile("wiki_triples.pickle")

True

In [31]:
import pickle
import os

cap_per_term = 50
dump_file = "wiki_triples.pickle"

def load_triples():
    with open(dump_file, "rb") as f:
        data = pickle.load(f)
    return data

def slowly_scrape_triples_from_wiki(terms):
    for i, term in enumerate(terms):
        capped = False
        if term in ("adam", "independence"):
            continue
        extras = get_all_triples_for_term(term)
        if len(extras) > cap_per_term:
            extras = extras[:cap_per_term]
            capped = True
        print(f"[{i:-3d}/{len(terms)}] Got {len(extras)} triples for term {term}. {'(CAPPED)' if capped else ''}")
        wiki_triples.extend(extras)
    
    with open(dump_file, "wb") as f:
        print(f"Saving data to {dump_file}.")
        pickle.dump(wiki_triples, f)
    
    return wiki_triples

if os.path.isfile(dump_file):
    print("Wiki triples file present, loading from file.")
    wiki_triples = load_triples()
else:
    print("Wiki triples pickle file absent, starting scrape.")
    wiki_triples = slowly_scrape_triples_from_wiki()

print(f"Total Wiki triples: {len(wiki_triples)}")
wiki_triples[:10]

Wiki triples file present, loading from file.
Total Wiki triples: 3666


[('accuracy', 'instance of', 'music track with vocals'),
 ('accuracy', 'performer', 'the cure'),
 ('accuracy', 'genre', 'post-punk'),
 ('accuracy', 'different from', 'accuracy'),
 ('accuracy', 'producer', 'chris parry'),
 ('accuracy', 'recording or performance of', 'accuracy'),
 ('accuracy', 'distribution format', 'music streaming'),
 ('accuracy', 'recorded at studio or venue', 'morgan studios'),
 ('accuracy', 'contributor to the creative work or subject', 'robert smith'),
 ('accuracy',
  'contributor to the creative work or subject',
  'michael dempsey')]