In [3]:
#!/usr/bin/env python3
"""
extract_relations.py

Read the JSON outputs from pdf_ingestion (data/text_json/),
run SciSpaCy + rule-based NER to pull out METHODS, DATASETS,
and METRIC mentions, then do simple co-occurrence relation
extraction. Write all triples to nlp/relations.json.
"""

import os
import json
import re
from tqdm import tqdm

import en_core_sci_sm
from spacy.pipeline import EntityRuler

# --- Config ----------------------------------------------------------------

# Where your step-2 JSON lives:
TEXT_JSON_DIR = os.path.expanduser(r"C:\Users\offic\AGENT\data\text_json")
# Where to write relation triples:
OUTPUT_DIR    = os.path.expanduser(r"C:\Users\offic\AGENT\data\nlp")
RELATIONS_OUT = os.path.join(OUTPUT_DIR, "relations.json")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Pipeline Setup --------------------------------------------------------

# 1) Load the SciSpaCy model directly
nlp = en_core_sci_sm.load()

# 2) Define your patterns
patterns = [
    {"label": "METHOD",
     "pattern": [{"TEXT": {"REGEX": "^(ResNet|ConvNeXt|ViT|T5|BERT)[0-9A-Za-z-]*$"}}]},
    {"label": "DATASET",
     "pattern": [{"TEXT": {"REGEX": "^(CIFAR|ImageNet|MNIST|FC100|CUB)-?[0-9]*$"}}]},
    {"label": "METRIC",
     "pattern": [{"TEXT": {"REGEX": "^(accuracy|F1|precision|recall|ROC-AUC|log\\.?loss)$",
                           "LOWER": True}}]},
]

# 3) Add *your* EntityRuler via the factory name, passing patterns in config
ruler = nlp.add_pipe(
    "entity_ruler",
    name="scientific_ruler",
    before="ner",
    config={"overwrite_ents": True}
)
ruler.add_patterns(patterns)

# 4) Verify
print(f"Loaded model: {nlp.meta['name']}")
print("Pipeline components:", nlp.pipe_names)


# --- Extraction Functions --------------------------------------------------

def extract_entities(text):
    """Return list of {'text', 'label'} for METHOD, DATASET, METRIC."""
    doc = nlp(text)
    return [{"text": ent.text, "label": ent.label_}
            for ent in doc.ents
            if ent.label_ in {"METHOD", "DATASET", "METRIC"}]

def extract_relations(text, entities):
    """
    For each sentence, if METHOD + DATASET co-occur,
    optionally attach any numeric METRIC found in that sentence.
    """
    doc = nlp(text)
    triples = []
    for sent in doc.sents:
        sent_text = sent.text
        sent_ents = [e for e in entities if e["text"] in sent_text]
        methods  = [e["text"] for e in sent_ents if e["label"] == "METHOD"]
        datasets = [e["text"] for e in sent_ents if e["label"] == "DATASET"]
        metrics  = re.findall(r"\b\d+(?:\.\d+)?\s?%|\blog\.?\s?loss\b",
                              sent_text, flags=re.I)
        for m in methods:
            for d in datasets:
                if metrics:
                    for met in metrics:
                        triples.append({
                            "method":   m,
                            "dataset":  d,
                            "metric":   met,
                            "sentence": sent_text
                        })
                else:
                    triples.append({
                        "method":   m,
                        "dataset":  d,
                        "sentence": sent_text
                    })
    return triples

# --- Main Loop -------------------------------------------------------------

def main():
    all_relations = []
    files = [f for f in os.listdir(TEXT_JSON_DIR) if f.endswith(".json")]
    if not files:
        print("No JSON files found in", TEXT_JSON_DIR)
        return

    for fname in tqdm(files, desc="Processing Papers"):
        paper_id = fname.rsplit(".",1)[0]
        path = os.path.join(TEXT_JSON_DIR, fname)
        paper = json.load(open(path, encoding="utf-8"))
        text  = paper.get("full_text", "")
        ents  = extract_entities(text)
        rels  = extract_relations(text, ents)
        for r in rels:
            r["paper_id"] = paper_id
        all_relations.extend(rels)

    # Write out
    with open(RELATIONS_OUT, "w", encoding="utf-8") as f:
        json.dump(all_relations, f, indent=2, ensure_ascii=False)

    print(f"\n✅ Wrote {len(all_relations)} relation triples to {RELATIONS_OUT}")

if __name__ == "__main__":
    main()


Loaded model: core_sci_sm
Pipeline components: ['tok2vec', 'tagger', 'attribute_ruler', 'lemmatizer', 'parser', 'scientific_ruler', 'ner']


Processing Papers: 100%|██████████| 67/67 [05:41<00:00,  5.09s/it]



✅ Wrote 325838 relation triples to C:\Users\offic\AGENT\data\nlp\relations.json
