In [None]:
from pathlib import Path

import torch
import numpy as np
from transformers import (
    TokenClassificationPipeline,
    AutoModelForTokenClassification,
    AutoTokenizer,
)
from transformers.pipelines import AggregationStrategy
import spacy
from spacy import displacy

from llm_ol.dataset import data_model

torch.set_default_device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Define keyphrase extraction pipeline
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
    def __init__(self, model_name):
        super().__init__(
            model=AutoModelForTokenClassification.from_pretrained(model_name),
            tokenizer=AutoTokenizer.from_pretrained(model_name),
            aggregation_strategy=AggregationStrategy.SIMPLE,
            device=0 if torch.cuda.is_available() else -1,
        )

    def postprocess(self, *args, **kwargs):
        results = super().postprocess(*args, **kwargs)
        return np.unique([result.get("word").strip() for result in results])


# Load pipeline
model_name = "ml6team/keyphrase-extraction-kbir-inspec"
extractor = KeyphraseExtractionPipeline(model_name)

In [None]:
file_path = Path("out/data/wikipedia/v1/full/full_graph.json")

G = data_model.load_graph(file_path, depth=1)
seen = set()
titles = []
abstracts = []
for _, data in G.nodes(data=True):
    for page in data["pages"]:
        if page["title"] in seen:
            continue
        seen.add(page["title"])
        titles.append(page["title"])
        abstracts.append(page["abstract"])

In [None]:
keyphrases = set()

for abstract in abstracts:
    result = extractor(abstract)
    for keyphrase in result:
        keyphrases.add(keyphrase.lower())

In [None]:
with open("out/data/wikipedia/v1/full/keyphrases.txt", "w") as f:
    for keyphrase in keyphrases:
        f.write(keyphrase + "\n")

In [None]:
import regex as re


class HearstPattern:
    def findall(self, text: str) -> list[tuple[str, str]]:
        raise NotImplementedError


class SuchAsPattern(HearstPattern):
    def __init__(self):
        self.pattern = re.compile(
            r"(?P<src>NP_[\w'-]+),? such as ((?P<tgt>NP_[\w'-]+)( |,|and|or)*)+"
        )

    def findall(self, text: str):
        for m in self.pattern.finditer(text):
            src = m.group("src")
            for tgt in m.captures("tgt"):
                yield src, tgt


class AsPattern(HearstPattern):
    def __init__(self):
        self.pattern = re.compile(
            r"(?P<src>NP_[\w'-]+),? as ((?P<tgt>NP_[\w'-]+)( |,|and|or)*)+"
        )

    def findall(self, text: str):
        for m in self.pattern.finditer(text):
            src = m.group("src")
            for tgt in m.captures("tgt"):
                yield src, tgt


class AndOtherPattern(HearstPattern):
    def __init__(self):
        self.pattern = re.compile(
            r"((?P<tgt>NP_[\w'-]+)( |,|and|or)*)+(and|or) other (?P<src>NP_[\w'-]+)",
        )

    def findall(self, text: str):
        for m in self.pattern.finditer(text):
            src = m.group("src")
            for tgt in m.captures("tgt"):
                yield src, tgt


class IncludePattern(HearstPattern):
    def __init__(self):
        self.pattern = re.compile(
            r"(?P<src>NP_[\w'-]+),? include ((?P<tgt>NP_[\w'-]+)( |,|and|or)*)+"
        )

    def findall(self, text: str):
        for m in self.pattern.finditer(text):
            src = m.group("src")
            for tgt in m.captures("tgt"):
                yield src, tgt


class EspeciallyPattern(HearstPattern):
    def __init__(self):
        self.pattern = re.compile(
            r"(?P<src>NP_[\w'-]+),? especially ((?P<tgt>NP_[\w'-]+)( |,|and|or)*)+"
        )

    def findall(self, text: str):
        for m in self.pattern.finditer(text):
            src = m.group("src")
            for tgt in m.captures("tgt"):
                yield src, tgt


HEARST_PATTERNS = [
    r"(NP_\w+,? such as (NP_\w+ ?(, )?(and |or )?)+)",
    r"(such NP_\w+,? as (NP_\w+ ?(, )?(and |or )?)+)",
    r"((NP_\w+ ?(, )?)+(and |or )?other NP_\w+)",
    r"(NP_\w+,? include (NP_\w+ ?(, )?(and |or )?)+)",
    r"(NP_\w+,? especially (NP_\w+ ?(, )?(and |or )?)+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?any other NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?some other NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?be a NP_\w+)",
    # r"(NP_\w+,? like (NP_\w+ ?,? (and |or )?)*NP_\w)",
    # r"such (NP_\w+,? as (NP_\w+ ?,? (and |or )?)+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?like other NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?one of the NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?one of these NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?one of those NP_\w+)",
    # r"example of (NP_\w+,? be (NP_\w+ ?,? (and |or )?)+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?be example of NP_\w+)",
    # r"(NP_\w+,? for example,? (NP_\w+ ?(, )?(and |or )?)+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?which be call NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?which be name NP_\w+)",
    # r"(NP_\w+,? mainly (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? mostly (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? notably (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? particularly (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? principally (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? in particular (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? except (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? other than (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? e.g.,? (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+ \( (e.g.|i.e.),? (NP_\w+ ?,? (and |or )?)+(\. )?\))",
    # r"(NP_\w+,? i.e.,? (NP_\w+ ?,? (and |or )?)+)",
    # r"((NP_\w+ ?(, )?)+(and|or)? a kind of NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and|or)? kind of NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and|or)? form of NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?which look like NP_\w+)",
    # r"((NP_\w+ ?(, )?)+(and |or )?which sound like NP_\w+)",
    # r"(NP_\w+,? which be similar to (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? example of this be (NP_\w+ ?,? (and |or )?)+)",
    # r"(NP_\w+,? type (NP_\w+ ?,? (and |or )?)+)",
    # r"((NP_\w+ ?(, )?)+(and |or )? NP_\w+ type)",
    # r"(NP_\w+,? whether (NP_\w+ ?,? (and |or )?)+)",
    # r"(compare (NP_\w+ ?(, )?)+(and |or )?with NP_\w+)",
    # r"(NP_\w+,? compare to (NP_\w+ ?,? (and |or )?)+)",
    # # r"(NP_\w+,? among -PRON- (NP_\w+ ?,? (and |or )?)+)",
    # # r"((NP_\w+ ?(, )?)+(and |or )?as NP_\w+)",  <-- bad
    # r"(NP_\w+,?  (NP_\w+ ?,? (and |or )?)+ for instance)",
    # r"((NP_\w+ ?(, )?)+(and|or)? sort of NP_\w+)",
    # r"(NP_\w+,? which may include (NP_\w+ ?(, )?(and |or )?)+)",
]


nlp = spacy.load("en_core_web_sm", disable=["ner"])
nlp.add_pipe("merge_noun_chunks")
# ruler = nlp.add_pipe("entity_ruler")
# ruler.add_patterns(
#     [
#         {
#             "label": "NP",
#             "pattern": [
#                 # {"POS": "PRON", "OP": "{0,1}"},
#                 # {"POS": "ADJ", "OP": "*"},
#                 # {"POS": {"IN": ["NOUN", "PROPN"]}, "OP": "+"},
#                 {"POS": {"IN": ["NOUN", "PROPN"]}},
#             ],
#         }
#     ]
# )

# matcher = Matcher(nlp.vocab)
# matcher.add(
#     "NP is NP",
#     [
#         [
#             {"ENT_TYPE": "NP"},
#             {"LOWER": "such"},
#             {"LOWER": "as"},
#             {"ENT_TYPE": "NP", "OP": "+"},
#             {"ENT_TYPE": "NP"},
#         ],
#     ],
# )

In [None]:
from typing import Iterator, Tuple, Union

from spacy.tokens import Doc, Span
from spacy.symbols import NOUN, PROPN, PRON
from spacy.errors import Errors
from spacy.util import filter_spans


def noun_chunks(doclike: Union[Doc, Span]):
    """
    Detect base noun phrases from a dependency parse. Works on both Doc and Span.
    """
    labels = [
        "oprd",
        "nsubj",
        "dobj",
        "nsubjpass",
        "pcomp",
        "pobj",
        "dative",
        "appos",
        "attr",
        "ROOT",
    ]
    doc = doclike.doc  # Ensure works on both Doc and Span.
    if not doc.has_annotation("DEP"):
        raise ValueError(Errors.E029)
    np_deps = [doc.vocab.strings.add(label) for label in labels]
    conj = doc.vocab.strings.add("conj")
    np_label = doc.vocab.strings.add("NP")
    prev_end = -1
    for i, word in enumerate(doclike):
        if word.pos not in (NOUN, PROPN, PRON):
            continue
        # Prevent nested chunks from being produced
        if word.left_edge.i <= prev_end:
            continue
        if word.dep in np_deps:
            prev_end = word.i
            left_i = word.left_edge.i
            while left_i < word.i and doc[left_i].dep_ in ("det", "poss", "case"):
                left_i += 1
            yield Span(doc, left_i, word.i + 1, label=np_label)
        elif word.dep == conj:
            head = word.head
            while head.dep == conj and head.head.i < head.i:
                head = head.head
            # If the head is an NP, and we're coordinated to it, we're an NP
            if head.dep in np_deps:
                prev_end = word.i
                left_i = word.left_edge.i
                while left_i < word.i and doc[left_i].dep_ in ("det", "poss"):
                    left_i += 1
                yield Span(doc, word.left_edge.i, word.i + 1, label=np_label)


def merge_noun_chunks(doc: Doc) -> Doc:
    with doc.retokenize() as retokenizer:
        nps = list(noun_chunks(doc))
        for np in nps:
            attrs = {"tag": np.root.tag, "dep": np.root.dep}
            retokenizer.merge(np, attrs=attrs)  # type: ignore[arg-type]
    return doc

In [None]:
doc = nlp(abstracts[1])
doc = merge_noun_chunks(doc)
displacy.render(doc, style="dep", jupyter=True, options={"compact": True})

In [None]:
import tqdm

# text = "Programming languages such as Python, Java, and C++ are popular."

i = 2

patterns = [
    SuchAsPattern(),
    AsPattern(),
    AndOtherPattern(),
    IncludePattern(),
    EspeciallyPattern(),
]


def extract_hyponyms(text: str):
    doc = nlp(text)
    # doc = merge_noun_chunks(doc)

    new_text = []
    for token in doc:
        lemmatized = token.text if token.pos_ in ("NOUN", "PROPN") else token.lemma_
        if token.pos_ in ("NOUN", "PROPN"):
            new_text.append("NP_" + lemmatized.replace(" ", "_"))
            # text_with_ws = token.text_with_ws
            # if text_with_ws.endswith(" "):
            #     text_replaced = text_with_ws[:-1].replace(" ", "_") + " "
            # else:
            #     text_replaced = text_with_ws.replace(" ", "_")
            # new_text.append("NP_" + text_replaced)
        else:
            # new_text.append(token.text_with_ws)
            new_text.append(lemmatized)
        if token.whitespace_:
            new_text.append(token.whitespace_)
    new_text = "".join(new_text)

    hyponyms = set()
    for pattern in patterns:
        for src, tgt in pattern.findall(new_text):
            hyponyms.add((src, tgt, pattern.__class__.__name__))
    return hyponyms


hyponyms = set()
for abstract in tqdm.tqdm(abstracts):
    for src, tgt, pattern in extract_hyponyms(abstract):
        hyponyms.add((src, tgt, pattern))

In [None]:
import random

nlp_normalized = spacy.load("en_core_web_sm")


def denormalize(text: str):
    return text.replace("_", " ").replace("NP ", "")


categories = list(
    {denormalize(src) for src, _, _ in hyponyms}
    | {denormalize(tgt) for _, tgt, _ in hyponyms}
)

In [None]:
def normalize(np_tag: str):
    text = np_tag.replace("_", " ").replace("NP ", "")
    doc = nlp_normalized(text)
    new_text = []
    for token in doc:
        if token.dep_ in ("det", "poss"):
            continue
        if token.tag_ in ("NNS", "NNPS"):
            new_text.append(token.lemma_.lower())
        else:
            new_text.append(token.text.lower())
        if token.whitespace_:
            new_text.append(token.whitespace_)
    new_text = "".join(new_text)
    return new_text

In [None]:
relations = [
    (normalize(src), normalize(tgt)) for src, tgt, pattern in tqdm.tqdm(hyponyms)
]

In [None]:
from collections import defaultdict

src_count = len({src for src, _ in relations})
tgt_count = len({tgt for _, tgt in relations})
rel_count = len({(src, tgt) for src, tgt in relations})
print(f"Unique source: {src_count}")
print(f"Unique target: {tgt_count}")
print(f"Unique relations: {rel_count}")

tree = defaultdict(lambda: set())
for src, tgt in relations:
    tree[src].add(tgt)

In [None]:
for token in nlp("including"):
    print(token.text, token.pos_, token.dep_, token.lemma_)