In [88]:
import spacy
from collections import Counter
import requests
from requests.exceptions import HTTPError
from SPARQLWrapper import SPARQLWrapper, JSON
import json
import time
from urllib.parse import quote
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from collections import defaultdict
from rouge_score import rouge_scorer
from datasets import load_dataset
import os
from bert_score import score as bert_score
import numpy as np
from transformers import AutoTokenizer
from fuzzywuzzy import fuzz


EntityFinder: Extracts named entities from text and sorts them based on frequency.

Inputs: raw text or a file path

Outputs: a sorted list of named entities



In [None]:
class EntityFinder:
    def __init__(self):
        self.nlp = spacy.load("en_core_web_trf")
        
    def ner(self, input_text):
        if os.path.isfile(input_text):
            with open(input_text, "r", encoding="utf-8") as file:
                text = file.read()
        else:
            text = input_text

        doc = self.nlp(text)

        entities = [(ent.text, ent.label_) for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART","DATE","NORP","LAW"]]
        return entities

    def get_sorted_entities(self, entities):
        
        entity_counter = Counter(entities)
        sorted_entities = [entity for entity, count in entity_counter.most_common()]
        
        return sorted_entities
 
# entity_finder = EntityFinder()
# extracted_entities = entity_finder.ner("InputText.txt")
# sorted_entities = entity_finder.get_sorted_entities(extracted_entities)
# for entity in sorted_entities:
#     print(entity)

CacheManager: manages a singleton cache for storing and retrieving Wikidata entities, ensuring that cached data persists across program runs.

In [None]:
class CacheManager:
    _instance = None  # 

    def __new__(cls, use_cache=True):
        if cls._instance is None:
            cls._instance = super(CacheManager, cls).__new__(cls)
            cls._instance.use_cache = use_cache
            cls._instance.cache_file = "wikidata_cache.json"
            cls._instance.wikidata_cache = cls._instance.load_cache() if use_cache else {}
        return cls._instance  

    def load_cache(self):
        try:
            with open(self.cache_file, "r") as f:
                return json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            with open(self.cache_file, "w") as f:
                json.dump({}, f)
            return {}
    
    def save_cache(self): 
        if self.use_cache:
            with open(self.cache_file, "w") as f:
                json.dump(self.wikidata_cache, f, indent=4)

    def get(self, entity_name):
        return self.wikidata_cache.get(entity_name)

    def update(self, entity_name, data):
        self.wikidata_cache[entity_name] = data
        self.save_cache()

# cache_manager = CacheManager(use_cache=False)

EntityLinker: links named entities to Wikidata by retrieving their IDs and labels, using a caching system to minimize redundant API requests.

Inputs: a list of entity tuples ((entity_name, entity_type)).

Outputs: a dictionary containing entity names and their Wikidata information.

In [None]:
class EntityLinker:
    def __init__(self):
        self.cache_manager = CacheManager()  
    
    def link_entities(self, sorted_entities):
        wikidata_entities = {}

        for entity_name, entity_type in sorted_entities:
            cached_data = self.cache_manager.get(entity_name)
            if self.cache_manager.use_cache and cached_data:
                print(f"Retrieved {entity_name} from cache")
                wikidata_entities[entity_name] = {
                    "info": cached_data["info"],
                    "type": cached_data.get("type", entity_type)
                }
            else:
                print(f"Fetching {entity_name} from Wikidata API")
                url = f'https://www.wikidata.org/w/api.php?action=wbsearchentities&search={entity_name}&language=en&format=json'
                response = requests.get(url)
                data = response.json()

                if 'search' in data and data['search']:
                    entity_id = data['search'][0]['id']
                    entity_label = data['search'][0]['label']

                    entity_info = {
                        "info": (entity_id, entity_label),
                        "triples": [],
                        "type": entity_type
                    }

                    # Cache the entity
                    if self.cache_manager.use_cache:
                        self.cache_manager.update(entity_name, entity_info)

                    wikidata_entities[entity_name] = entity_info

        return wikidata_entities
    

# entity_linker = EntityLinker()
# wikidata_entities = entity_linker.link_entities(sorted_entities)
# print(wikidata_entities)

KnowledgeExtractor: extracts factual knowledge from Wikidata by querying entity relationships, using caching to reduce redundant API calls.

Inputs: a dictionary of Wikidata entities with their IDs and types.

Outputs: a list of formatted knowledge statements extracted from Wikidata.

In [None]:
class KnowledgeExtractor:
    def __init__(self):
        self.cache_manager = CacheManager() 
        self.query_templates = self.open_queries()
        self.sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
        self.sparql.setReturnFormat(JSON)

    def open_queries(self):
        """Load the query templates from the JSON file."""
        with open("wikidata_queries.json", "r") as f:
            return json.load(f)

    def get_triples(self, entity_id, entity_type):
        query = self.query_templates[entity_type].replace("{entity_id}", entity_id)
        self.sparql.setQuery(query)
        results = self.sparql.query().convert()

        triples = [
            (
                result.get('predicateLabel', {}).get('value', 'Unknown Relationship'),
                result.get('objectLabel', {}).get('value', 'Unknown Object')
            )
            for result in results['results']['bindings']
        ]

        return triples if triples else []

    def extract_knowledge(self, wikidata_entities):
        extracted_knowledge = []

        for entity_name, entity_data in wikidata_entities.items():
            entity_id = entity_data["info"][0]
            entity_label = entity_data["info"][1]
            entity_type = entity_data["type"]
            
            triples = []
            
            if self.cache_manager.use_cache:
                cached_data = self.cache_manager.get(entity_name)
                if cached_data and len(cached_data["triples"]) > 0:
                    triples = cached_data["triples"]
                    print(f"Using cached triples for {entity_name}")
                else:
                    print(f"Fetching triples for {entity_name} from Wikidata...")
                    triples = self.get_triples(entity_id, entity_type)
                    
                    
                    if cached_data:
                        cached_data["triples"] = triples
                        self.cache_manager.update(entity_name, cached_data)
                    else:
                        
                        self.cache_manager.update(entity_name, {"triples": triples})
            else:
                print(f"Fetching triples for {entity_name} from Wikidata (cache disabled)...")
                triples = self.get_triples(entity_id, entity_type)

            for triple in triples:
                knowledge = f"{entity_label} - {' - '.join(triple)}"
                extracted_knowledge.append(knowledge)

        return extracted_knowledge
    

# knowledge_extractor = KnowledgeExtractor()  
# extracted_knowledge = knowledge_extractor.extract_knowledge(wikidata_entities)
# print (extracted_knowledge)
# # for knowledge in extracted_knowledge:
# #     print(knowledge)

KnowledgeOptimizer: reduces redundancy and organizes Wikidata triples into a structured format by grouping relationships under their respective subjects.

Inputs: a list of knowledge triples as strings.

Outputs: a structured, human-readable concise version of optimized triples.

In [None]:
class KnowledgeOptimizer:
    def __init__(self):
        self.grouped_info = {}  # Use a regular dictionary

    def add_triples(self, triples):
        for triple in triples:
            try:
                subject, predicate, obj = triple.split(" - ")
                
                if subject not in self.grouped_info:
                    self.grouped_info[subject] = {}
                            
                if predicate not in self.grouped_info[subject]:
                    self.grouped_info[subject][predicate] = set()
                
                self.grouped_info[subject][predicate].add(obj)
            except ValueError:
               
                print(f"Skipping malformed triple: {triple}")

    def get_optimized_triples(self):
        optimized_list = []
        for subject, predicates in self.grouped_info.items():
            # Format each predicate-object pair on a new line
            predicate_object_pairs = [f" {pred} - {', '.join(objs)}" for pred, objs in predicates.items()]
            # Join all predicate-object pairs for the subject with newlines
            merged_facts = ",\n".join(predicate_object_pairs)
            optimized_list.append(f"{subject}: \n{merged_facts}")
        
        optimized_knowledge = "\n".join(optimized_list)
        return optimized_knowledge

    def reset(self):
        self.grouped_info.clear()
        

# # Example usage
# optimizer = KnowledgeOptimizer()
# optimizer.add_triples(extracted_knowledge)
# optimized_knowledge = optimizer.get_optimized_triples()
# print(optimized_knowledge)
# optimizer.reset()

PromptCreator: generates a structured prompt by merging input text with optimized knowledge, ensuring external information is clearly integrated.

Inputs: raw text (or file path) and structured knowledge.

Outputs: a formatted prompt combining both elements.

In [None]:
class PromptCreator:
    
    def __init__(self):
        pass
        
    def create_prompt(self, input_text, optimized_knowledge):
        if os.path.isfile(input_text):
            with open(input_text, "r", encoding="utf-8") as file:
                text = file.read()
        else:
            text = input_text

        augmented_text = f"""

[KNOWLEDGE]
{optimized_knowledge}
[/KNOWLEDGE]

[TEXT]
{text}
[/TEXT]
        """
           
        return augmented_text
    
#Example usage:
# prompt_creator = PromptCreator()
# augmented_text = prompt_creator.create_prompt("InputText.txt", optimized_knowledge)
# # prompt_creator.save_prompt(augmented_text, "AugmentedPrompt.txt")
# print(augmented_text)

MistralSummarizer: generates summaries using Mistral Large with optional knowledge augmentation to integrate relevant external facts.

Inputs:
augmented text (includes [KNOWLEDGE] and [TEXT]).

compression ratio for output length.

knowledge augmentation flag (True/False).

Outputs: a concise summary, optionally enriched with bold-highlighted knowledge.

In [None]:
class MistralSummarizer:
    def __init__(self, api_key):
        self.api_key = api_key
        self.api_url = "https://api.mistral.ai/v1/chat/completions"

    def count_tokens(self, augmented_text):
        if not hasattr(self, 'tokenizer'):   
            self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")

        tokens = self.tokenizer(augmented_text, return_tensors="pt")
        self.n_tokens = len(tokens.input_ids[0])
        return self.n_tokens

    def summarize(self, augmented_text,  compression_ratio=0.4, knwoledge_augmentation = True):
        input_length = self.count_tokens(augmented_text)
        num_tokens = max(100, int(input_length * compression_ratio))

        if knwoledge_augmentation:
            payload = {
                "model": "mistral-large-2407",
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant that summarizes text directly introducing relevant facts from the [KNOWLEDGE] section accurately using **bold**."},
                    {"role": "user", "content": f"""Please summarize the following text concisely. When information in the [TEXT] relates to entities mentioned in the [KNOWLEDGE] section, incorporate those relevant facts, using **bold**, to provide additional information.

                    {augmented_text}
                    
                    """}
                ],
                "max_tokens": num_tokens,
                "temperature": 0.4  
            }
        else:
            payload = {
                "model": "mistral-large-2407",
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant that summarizes text directly."},
                    {"role": "user", "content": f"""Please summarize the following text concisely. 

                    {augmented_text}
                    
                    """}
                ],
                "max_tokens": num_tokens,
                "temperature": 0.4  
            }
       
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }

        response = requests.post(self.api_url, headers=headers, data=json.dumps(payload))

        if response.status_code != 200:
            raise Exception(f"Mistral API request failed with status code {response.status_code}: {response.text}")

        result = response.json()
        summary = result["choices"][0]["message"]["content"].strip()
        return summary


#  # Example Usage
# api_key = "wOguKQrDCdCAXg9ceIgvlS0a22t9gxEz"
# summarizer = MistralSummarizer(api_key) 
# summary = summarizer.summarize(augmented_text)
# print(f"\n Summary: {summary}")

KnowledgeAugmentationPipeline: executes a full knowledge-augmented summarization pipeline, integrating Wikidata knowledge into summaries using Mistral Large.

Inputs: path to input text or raw text.

Outputs:
augmented summary (with knowledge integration).

non-augmented summary (standard summarization).

structured input prompt used for summarization.

In [None]:
class KnowledgeAugmentationPipeline:
    def __init__(self, api_key, use_cache=True):
        self.api_key = api_key
        self.entity_finder = EntityFinder()
        self.cache_manager = CacheManager(use_cache=use_cache)  # Singleton handles caching
        self.entity_linker = EntityLinker()
        self.knowledge_extractor = KnowledgeExtractor()
        self.optimizer = KnowledgeOptimizer()
        self.prompt_creator = PromptCreator()
        self.summarizer = MistralSummarizer(api_key)
    
    def process_text(self, input_path):
      
        extracted_entities = self.entity_finder.ner(input_path)
        sorted_entities = self.entity_finder.get_sorted_entities(extracted_entities)

        wikidata_entities = self.entity_linker.link_entities(sorted_entities)

        extracted_info = self.knowledge_extractor.extract_knowledge(wikidata_entities)

        self.optimizer.add_triples(extracted_info)
        optimized_knowledge = self.optimizer.get_optimized_triples()

        ka_text = self.prompt_creator.create_prompt(input_path, optimized_knowledge)

        augmented_summary = self.summarizer.summarize(ka_text, knwoledge_augmentation=True)
        non_augmented_summary = self.summarizer.summarize(input_path, knwoledge_augmentation=False)
       
        self.optimizer.reset()
       
        return augmented_summary, non_augmented_summary, ka_text
        

Example Execution

In [None]:
full_article = "InputText"  

pipeline = KnowledgeAugmentationPipeline(api_key="wOguKQrDCdCAXg9ceIgvlS0a22t9gxEz", use_cache=False)
with open("comparisons.txt", "a", encoding="utf-8") as f:
     
        augmented_summary, non_augmented_summary, ka_text = pipeline.process_text(full_article)

        f.write(f"Article:\n")
        f.write(f"Knowledge-Augmented Text (ka_text):\n{ka_text}\n\n")
        f.write(f"Augmented Summary:\n{augmented_summary}\n\n")
        f.write(f"Non-Augmented Summary:\n{non_augmented_summary}\n\n")
        f.write("=" * 80 + "\n\n") 

print("Summarization results saved to comparisons.txt.")

In [None]:
# Initialize your Knowledge Augmentation Pipeline
pipeline = KnowledgeAugmentationPipeline(api_key="wOguKQrDCdCAXg9ceIgvlS0a22t9gxEz", use_cache=True)
  # Number of articles to process

# Load the XSum dataset 
num_articles = 1
dataset = load_dataset("xsum", split="test", trust_remote_code=True)  


# Open file to save results
with open("comparisons.txt", "a", encoding="utf-8") as f:
    for i, article in enumerate(dataset):
        if i >= num_articles:  # Stop after processing the specified number of articles
            break

        full_article = article["document"]  # Access the full article text
         # Access the reference summary (if needed)
        # Process the article to get both augmented and non-augmented summaries
        augmented_summary, non_augmented_summary, ka_text = pipeline.process_text(full_article)

        # Save results to the file
        f.write(f"Article {i+1}:\n")
        f.write(f"Knowledge-Augmented Text (ka_text):\n{ka_text}\n\n")
        f.write(f"Augmented Summary:\n{augmented_summary}\n\n")
        f.write(f"Non-Augmented Summary:\n{non_augmented_summary}\n\n")
        f.write("=" * 80 + "\n\n")  # Separator for readability

print("Summarization results saved to sum_results.txt.")