In [11]:
!pip install transformers wikipedia torch neo4j pandas numpy pyarrow

Collecting pyarrow
  Downloading pyarrow-15.0.1-cp311-cp311-win_amd64.whl.metadata (3.1 kB)
Downloading pyarrow-15.0.1-cp311-cp311-win_amd64.whl (24.8 MB)
   ---------------------------------------- 0.0/24.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/24.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/24.8 MB 435.7 kB/s eta 0:00:57
   ---------------------------------------- 0.1/24.8 MB 544.7 kB/s eta 0:00:46
   ---------------------------------------- 0.2/24.8 MB 952.6 kB/s eta 0:00:26
   ---------------------------------------- 0.3/24.8 MB 1.3 MB/s eta 0:00:19
    --------------------------------------- 0.4/24.8 MB 1.5 MB/s eta 0:00:17
    --------------------------------------- 0.5/24.8 MB 1.7 MB/s eta 0:00:15
    --------------------------------------- 0.5/24.8 MB 1.7 MB/s eta 0:00:15
    --------------------------------------- 0.5/24.8 MB 1.7 MB/s eta 0:00:15
   - -------------------------------------- 1.1/24.8 MB 2.3 MB/s eta 0:00:11
   - -

In [12]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import wikipedia
import pandas as pd
import re
from neo4j import GraphDatabase

In [13]:
class KB():
    def __init__(self):
        self.entities = {}
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]
        spans_to_add = [span for span in r1["meta"]["spans"]
                        if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add

    def get_wikipedia_data(self, candidate_entity):
        try:
            page = wikipedia.page(candidate_entity, auto_suggest=False)
            entity_data = {
                "title": page.title,
                "url": page.url,
                "summary": page.summary
            }
            return entity_data
        except:
            return None

    def add_entity(self, e):
        self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}

    def add_relation(self, r):
        # check on wikipedia
        candidate_entities = [r["head"], r["tail"]]
        entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]

        # if one entity does not exist, stop
        if any(ent is None for ent in entities):
            return

        # manage new entities
        for e in entities:
            self.add_entity(e)

        # rename relation entities with their wikipedia titles
        r["head"] = entities[0]["title"]
        r["tail"] = entities[1]["title"]

        # manage new relation
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Entities:")
        for e in self.entities.items():
            print(f"  {e}")
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

In [14]:
class REBELGraph():
    def __init__(self):
        self.tokenizer=AutoTokenizer.from_pretrained("Babelscape/rebel-large")
        self.model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

    def extract_relations_from_model_output(self,text):
        relations = []
        relation, subject, relation, object_ = '', '', '', ''
        text = text.strip()
        current = 'x'
        text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
        for token in text_replaced.split():
            if token == "<triplet>":
                current = 't'
                if relation != '':
                    relations.append({
                        'head': subject.strip(),
                        'type': relation.strip(),
                        'tail': object_.strip()
                    })
                    relation = ''
                subject = ''
            elif token == "<subj>":
                current = 's'
                if relation != '':
                    relations.append({
                        'head': subject.strip(),
                        'type': relation.strip(),
                        'tail': object_.strip()
                    })
                object_ = ''
            elif token == "<obj>":
                current = 'o'
                relation = ''
            else:
                if current == 't':
                    subject += ' ' + token
                elif current == 's':
                    object_ += ' ' + token
                elif current == 'o':
                    relation += ' ' + token
        if subject != '' and relation != '' and object_ != '':
            relations.append({
                'head': subject.strip(),
                'type': relation.strip(),
                'tail': object_.strip()
            })
        return relations
    
    def from_text_to_kb(self,text, span_length=128, verbose=False):
        # tokenize whole text
        inputs = self.tokenizer([text],max_length=4092,truncation=True, return_tensors="pt")

        # compute span boundaries
        num_tokens = len(inputs["input_ids"][0])
        if verbose:
            print(f"Input has {num_tokens} tokens")
        num_spans = math.ceil(num_tokens / span_length)
        if verbose:
            print(f"Input has {num_spans} spans")
        overlap = math.ceil((num_spans * span_length - num_tokens) /
                            max(num_spans - 1, 1))
        spans_boundaries = []
        start = 0
        for i in range(num_spans):
            spans_boundaries.append([start + span_length * i,
                                    start + span_length * (i + 1)])
            start -= overlap
        if verbose:
            print(f"Span boundaries are {spans_boundaries}")

        # transform input with spans
        tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
        tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                        for boundary in spans_boundaries]
        inputs = {
            "input_ids": torch.stack(tensor_ids),
            "attention_mask": torch.stack(tensor_masks)
        }

        # generate relations
        num_return_sequences = 3
        gen_kwargs = {
            "max_length": 256,
            "length_penalty": 0,
            "num_beams": 3,
            "num_return_sequences": num_return_sequences
        }
        generated_tokens = self.model.generate(
            **inputs,
            **gen_kwargs,
        )

        # decode relations
        decoded_preds = self.tokenizer.batch_decode(generated_tokens,
                                            skip_special_tokens=False)

        # create kb
        kb = KB()
        i = 0
        for sentence_pred in decoded_preds:
            current_span_index = i // num_return_sequences
            relations = self.extract_relations_from_model_output(sentence_pred)
            for relation in relations:
                relation["meta"] = {
                    "spans": [spans_boundaries[current_span_index]]
                }
                kb.add_relation(relation)
            i += 1

        return kb
    
    def send_to_neo4j(self,kb,username="neo4j",password="neo4j"):
        uri = "bolt://localhost:7687"
        svo_triples = [tuple(relation.values()) for relation in kb.relations]

        cypher_query = """
        UNWIND $triples AS triple
        MERGE (s:Subject {name: triple[0]})
        MERGE (o:Object {name: triple[2]})
        WITH $triples AS triples
        UNWIND triples AS triple
        MATCH (s:Subject {name: triple[0]})
        MATCH (o:Object {name: triple[2]})
        MERGE (s)-[:PROVIDED {verb: triple[1]}]->(o)
        """

        # Define a function to execute the Cypher query
        def execute_query(tx):
            tx.run(cypher_query, triples=svo_triples)

        # Connect to the Neo4j database and execute the query
        driver = GraphDatabase.driver(uri, auth=(username, password))
        with driver.session() as session:
            session.execute_write(execute_query)

        # Close the driver
        driver.close()

In [15]:
def load_text_from_file(filepath,col="tokens"):
    def remove_parenthesis_and_spaces_around_fullstops(text):
        text=re.sub(r'\s*[\(\[]\s*.*?\s*[\)\]]\s*', '', text)
        text = re.sub(r'\s+\.', '.', text)
        text = re.sub(r'\.(\s+)', '.', text)
        text = re.sub(r'\.{2,}', '.', text)
        text = re.sub(r'\.', '. ', text)
        text = re.sub(r'\s+,', ',', text)
        return text
    
    tokens=pd.read_parquet(filepath)[col][:5]
    str_list=[]
    for token in tokens:
        str_list.append(" ".join(token))
    result="".join(str_list)
    result=remove_parenthesis_and_spaces_around_fullstops(result)
    return result

In [16]:
text=load_text_from_file("train-00000-of-00001.parquet")
rebel=REBELGraph()
kb=rebel.from_text_to_kb(text)
rebel.send_to_neo4j(kb,username="neo4j",password="password")



  lis = BeautifulSoup(html).find_all('li')


In [33]:
print("Total number of entities:",len(kb.entities))
for key in kb.entities.keys():
    print({key:kb.entities[key]})

Total number of entities: 16
{'Adenomatous polyposis coli': {'url': 'https://en.wikipedia.org/wiki/Adenomatous_polyposis_coli', 'summary': 'Adenomatous polyposis coli (APC) also known as deleted in polyposis 2.5 (DP2.5) is a protein that in humans is encoded by the APC gene. The APC protein is a negative regulator  that controls beta-catenin concentrations and interacts with E-cadherin, which are involved in cell adhesion. Mutations in the APC gene may result in colorectal cancer and desmoid tumors.APC is classified as a tumor suppressor gene. Tumor suppressor genes prevent the uncontrolled growth of cells that may result in cancerous tumors. The protein made by the APC gene plays a critical role in several cellular processes that determine whether a cell may develop into a tumor. The APC protein helps control how often a cell divides, how it attaches to other cells within a tissue, how the cell polarizes and the morphogenesis of the 3D structures, or whether a cell moves within or awa

In [35]:
print("Total number of relations:",len(kb.relations))
for relation in kb.relations:
    print(relation)

Total number of relations: 15
{'head': 'Adenomatous polyposis coli', 'type': 'instance of', 'tail': 'Gene', 'meta': {'spans': [[0, 128]]}}
{'head': 'Chromosome', 'type': 'has part', 'tail': 'DNA', 'meta': {'spans': [[378, 506]]}}
{'head': 'Spindle apparatus', 'type': 'part of', 'tail': 'Chromosome', 'meta': {'spans': [[378, 506], [2016, 2144]]}}
{'head': 'DNA', 'type': 'part of', 'tail': 'Chromosome', 'meta': {'spans': [[378, 506]]}}
{'head': 'Stop codon', 'type': 'part of', 'tail': 'Protein', 'meta': {'spans': [[504, 632]]}}
{'head': 'Stop codon', 'type': 'subclass of', 'tail': 'Protein', 'meta': {'spans': [[504, 632], [2142, 2270]]}}
{'head': 'Carcinogenesis', 'type': 'subclass of', 'tail': 'Cancer', 'meta': {'spans': [[882, 1010]]}}
{'head': 'Small molecule', 'type': 'subclass of', 'tail': 'Biological agent', 'meta': {'spans': [[1134, 1262]]}}
{'head': 'Biological agent', 'type': 'has part', 'tail': 'Small molecule', 'meta': {'spans': [[1134, 1262]]}}
{'head': 'Ubiquitin', 'type': '