In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env OMP_NUM_THREADS=16

In [None]:
import random
import json
from collections import defaultdict

import dotenv
from transformers import pipeline

from llm_ol.dataset import data_model

dotenv.load_dotenv()

In [None]:
triplet_extractor = pipeline(
    "text2text-generation",
    model="Babelscape/rebel-large",
    tokenizer="Babelscape/rebel-large",
)


# Function to parse the generated text and extract the triplets
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = "", "", "", ""
    text = text.strip()
    current = "x"
    for token in (
        text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split()
    ):
        if token == "<triplet>":
            current = "t"
            if relation != "":
                triplets.append(
                    {
                        "head": subject.strip(),
                        "type": relation.strip(),
                        "tail": object_.strip(),
                    }
                )
                relation = ""
            subject = ""
        elif token == "<subj>":
            current = "s"
            if relation != "":
                triplets.append(
                    {
                        "head": subject.strip(),
                        "type": relation.strip(),
                        "tail": object_.strip(),
                    }
                )
            object_ = ""
        elif token == "<obj>":
            current = "o"
            relation = ""
        else:
            if current == "t":
                subject += " " + token
            elif current == "s":
                object_ += " " + token
            elif current == "o":
                relation += " " + token
    if subject != "" and relation != "" and object_ != "":
        triplets.append(
            {"head": subject.strip(), "type": relation.strip(), "tail": object_.strip()}
        )
    return triplets

In [None]:
G = data_model.load_graph("out/data/wikipedia/v2/train_eval_split/train_graph.json")

pages = [page for n in G.nodes for page in G.nodes[n]["pages"]]

In [None]:
abstract = random.choice(pages)["abstract"]

# We need to use the tokenizer manually since we need special tokens.
extracted_text = triplet_extractor.tokenizer.batch_decode(
    [
        triplet_extractor(abstract, return_tensors=True, return_text=False)[0][
            "generated_token_ids"
        ]
    ]
)

print(abstract)
print(extracted_text[0])
extracted_triplets = extract_triplets(extracted_text[0])
print(extracted_triplets)

In [None]:
abstract1 = random.choice(pages)["abstract"]
abstract2 = random.choice(pages)["abstract"]

triplet_extractor([abstract1, abstract2])  # , return_tensors=True, return_text=True)

In [None]:
counts = defaultdict(set)
with open("out/experiments/rebel/v1/test/categorised_pages.jsonl") as f:
    for line in f:
        page = json.loads(line)
        for triplet in page["triplets"]:
            counts[triplet["type"]].add((triplet["tail"], triplet["head"]))

print(
    sorted([(k, len(v)) for k, v in counts.items()], key=lambda x: x[1], reverse=True)
)