In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spacy
from spacy.tokens import Doc, Span
from tqdm import tqdm
from datasets import Dataset

from llm_ol.dataset import data_model

In [None]:
G = data_model.load_graph("out/data/wikipedia/v2/train_eval_split/train_graph.json")

concepts = set()
pages = {}
for _, data in G.nodes(data=True):
    for page in data["pages"]:
        pages[page["id"]] = page
    concepts.add(data["title"])
pages = list(pages.values())

In [None]:
nlp = spacy.load("en_core_web_sm", enable=["tagger", "attribute_ruler", "lemmatizer"])

In [None]:
concept_docs = list(tqdm(nlp.pipe(concepts, n_process=16), total=len(concepts)))
page_docs = list(
    tqdm(nlp.pipe([page["abstract"] for page in pages], n_process=16), total=len(pages))
)

In [None]:
# Make a trie for the concepts
trie = {}
for doc in concept_docs:
    node = trie
    for token in doc:
        lemma = token.lemma_
        if lemma not in node:
            node[lemma] = {}
        node = node[lemma]
    node[""] = doc.text

In [None]:
matched_concepts = set()


def match_concept(span: Span, trie) -> Span | None:
    i = 0
    for token in span:
        lemma = token.lemma_
        if lemma in trie:
            trie = trie[lemma]
            i += 1
            if "" in trie:
                matched_concepts.add(trie[""])
                return span[:i]
        else:
            return None
    return None


def find_concepts(doc: Doc, trie) -> list[Span]:
    concepts = []
    for i in range(len(doc)):
        matching = match_concept(doc[i:], trie)
        if matching is not None:
            concepts.append(matching)
    return concepts


# Go through the pages and tag them with the concepts
page_concepts = []
for doc in tqdm(page_docs):
    page_concepts.append(find_concepts(doc, trie))

print(f"Matched {len(matched_concepts)}/{len(concepts)} concepts")

In [None]:
# Make a dataset
data = {
    "tokens": [],
    "ner_tags": [],
}
for doc, concept_spans in zip(tqdm(page_docs), page_concepts):
    tokens = [token.text for token in doc]
    tags = ["O"] * len(tokens)
    for span in spacy.util.filter_spans(concept_spans):
        tags[span.start] = "B-MISC"
        for i in range(span.start + 1, span.end):
            tags[i] = "I-MISC"
    data["tokens"].append(tokens)
    data["ner_tags"].append(tags)
ds = Dataset.from_dict(data, features=)