In [None]:
!pip install spacy
!pip install bitsandbytes
!pip install accelerate
!pip install transformers
!pip install tqdm
!pip install datasets
!python -m spacy download en_core_web_sm

In [None]:
# Import necessary libraries
from transformers import pipeline, AutoTokenizer
import spacy
import pandas as pd
from tqdm import tqdm
import requests
from functools import lru_cache


# Load the spaCy model for English
nlp = spacy.load("en_core_web_sm")

# Function to load relation schema from a CSV file
def load_relation_schema(csv_file):
    # Load CSV into DataFrame
    df = pd.read_csv(csv_file)
    # Convert 'domains' and 'ranges' columns to strings, replace NaN with empty string
    df['domains'] = df['domains'].fillna('').astype(str)
    df['ranges'] = df['ranges'].fillna('').astype(str)
    # Extract relevant columns and convert to list of tuples
    relations = [(row['property'], row['propertyLabel'], row['propertyDescription'], row['domains'], row['ranges']) for _, row in df.iterrows()]
    return relations

# Load the relation schema from a CSV file
relation_schema = load_relation_schema('data/relations.csv')

# Load the tokenizer for the model
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2')

# Check and set the PAD token for the tokenizer
pad_token = tokenizer.pad_token
if pad_token is None:
    tokenizer.pad_token = tokenizer.unk_token

# Initialize the Hugging Face pipeline with a generative model
generator = pipeline('text-generation', model='mistralai/Mistral-7B-Instruct-v0.2', tokenizer=tokenizer, device_map="auto", model_kwargs={"load_in_8bit": True})



#Checks weather a class is a subclass of another class. This is required for the range/domain check.
@lru_cache(maxsize=None)
def query_subclass_of(subclass_id, class_id):
    """
    Check if the given subclass_id is a subclass of class_id using Wikidata SPARQL endpoint.
    Uses caching to reduce duplicate queries.
    """
    sparql_query = f"""
    ASK {{
        wd:{subclass_id} (wdt:P279)* wd:{class_id} .
    }}
    """
    response = query_wikidata_sparql(sparql_query)
    return response.get('boolean', False)


#Checks type compatibility. This is deactivated at the moment.
def is_compatible(entity1_classes, entity2_classes, relation_domain, relation_range):
    """
    Check if entity1's classes are compatible with relation_domain and
    entity2's classes with relation_range, or vice versa.
    """
    # Assuming query_subclass_of function is defined to check subclass relationships
    # and direct class matches.

    domain_classes = set(relation_domain.split(', '))
    range_classes = set(relation_range.split(', '))

    # Check compatibility for entity1 with domain and entity2 with range
    compatibility1 = any(query_subclass_of(cls[0], domain[0]) for cls in entity1_classes for domain in domain_classes) \
                     and any(query_subclass_of(cls[0], rng[0]) for cls in entity2_classes for rng in range_classes)

    # Check compatibility for entity2 with domain and entity1 with range
    compatibility2 = any(query_subclass_of(cls[0], domain[0]) for cls in entity2_classes for domain in domain_classes) \
                     and any(query_subclass_of(cls[0], rng[0]) for cls in entity1_classes for rng in range_classes)

    return compatibility1 or compatibility2


# Function to generate prompts for relation extraction
def generate_prompts(candidate_sentences, relation_schema, linked_entities_dict):
    prompts = []
    entity_info = []
    # Iterate through each candidate sentence to generate prompts
    for sentence in candidate_sentences:
        entities_in_sentence = [ent.text for ent in sentence.ents if ent.text in linked_entities_dict]
        if len(entities_in_sentence) > 1:
            for i in range(len(entities_in_sentence)):
                for j in range(i + 1, len(entities_in_sentence)):
                    entity1_classes = linked_entities_dict[entities_in_sentence[i]]
                    entity2_classes = linked_entities_dict[entities_in_sentence[j]]
                    # Check compatibility and generate prompt if compatible
                    # Inside the generate_prompts function loop
                    for _, label, description, relation_domain, relation_range in relation_schema:
                        # Adjusted to pass entity classes directly corresponding to domain and range
                        #if is_compatible(entity1_classes, entity2_classes, relation_domain, relation_range):
                        # Generate prompt if compatible
                        prompt = f"[INST]Answer with yes or no. Does the sentence '{sentence.text}' contain the relationship '{label}' between '{entities_in_sentence[i]}' and '{entities_in_sentence[j]}'? Only answer with yes, if you are sure[/INST]"
                        prompts.append(prompt)
                        entity_info.append((entities_in_sentence[i], entities_in_sentence[j], description))
    return prompts, entity_info

# Function to get entity mentions from a sentence
def get_entity_mentions(sentence):
    doc = nlp(sentence)
    return [ent.text for ent in doc.ents]

# Function to query Wikidata SPARQL endpoint
def query_wikidata_sparql(query):
    url = "https://query.wikidata.org/sparql"
    headers = {'User-Agent': 'TypeCheck; https://enexa.eu/contact/'}

    response = requests.get(url,headers=headers, params={'format': 'json', 'query': query})
    print(response)
    return response.json()

# Function to search for an entity in Wikidata and get its classes
def search_wikidata(entity):
    search_url = "https://www.wikidata.org/w/api.php"
    search_params = {
        "action": "wbsearchentities",
        "language": "en",
        "format": "json",
        "search": entity
    }
    search_response = requests.get(search_url, params=search_params)
    search_results = search_response.json().get("search", [])
    if search_results:
        wikidata_id = search_results[0].get("id")  # Get the ID of the first match
        # Query SPARQL to get all classes for this entity
        sparql_query = f"""
        SELECT ?class ?classLabel WHERE {{
            wd:{wikidata_id} wdt:P31 ?class .
            SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en". }}
        }}
        """
        results = query_wikidata_sparql(sparql_query)
        # Extract class IDs and labels from query results
        classes = [(result['class']['value'].split('/')[-1], result['classLabel']['value']) for result in results['results']['bindings']]
        return wikidata_id, classes
    return None, []

# Function to link entities in sentences to Wikidata
def link_entities_to_wikidata(candidate_sentences):
    linked_entities = []
    # Iterate through candidate sentences to link entities
    for sentence in candidate_sentences:
        sentence_text = sentence.text  # Convert the span to a string
        entity_mentions = get_entity_mentions(sentence_text)
        for entity in entity_mentions:
            wikidata_id, classes = search_wikidata(entity)
            if wikidata_id:
                linked_entities.append((entity, wikidata_id, classes))
    return linked_entities

# Function to extract relations in batch using the generator
def batch_extract_relations(candidate_sentences, relation_schema, generator, linked_entities_dict, batch_size=512):
    extracted_relations = []
    prompts, entity_info_list = generate_prompts(candidate_sentences, relation_schema, linked_entities_dict)
    
    prompt_generator = (prompt for prompt in prompts)  # Create generator from prompts list

    # Iterate through generator responses and process them
    for i, response in enumerate(tqdm(generator(prompt_generator, max_new_tokens=4, batch_size=batch_size), desc="Processing")):
        generated_text = response[0]['generated_text']
        subject, object, relation_description = entity_info_list[i]
        # Check if the generated text indicates a "yes" response
        if "yes" in generated_text.lower():
            # Find the relation label corresponding to the description
            relation_label = next((label for _, label, description, _, _ in relation_schema if description == relation_description), None)
            if relation_label:
                extracted_relations.append((subject, relation_label, object))

    return extracted_relations


# Process input text with spaCy to identify candidate sentences for relation extraction
text = """
Berlin is located in Germany.
Barack Obama was born in Honolulu.
"""
doc = nlp(text)
candidate_sentences = [sent for sent in doc.sents if len(sent.ents) >= 2]
print("Candidated Sentences are:")
print(candidate_sentences)
# Link entities to Wikidata for further processing
linked_entities = link_entities_to_wikidata(candidate_sentences)
print(linked_entities)
linked_entities_dict = {entity: classes for entity, _, classes in linked_entities}

# Extract relations for all candidate sentences in a batch
extracted_relations = batch_extract_relations(candidate_sentences, relation_schema, generator, linked_entities_dict, batch_size=512)

# Print extracted relations
print("Extracted Relations:")
for relation in extracted_relations:
    print(relation)

