# Extract entities and triplets

In [None]:
import json
import re
from tqdm import tqdm
from dotenv import dotenv_values
import requests
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate


config = dotenv_values(".env")

class HuggingFaceWrapper:
    def __init__(self, model_id, api_key=None, temperature=0.9, max_tokens=512):
        self.model_id = model_id
        self.api_key = api_key if api_key else config["HF_API_KEY"]  # API key from .env file or parameter
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.api_url = f"https://api-inference.huggingface.co/models/{model_id}"
        self.headers = {"Authorization": f"Bearer {self.api_key}"}

    def __call__(self, prompt):
        try:
            payload = {
                "inputs": prompt,
                "parameters": {
                    "temperature": self.temperature,
                    "max_new_tokens": self.max_tokens,
                    "return_full_text": False  # Only return the generated text, not the prompt
                }
            }
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()[0]["generated_text"]
        except Exception as e:
            print(f"Error calling Hugging Face API: {e}")
            return None


#model for extraction
model_id = 'meta-llama/Llama-3-70b-chat-hf'  # Example model ID
llm = HuggingFaceWrapper(model_id)


# NER instructions
ner_instruction = (
    "Your task is to extract named entities from the given paragraph.\n"
    "Focus on people, organizations, locations, dates, and other proper nouns.\n"
    "Respond with a JSON object with a key 'named_entities' whose value is a list of entities.\n"
    "If there are no entities, return {\"named_entities\": []}.\n"
    "Output only the JSON."
)


ner_example_passage_1 = (
    "Apple Inc. announced its new iPhone 14 series on September 7, 2022, at the Steve Jobs Theater.\n"
    "CEO Tim Cook led the presentation where he also introduced Apple Watch Series 8.\n"
    "The company expects to ship these devices across the United States and Europe starting next week."
)

ner_example_entities_1 = (
    "{\"named_entities\": [\"Apple Inc.\", \"iPhone 14\", \"September 7, 2022\", \"Steve Jobs Theater\", \"Tim Cook\", \"Apple Watch Series 8\", \"United States\", \"Europe\"]}"
)

#prompt template for NER
ner_prompts = ChatPromptTemplate.from_messages([
    SystemMessage(ner_instruction),
    HumanMessage(f"Paragraph:\n{ner_example_passage_1}\n"),
    AIMessage(ner_example_entities_1),
    HumanMessagePromptTemplate.from_template("Paragraph:\n```\n{user_input}\n```")
])



openie_post_ner_instruction = (
    "Your task is to extract relationship triples from the given paragraph using the named entities provided.\n"
    "Each triple should be in the form [subject, relation, object] where subject and object are preferably named entities.\n"
    "Respond with a JSON object with a key 'triples' whose value is a list of triples.\n"
    "Make sure to resolve pronouns to their explicit references.\n"
    "Output only the JSON."
)

openie_example_input_1 = (
    f"Convert the following passage into RDF triples using the named entities.\n"
    f"Paragraph:\n{ner_example_passage_1}\n\n{ner_example_entities_1}"
)

openie_example_triples_1 = (
    "{\"triples\": [\n"
    "    [\"Apple Inc.\", \"announced\", \"iPhone 14\"],\n"
    "    [\"Apple Inc.\", \"announced on\", \"September 7, 2022\"],\n"
    "    [\"Apple Inc.\", \"held event at\", \"Steve Jobs Theater\"],\n"
    "    [\"Tim Cook\", \"is CEO of\", \"Apple Inc.\"],\n"
    "    [\"Tim Cook\", \"introduced\", \"Apple Watch Series 8\"],\n"
    "    [\"Apple Inc.\", \"will ship to\", \"United States\"],\n"
    "    [\"Apple Inc.\", \"will ship to\", \"Europe\"]\n"
    "]}"
)


openie_post_ner_prompts = ChatPromptTemplate.from_messages([
    SystemMessage(openie_post_ner_instruction),
    HumanMessage(openie_example_input_1),
    AIMessage(openie_example_triples_1),
    HumanMessagePromptTemplate.from_template(
        "Convert the following passage into RDF triples using the named entities.\n"
        "Paragraph:\n{passage}\n\n{named_entity_json}"
    )
])


def parse_json_response(response_text, default_key):
    """
    Remove extra formatting and extract all JSON objects from the response.
    Return the one that contains the expected key.
    """
    try:
        if response_text is None:
            return []
        
        # Clean up the response text to handle markdown formatting
        cleaned = response_text.replace("```json", "").replace("```", "").strip()
        
        # Try to find JSON objects in the text
        json_matches = re.findall(r'({.*?})', cleaned, re.DOTALL)
        
        for match in json_matches:
            try:
                parsed = json.loads(match)
                if isinstance(parsed, dict) and default_key in parsed:
                    return parsed[default_key]
            except json.JSONDecodeError:
                continue
        
        # Try parsing the whole cleaned response
        try:
            parsed = json.loads(cleaned)
            if isinstance(parsed, dict):
                return parsed.get(default_key, [])
            return parsed
        except json.JSONDecodeError:
            pass
        
        # If parsing fails, return an empty list
        return []
    
    except Exception as e:
        print(f"Error parsing JSON: {e}")
        return []


# ---------------------------
# Processing Functions
# ---------------------------
def process_text(text):
    # --- NER Step ---
    final_ner_prompt = ner_prompts.format(user_input=text)
    ner_response = llm(final_ner_prompt)
    print("NER Response received")
    named_entities = parse_json_response(ner_response, "named_entities")
    if named_entities is None:
        named_entities = []
    
    # --- OpenIE Step ---
    final_openie_prompt = openie_post_ner_prompts.format(
        passage=text,
        named_entity_json=json.dumps({"named_entities": named_entities})
    )
    
    openie_response = llm(final_openie_prompt)
    print("OpenIE Response received")
    triples = parse_json_response(openie_response, "triples")
    if triples is None:
        triples = []
    
    return {
        "extracted_entities": named_entities,
        "extracted_triples": triples
    }

def process_jsonl_file(input_file, output_file):
    results = []
    with open(input_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    for line in tqdm(lines, desc="Processing documents"):
        try:
            entry = json.loads(line.strip())
            text = entry.get("text") or entry.get("passage")
            title = entry.get("title", "")
            if text:
                processed = process_text(text)
                result = {
                    "title": title,
                    "passage": text,
                    "extracted_entities": processed["extracted_entities"],
                    "extracted_triples": processed["extracted_triples"]
                }
                results.append(result)
        except Exception as e:
            print(f"Error processing line: {e}")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    return results


if __name__ == "__main__":
    input_file = "filtered_output.jsonl"  # jsonl file {"title" : .... , "text": ......} 
    output_file = "extraction_results.json"   
    results = process_jsonl_file(input_file, output_file)
    print(f"Processed {len(results)} documents. Results saved to {output_file}.")

# Knowledge graph Creation

In [None]:
import pandas as pd
from scipy.sparse import csr_array
from processing import *
from glob import glob

import os
import json
from tqdm import tqdm
import pickle
import argparse
import copy

os.environ['TOKENIZERS_PARALLELISM'] = 'FALSE'

version = 'v3'
inter_triple_weight = 1.0
similarity_max = 1.0




file_path = "corpus.json" #corpus file
extracted_file = json.load(open(file_path))

cosine_sim_edges = False

extracted_triples = extracted_file['docs']


phrase_type = 'ents_only_lower_preprocess'  
if cosine_sim_edges:
    graph_type = 'facts_and_sim'  
else:
    graph_type = 'facts'

passage_json = []
phrases = []
entities = []
relations = {}
incorrectly_formatted_triples = []
triples_wo_ner_entity = []
triple_tuples = []
full_neighborhoods = {}
correct_wiki_format = 0

for i, row in tqdm(enumerate(extracted_triples), total=len(extracted_triples)):
    document = row['passage']
    raw_ner_entities = row['extracted_entities']
    ner_entities = [processing_phrases(p) for p in row['extracted_entities']]

    triples = row['extracted_triples']

    doc_json = row

    clean_triples = []
    unclean_triples = []
    doc_entities = set()


    for triple in triples:

        triple = [str(s) for s in triple]

        if len(triple) > 1:
            if len(triple) != 3:
                clean_triple = [processing_phrases(p) for p in triple]

                incorrectly_formatted_triples.append(triple)
                unclean_triples.append(triple)
            else:
                clean_triple = [processing_phrases(p) for p in triple]

                clean_triples.append(clean_triple)
                phrases.extend(clean_triple)

                head_ent = clean_triple[0]
                tail_ent = clean_triple[2]

                if head_ent not in ner_entities and tail_ent not in ner_entities:
                    triples_wo_ner_entity.append(triple)

                relations[(head_ent, tail_ent)] = clean_triple[1]

                raw_head_ent = triple[0]
                raw_tail_ent = triple[2]

                entity_neighborhood = full_neighborhoods.get(raw_head_ent, set())
                entity_neighborhood.add((raw_head_ent, triple[1], raw_tail_ent))
                full_neighborhoods[raw_head_ent] = entity_neighborhood

                entity_neighborhood = full_neighborhoods.get(raw_tail_ent, set())
                entity_neighborhood.add((raw_head_ent, triple[1], raw_tail_ent))
                full_neighborhoods[raw_tail_ent] = entity_neighborhood

                for triple_entity in [clean_triple[0], clean_triple[2]]:
                    entities.append(triple_entity)
                    doc_entities.add(triple_entity)

    doc_json['entities'] = list(set(doc_entities))
    doc_json['clean_triples'] = clean_triples
    doc_json['noisy_triples'] = unclean_triples
    triple_tuples.append(clean_triples)


unique_phrases = list(np.unique(entities))
unique_relations = np.unique(list(relations.values()) + ['equivalent'])

all_phrases = copy.deepcopy(unique_phrases)

kb = pd.DataFrame(unique_phrases, columns=['strings'])
kb2 = copy.deepcopy(kb)
kb['type'] = 'query'
kb2['type'] = 'kb'
kb_full = pd.concat([kb, kb2])

rel_kb = pd.DataFrame(unique_relations, columns=['strings'])
rel_kb2 = copy.deepcopy(rel_kb)
rel_kb['type'] = 'query'
rel_kb2['type'] = 'kb'
rel_kb_full = pd.concat([rel_kb, rel_kb2])



kb['type'] = 'kb'

create_graph_flag = True
if create_graph_flag:
    print('Creating Graph')

    node_json = [{'idx': i, 'name': p} for i, p in enumerate(unique_phrases)]
    kb_phrase_df = pd.DataFrame(unique_phrases)
    kb_phrase_dict = {p: i for i, p in enumerate(unique_phrases)}

    lose_facts = []

    for triples in triple_tuples:
        lose_facts.extend([tuple(t) for t in triples])

    lose_fact_dict = {f: i for i, f in enumerate(lose_facts)}
    fact_json = [{'idx': i, 'head': t[0], 'relation': t[1], 'tail': t[2]} for i, t in enumerate(lose_facts)]

# passage mapping

import json

# Load data from output_mapping.json
try:
    with open('output_mapping.json', 'r') as f:
        input_data = json.load(f)
except FileNotFoundError:
    print("Error: output_mapping.json file not found!")
    exit(1)
except json.JSONDecodeError:
    print("Error: output_mapping.json is not a valid JSON file!")
    exit(1)

# Create a mapping of unique passages to passage IDs
unique_passages = {}
passage_id_to_original = {}
passage_counter = 1

for triple, passage in input_data.items():
    if passage not in unique_passages:
        passage_id = f"passage{passage_counter}"
        unique_passages[passage] = passage_id
        passage_id_to_original[passage_id] = passage
        passage_counter += 1

# Create triple to passage ID mapping
triple_to_passage_id = {}
for triple, passage in input_data.items():
    passage_id = unique_passages[passage]
    triple_to_passage_id[triple] = passage_id

# Write the output files
with open('triple_to_passage_id.json', 'w') as f:
    json.dump(triple_to_passage_id, f, indent=2)

with open('passage_id_to_original.json', 'w') as f:
    json.dump(passage_id_to_original, f, indent=2)

print("Files created successfully!")
print(f"Created mappings for {len(triple_to_passage_id)} triples and {len(passage_id_to_original)} unique passages.")

# entity to passage mapping
import json

# Load data from output_mapping.json
try:
    with open('output_mapping.json', 'r') as f:
        input_data = json.load(f)
except FileNotFoundError:
    print("Error: output_mapping.json file not found!")
    exit(1)
except json.JSONDecodeError:
    print("Error: output_mapping.json is not a valid JSON file!")
    exit(1)

# Create a mapping of unique passages to passage IDs
unique_passages = {}
passage_id_to_original = {}
passage_counter = 1

for triple_str, passage in input_data.items():
    if passage not in unique_passages:
        passage_id = f"passage{passage_counter}"
        unique_passages[passage] = passage_id
        passage_id_to_original[passage_id] = passage
        passage_counter += 1

# Create entity to passage ID mapping (for head and tail only)
entity_to_passage_ids = {}
for triple_str, passage in input_data.items():
    # Remove the parentheses and split the triple string
    triple_parts = triple_str[1:-1].split(',')
    if len(triple_parts) >= 2:  # Ensure at least head and tail exist
        head = triple_parts[0].strip().strip("'")
        tail = triple_parts[-1].strip().strip("'")  # Consider the last part as tail
        passage_id = unique_passages[passage]

        # Map head entity to passage ID
        if head not in entity_to_passage_ids:
            entity_to_passage_ids[head] = set()
        entity_to_passage_ids[head].add(passage_id)

        # Map tail entity to passage ID
        if tail not in entity_to_passage_ids:
            entity_to_passage_ids[tail] = set()
        entity_to_passage_ids[tail].add(passage_id)
    else:
        print(f"Warning: Skipping triple with insufficient parts: {triple_str}")

# Convert sets to lists for JSON serialization
entity_to_passage_ids_list = {
    entity: sorted(list(passage_ids)) for entity, passage_ids in entity_to_passage_ids.items()
}

# Write the output files
with open('entity_to_passage_id.json', 'w') as f:
    json.dump(entity_to_passage_ids_list, f, indent=2)

with open('passage_id_to_original.json', 'w') as f:
    json.dump(passage_id_to_original, f, indent=2)

print("Files created successfully!")
print(f"Created mappings for {len(entity_to_passage_ids)} unique entities (head/tail) and {len(passage_id_to_original)} unique passages.")


import json
import random
import networkx as nx

# Load the entity to passage ID mapping
with open('entity_to_passage_id.json', 'r') as f:
    entity_to_passage_ids = json.load(f)

# Your existing triples
triples = fact_json

# Create a directed graph
kg = nx.DiGraph()

# Add nodes and edges with metadata
for triple in triples:
    head = triple["head"]
    relation = triple["relation"]
    tail = triple["tail"]
    
    # Get passage IDs for head and tail entities
    head_passage_ids = entity_to_passage_ids.get(head, [])
    tail_passage_ids = entity_to_passage_ids.get(tail, [])
    
    # Add nodes with passage ID metadata
    kg.add_node(head, passage_ids=head_passage_ids)
    kg.add_node(tail, passage_ids=tail_passage_ids)
    
    # Add edge with relation metadata
    kg.add_edge(head, tail, relation=relation)

print(f"ðŸ“Œ Knowledge Graph Created: {kg.number_of_nodes()} Nodes, {kg.number_of_edges()} Edges")

# Example of accessing node metadata

random_node = random.choice(list(kg.nodes()))
print(f"Example node '{random_node}' has these passage IDs: {kg.nodes[random_node]['passage_ids']}")

# Q/A generation via KG

In [None]:
#usage python multihop_qa_generator.py --input your_input_file.jsonl --output multihop_results.json

import json
import re
import os
import requests
from tqdm import tqdm
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

class MistralWrapper:
    def __init__(self, model_id="mistralai/Mistral-7B-Instruct-v0.2", api_key=None, temperature=0.1, max_tokens=1024):
        """
        Initialize the Mistral model wrapper.
        
        Args:
            model_id: The Hugging Face model ID
            api_key: Hugging Face API key (if None, loads from HF_API_KEY env variable)
            temperature: Temperature for text generation
            max_tokens: Maximum number of tokens to generate
        """
        self.model_id = model_id
        self.api_key = api_key if api_key else os.getenv("HF_API_KEY")
        if not self.api_key:
            raise ValueError("API key not provided. Set HF_API_KEY environment variable or pass api_key parameter.")
        
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.api_url = f"https://api-inference.huggingface.co/models/{model_id}"
        self.headers = {"Authorization": f"Bearer {self.api_key}"}

    def __call__(self, prompt):
        """
        Generate a response using the Mistral model.
        
        Args:
            prompt: The input prompt
            
        Returns:
            The generated text
        """
        try:
            # Format the prompt according to Mistral's chat template
            formatted_prompt = f"""<s>[INST] {prompt} [/INST]"""
            
            payload = {
                "inputs": formatted_prompt,
                "parameters": {
                    "temperature": self.temperature,
                    "max_new_tokens": self.max_tokens,
                    "return_full_text": False,
                    "do_sample": True,
                }
            }
            
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            response.raise_for_status()
            
            return response.json()[0]["generated_text"]
        except Exception as e:
            print(f"Error calling Hugging Face API: {e}")
            return None

def parse_json_document(json_line):
    """
    Parse the JSON document and extract facts and passages.
    
    Args:
        json_line: A line from the input file containing JSON data
        
    Returns:
        A tuple of (facts, passages, title)
    """
    try:
        data = json.loads(json_line)
        document = data.get("document", "")
        title = data.get("chapter_title", "")
        
        # Extract facts
        facts_match = re.search(r'<facts>(.*?)<\/facts>', document, re.DOTALL)
        facts = []
        if facts_match:
            facts_text = facts_match.group(1)
            fact_matches = re.findall(r'<fact>(.*?)<\/fact>', facts_text, re.DOTALL)
            facts = [fact.strip() for fact in fact_matches]
        
        # Extract passages
        passages_match = re.search(r'<passages>(.*?)<\/passages>', document, re.DOTALL)
        passages = []
        if passages_match:
            passages_text = passages_match.group(1)
            passage_matches = re.findall(r'<passage>(.*?)<\/passage>', passages_text, re.DOTALL)
            passages = [passage.strip() for passage in passage_matches]
        
        return facts, passages, title
    except json.JSONDecodeError:
        print(f"Error parsing JSON line: {json_line[:100]}...")
        return [], [], ""

def create_multihop_prompt(facts, passages):
    """
    Create a prompt for generating multi-hop questions based on facts and passages.
    
    Args:
        facts: List of facts
        passages: List of passages
        
    Returns:
        A formatted prompt string
    """
    facts_text = "\n".join([f"<fact>{fact}</fact>" for fact in facts])
    passages_text = "\n".join([f"<passage>{passage}</passage>" for passage in passages])
    
    prompt = f"""You are a Multi-Hop Factual Question Formulation Assistant. Generate precise, fact-based, multi-hop questions requiring integration of specific information from provided texts, resulting in very brief factoid answers (2-3 words maximum).

Given the following facts and passages, please generate multi-hop questions:

Facts:
{facts_text}

Passages:
{passages_text}

Generate 1-2 multi-hop questions that require integrating multiple pieces of information from these texts.
Each question should have a clear factoid answer (2-3 words maximum).
Begin with subquestions that build towards a final complex question.
Please follow this exact format:

<sub-question>[First focused question]</sub-question>
<answer-to-sub-question>[Concise answer]</answer-to-sub-question>

<sub-question>[Follow-up question using previous answer]</sub-question>
<answer-to-sub-question>[Concise answer]</answer-to-sub-question>

[Additional sub-questions as needed]

<question-type>[Number of reasoning steps, e.g., "2-hop"]</question-type>

<complex-question>[Question that integrates all previous sub-questions]</complex-question>
<answer-to-complex-question>[Final answer, 2-3 words]</answer-to-complex-question>

<explanation>[Break down how to answer the complex question step-by-step]</explanation>
<done>
"""
    return prompt

def extract_qa_from_response(response):
    """
    Extract the multi-hop questions, answers, and explanations from the model's response.
    
    Args:
        response: The model's generated text
        
    Returns:
        A dictionary containing the extracted QA components
    """
    if not response:
        return None
    
    # Extract components using regex
    qa_dict = {
        "sub_questions": [],
        "sub_answers": [],
        "question_type": None,
        "complex_question": None,
        "complex_answer": None,
        "explanation": None
    }
    
    # Extract sub-questions and their answers
    sub_questions = re.findall(r'<sub-question>(.*?)<\/sub-question>', response, re.DOTALL)
    sub_answers = re.findall(r'<answer-to-sub-question>(.*?)<\/answer-to-sub-question>', response, re.DOTALL)
    
    for i in range(min(len(sub_questions), len(sub_answers))):
        qa_dict["sub_questions"].append(sub_questions[i].strip())
        qa_dict["sub_answers"].append(sub_answers[i].strip())
    
    # Extract question type
    question_type_match = re.search(r'<question-type>(.*?)<\/question-type>', response, re.DOTALL)
    if question_type_match:
        qa_dict["question_type"] = question_type_match.group(1).strip()
    
    # Extract complex question and answer
    complex_q_match = re.search(r'<complex-question>(.*?)<\/complex-question>', response, re.DOTALL)
    if complex_q_match:
        qa_dict["complex_question"] = complex_q_match.group(1).strip()
    
    complex_a_match = re.search(r'<answer-to-complex-question>(.*?)<\/answer-to-complex-question>', response, re.DOTALL)
    if complex_a_match:
        qa_dict["complex_answer"] = complex_a_match.group(1).strip()
    
    # Extract explanation
    explanation_match = re.search(r'<explanation>(.*?)<\/explanation>', response, re.DOTALL)
    if explanation_match:
        qa_dict["explanation"] = explanation_match.group(1).strip()
    
    return qa_dict

def process_jsonl_file(input_file, output_file, model):
    """
    Process a JSONL file containing documents and generate multi-hop questions.
    
    Args:
        input_file: Path to the input JSONL file
        output_file: Path to save the output JSON results
        model: The language model wrapper
        
    Returns:
        A list of results containing generated multi-hop QAs
    """
    results = []
    
    with open(input_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    for line_idx, line in enumerate(tqdm(lines, desc="Processing documents")):
        try:
            facts, passages, title = parse_json_document(line)
            
            if not facts or not passages:
                print(f"Skipping document {line_idx} - missing facts or passages")
                continue
            
            prompt = create_multihop_prompt(facts, passages)
            response = model(prompt)
            
            qa_data = extract_qa_from_response(response)
            
            if qa_data:
                result = {
                    "title": title,
                    "facts": facts,
                    "passages": passages,
                    "qa_data": qa_data
                }
                results.append(result)
                print(f"Generated QA for document {line_idx} - {title}")
            else:
                print(f"Failed to extract QA from response for document {line_idx}")
        
        except Exception as e:
            print(f"Error processing document {line_idx}: {e}")
    
    # Save the results to the output file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    return results

def process_text_data(text_data, output_file=None):
    """
    Process a string containing multiple JSON documents and generate multi-hop questions.
    
    Args:
        text_data: String containing JSON documents (one per line)
        output_file: Optional path to save the output JSON results
        
    Returns:
        A list of results containing generated multi-hop QAs
    """
    # Initialize the Mistral model
    model = MistralWrapper(model_id="mistralai/Mixtral-8x22B-Instruct-v0.1")
    
    results = []
    lines = text_data.strip().split('\n')
    
    for line_idx, line in enumerate(tqdm(lines, desc="Processing documents")):
        try:
            facts, passages, title = parse_json_document(line)
            
            if not facts or not passages:
                print(f"Skipping document {line_idx} - missing facts or passages")
                continue
            
            prompt = create_multihop_prompt(facts, passages)
            response = model(prompt)
            
            qa_data = extract_qa_from_response(response)
            
            if qa_data:
                result = {
                    "title": title,
                    "facts": facts,
                    "passages": passages,
                    "qa_data": qa_data
                }
                results.append(result)
                print(f"Generated QA for document {line_idx} - {title}")
            else:
                print(f"Failed to extract QA from response for document {line_idx}")
        
        except Exception as e:
            print(f"Error processing document {line_idx}: {e}")
    
    # Save the results to the output file if specified
    if output_file:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
    
    return results

def main():
    # Parse command line arguments
    import argparse
    parser = argparse.ArgumentParser(description='Generate multi-hop QA pairs from documents')
    parser.add_argument('--input', type=str, help='Input JSONL file path')
    parser.add_argument('--output', type=str, default='multihop_qa_results.json', help='Output JSON file path')
    parser.add_argument('--model', type=str, default='mistralai/Mixtral-8x22B-Instruct-v0.1', help='Hugging Face model ID')
    
    args = parser.parse_args()
    
    # Initialize the language model
    try:
        model = MistralWrapper(model_id=args.model)
        print(f"Using model: {args.model}")
    except ValueError as e:
        print(f"Error initializing model: {e}")
        return
    
    # Process the input file
    results = process_jsonl_file(args.input, args.output, model)
    print(f"Processed {len(results)} documents. Results saved to {args.output}.")

if __name__ == "__main__":
    main()

# Path Sampling

In [None]:
import json
import networkx as nx
import random


with open("mappings/passage_id_to_passage.json", "r") as f:
    passage_id_to_text = json.load(f)

def sample_paths_with_passages(kg, num_paths=50, min_hops=3, max_hops=5, min_passages=2, max_passages=5):
    paths = []
    attempts = 0
    max_attempts = num_paths * 20

    nodes = list(kg.nodes())

    while len(paths) < num_paths and attempts < max_attempts:
        attempts += 1
        facts = []
        visited_nodes = set()
        visited_passages = set()

        candidate_nodes = [n for n in nodes if len(kg.nodes[n]['passage_ids']) >= min_passages]
        if not candidate_nodes:
            print("No suitable starting nodes found. Lower 'min_passages'.")
            break

        current_node = random.choice(candidate_nodes)
        current_passages = set(kg.nodes[current_node]['passage_ids'])
        if len(current_passages) > max_passages:
            current_passages = set(random.sample(list(current_passages), max_passages))

        visited_nodes.add(current_node)
        visited_passages.update(current_passages)

        hop_count = random.randint(min_hops, max_hops)
        hop_counter = 0

        while hop_counter < hop_count:
            neighbors = []
            for neighbor in kg.neighbors(current_node):
                if neighbor in visited_nodes:
                    continue
                neighbor_passages = set(kg.nodes[neighbor]['passage_ids'])
                new_passages = neighbor_passages - visited_passages

                if len(visited_passages | new_passages) <= max_passages:
                    relation = kg.edges[current_node, neighbor]['relation']
                    fact_str = f"{current_node} {relation} {neighbor}"
                    neighbors.append((neighbor, fact_str, new_passages))

            if not neighbors:
                break

            neighbors.sort(key=lambda x: len(x[2]), reverse=True)
            top_choices = neighbors[:3]
            chosen_neighbor, fact_str, passages_added = random.choice(top_choices)

            facts.append(fact_str)
            visited_nodes.add(chosen_neighbor)
            visited_passages.update(passages_added)
            current_node = chosen_neighbor
            hop_counter += 1

            if len(visited_passages) == max_passages:
                break

        if len(facts) >= min_hops and len(visited_passages) >= min_passages:
            # Convert passage IDs to actual passage texts
            text_passages = [
                passage_id_to_text[str(pid)]
                for pid in visited_passages
                if str(pid) in passage_id_to_text
            ]
            paths.append({
                "facts": facts,
                "passages": text_passages
            })

    print(f"âœ… Sampled {len(paths)} paths (after {attempts} attempts).")
    return paths

# Generate the data
sampled_data = sample_paths_with_passages(
    kg,
    num_paths=10000,
    min_hops=3,
    max_hops=5,
    min_passages=2,
    max_passages=5
)

# Save to file
with open("sampled_paths_with_text_passages.json", "w", encoding="utf-8") as f:
    json.dump(sampled_data, f, indent=4, ensure_ascii=False)

print("ðŸ“„ Saved to 'sampled_paths_with_text_passages.json'.")


# Fine tuning

import json
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

def load_qa_data(jsonl_path):
    data = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            item = json.loads(line)
            q = item["question"].strip()
            a = item["answer"].strip()
            full_text = f"<s>[INST] Answer the following question:\n\n{q}\n[/INST]\n{a}</s>"
            data.append({"text": full_text})
    return Dataset.from_list(data)

def tokenize_fn(examples, tokenizer, max_length=2048):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)

def main():
    model_id = "mistralai/Mistral-7B-v0.1"
    jsonl_path = "qa_data.jsonl"

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        load_in_4bit=True,
        device_map="auto"
    )
    model = prepare_model_for_kbit_training(model)

    # LoRA Config
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    model = get_peft_model(model, lora_config)

    # Load and tokenize dataset
    dataset = load_qa_data(jsonl_path)
    dataset = dataset.train_test_split(test_size=0.05)
    tokenized = dataset.map(lambda x: tokenize_fn(x, tokenizer), batched=True)
    tokenized = tokenized.remove_columns(["text"])

    # Training setup
    training_args = TrainingArguments(
        output_dir="./checkpoints-mistral-qa",
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=4,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        num_train_epochs=4,
        fp16=True,
        logging_dir="./logs",
        save_total_limit=2,
        load_best_model_at_end=True,
        report_to="none"
    )

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["test"],
        data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    )

    # Train & Save
    trainer.train()
    model.save_pretrained("./mistral-qa-final")
    tokenizer.save_pretrained("./mistral-qa-final")

if __name__ == "__main__":
    main()
