<a href="https://colab.research.google.com/github/Hicham-Yezza/Neurosymbolic-LLM-Project/blob/main/KG_LLM_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# KG-Enhanced Summarisation Project - Hicham Yezza 2024

In [None]:
# Uninstall the incompatible version of pyarrow
!pip uninstall pyarrow -y

# Install the compatible version of pyarrow
!pip install pyarrow==14.0.1

# Restart the runtime (you will need to manually restart the runtime after this step)

In [None]:
!export LC_ALL=C.UTF-8
!export LANG=C.UTF-8

In [None]:
import torch

# Select device based on GPU availability
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Ensure mixed precision support
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Install all necessary libraries in one line, avoiding redundancy
!pip install wandb rouge_score sacrebleu bert-score spacy networkx datasets pandas tqdm transformers nltk torch evaluate node2vec sentence_transformers
# Download SpaCy models (both transformer-based and small model)
!python -m spacy download en_core_web_trf
!python -m spacy download en_core_web_sm

In [None]:
# Standard Library Imports
import os
import re
import gc
import json
import time
import logging
import warnings
from collections import defaultdict
from typing import List, Tuple, Dict
from transformers import logging

# Third-Party Library Imports
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from tqdm import tqdm, notebook  # For progress bars
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics.pairwise import cosine_similarity
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# SpaCy for NLP
import spacy

# Hugging Face Transformers
from transformers import (
    BartTokenizer, BartForConditionalGeneration, AdamW, get_scheduler, BartConfig, pipeline
)

# Datasets, Evaluation, and Metrics
from datasets import load_dataset
from evaluate import load as load_metric
from rouge_score import rouge_scorer
import sacrebleu
import bert_score

# Knowledge Graph and Node Embedding
from node2vec import Node2Vec

# Weights and Biases for Experiment Tracking
import wandb

# Sentence Transformers for Embeddings
from sentence_transformers import SentenceTransformer

In [None]:
# Setup
timestamp = int(time.time())
warnings.filterwarnings("ignore", message="Some weights of RobertaModel were not initialized from the model checkpoint")

# Set logging level to ERROR to suppress warnings
logging.set_verbosity_error()

In [None]:
# Load the SpaCy model for NER (Transformer-based for higher accuracy)
nlp = spacy.load("en_core_web_trf")

In [None]:
# Entity extraction, relation extraction, KG generation

In [None]:
# Define KnowledgeGraphExtractor class with improved entity extraction and relation extraction
class KnowledgeGraphExtractor:
    def __init__(self, use_trf_model=True):
        if use_trf_model:
            spacy.require_gpu()  # Ensure the GPU is used if available
            self.nlp = spacy.load("en_core_web_trf")
        else:
            self.nlp = spacy.load("en_core_web_sm")
        self.matcher = spacy.matcher.Matcher(self.nlp.vocab)
        self._add_patterns()

    def _add_patterns(self):
        # Add common subject-verb-object patterns to the matcher
        patterns = [
            [{'DEP': 'nsubj'}, {'DEP': 'ROOT'}, {'DEP': 'dobj'}],
            [{'DEP': 'nsubj'}, {'DEP': 'ROOT'}, {'DEP': 'prep'}, {'DEP': 'pobj'}],
            [{'DEP': 'nsubj'}, {'DEP': 'ROOT'}, {'DEP': 'attr'}],
            [{'DEP': 'nsubjpass'}, {'DEP': 'ROOT'}, {'DEP': 'pobj'}],  # For passive voice
        ]
        for pattern in patterns:
            self.matcher.add("SVO", [pattern])

    def _merge_similar_entities(self, entities):
        # Helper function to merge similar entities like "John Terry" and "Terry"
        merged_entities = {}
        for entity, label in entities:
            key = re.sub(r"\b(the|a|an)\b", "", entity.lower()).strip()  # Remove articles
            if key not in merged_entities:
                merged_entities[key] = (entity, label)
        return list(merged_entities.values())

    def _filter_low_value_entities(self, entities):
        # Filter out low-value entities like generic noun chunks ("his wife", "this")
        low_value_terms = {"this", "we", "his", "her", "it", "they", "he", "she"}
        return [(entity, label) for entity, label in entities if entity.lower() not in low_value_terms]

    def clean_entities(self, entities):
        # Additional entity cleaning logic to refine and filter entities
        cleaned_entities = []
        seen = set()

        for entity, label in entities:
            if entity.lower() not in seen:
                seen.add(entity.lower())
                if len(entity) > 2 and not entity.isdigit():  # Avoid single-character entities and pure numbers
                    cleaned_entities.append((entity, label))

        return cleaned_entities

    def extract_entities(self, doc):
        # Extract named entities and noun chunks
        entities = set()
        logging.info(f"Document text: {doc.text[:100]}")  # Log first 100 characters of the document
        logging.info(f"Detected entities: {[ent.text for ent in doc.ents]}")

        # Add named entities to the set
        for ent in doc.ents:
            entities.add((ent.text.strip(), ent.label_))

        # Also consider noun chunks as potential entities
        for chunk in doc.noun_chunks:
            entities.add((chunk.text.strip(), "NOUN_CHUNK"))

        # Merge similar entities, filter out low-value ones, and clean the entities
        entities = self._merge_similar_entities(entities)
        entities = self._filter_low_value_entities(entities)
        entities = self.clean_entities(entities)

        return entities

    def extract_relations(self, doc):
        # Extract relations using pattern matching and subject-object dependencies
        matches = self.matcher(doc)
        relations = []

        # Extract SVO triples using pattern matches
        for match_id, start, end in matches:
            span = doc[start:end]
            if len(span) >= 3:
                subj, verb, obj = span[0].text.strip(), span[1].text, span[-1].text.strip()
                relations.append((subj, verb, obj))

        # Extract additional relations using dependency parsing
        for token in doc:
            if token.dep_ in ("nsubj", "nsubjpass"):
                for child in token.head.children:
                    if child.dep_ == "dobj":
                        relations.append((token.text.strip(), token.head.text, child.text.strip()))
                    elif child.dep_ == "prep" and child.children:
                        for pobj in child.children:
                            if pobj.dep_ == "pobj":
                                relations.append((token.text.strip(), token.head.text + " " + child.text, pobj.text.strip()))

        # Filter out redundant or trivial relations
        relations = self._filter_trivial_relations(relations)
        return relations

    def _filter_trivial_relations(self, relations):
        # Filter out low-value or redundant relations
        low_value_verbs = {"is", "are", "was", "were", "have", "has", "had", "do", "did"}
        filtered_relations = [(subj, verb, obj) for subj, verb, obj in relations if verb.lower() not in low_value_verbs]
        return filtered_relations

# Define KnowledgeGraphGenerator class for generating and post-processing the KG
class KnowledgeGraphGenerator:
    def __init__(self, use_trf_model=True):
        self.extractor = KnowledgeGraphExtractor(use_trf_model)

    def create_graph(self, text):
        # Process text and extract entities and relations
        doc = self.extractor.nlp(text)
        entities = self.extractor.extract_entities(doc)
        relations = self.extractor.extract_relations(doc)

        # Log extracted entities and relations for debugging, limiting the output size
        logging.info(f"Extracted entities: {entities[:5]}...")  # Limit entity log output to first 5 items
        logging.info(f"Extracted relations: {relations[:5]}...")  # Limit relation log output to first 5 items

        # Create the knowledge graph using NetworkX
        G = nx.Graph()

        # Add nodes for each entity
        for entity, label in entities:
            G.add_node(entity, label=label)

        # Add edges for each relation
        for subj, pred, obj in relations:
            if G.has_node(subj) and G.has_node(obj):
                G.add_edge(subj, obj, relation=pred)

        return G

    def clean_knowledge_graph(self, G):
        # Remove isolated or irrelevant nodes (avoid removing too many)
        isolated_nodes = [node for node, degree in G.degree() if degree == 0]
        G.remove_nodes_from(isolated_nodes)

        # Dynamically remove low-value nodes that add little to the graph
        low_value_nodes = [node for node in G.nodes if re.match(r"\b(this|that|we|he|she|they)\b", node, re.I)]
        G.remove_nodes_from(low_value_nodes)

        return G

In [None]:
# Wandb setup
wandb.login()

# for wandb sweep hp tuning

sweep_config = {
    'method': 'bayes',  # Bayesian search
    'metric': {'name': 'rouge', 'goal': 'maximize'},  # Optimize based on Rouge score
    'parameters': {
        'max_length': {'values': [100, 150, 200]},
        'min_length': {'values': [40, 50, 60]},
        'num_beams': {'values': [4, 6, 8]},
        'graph_attention_weight': {'values': [0.3, 0.5, 0.7]},  # For graph-aware models
        'walk_length': {'values': [10, 30, 50]},  # Node2Vec
        'num_walks': {'values': [100, 200, 300]},  # Node2Vec
        'dimensions': {'values': [64, 128]},  # Node2Vec
    }
}

In [None]:
# Implementing the four summarizer classes

In [None]:
class BaseBARTSummarizer: ## base model
    def __init__(self, model_name: str = 'facebook/bart-large'):
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name)
        self.model.to(device)

    def summarize(self, text: str, max_length: int = 150, min_length: int = 40) -> str:
        inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True).to(device)
        summary_ids = self.model.generate(inputs['input_ids'], max_length=max_length, min_length=min_length, num_beams=4)
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:
class KGEnhancedSummarizer: # KG-Enhanced Summarizer Base Class
    def __init__(self, model_name: str = 'facebook/bart-large'):
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name)
        self.kg_generator = KnowledgeGraphGenerator()

    def summarize(self, text: str, max_length: int = 150, min_length: int = 40) -> str:
        raise NotImplementedError("This method should be implemented by subclasses")

    def save_model(self, path: str):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

In [None]:
class KGEnhancedInputSummarizer(KGEnhancedSummarizer):  # With structured prompting
    def summarize(self, text: str, max_length: int = 150, min_length: int = 40) -> str:
        # Step 1: Generate the knowledge graph from the input text
        G = self.kg_generator.create_graph(text)

        # Step 2: Construct the knowledge graph representation
        entities = ", ".join([f"{node}" for node in G.nodes()])
        relations = ", ".join([f"{u}-{G[u][v].get('relation', 'related')}-{v}" for u, v in G.edges()])

        # Step 3: Create structured prompt for the LLM
        structured_prompt = f"""
        Knowledge Graph Information:
        - Entities: {entities}
        - Relations: {relations}

        Article:
        {text}

        Task: Summarize the article considering the knowledge graph information above, focusing on how the entities and relationships influence the article's content.
        """

        # Step 4: Tokenize the structured prompt for input to the model
        inputs = self.tokenizer(structured_prompt, return_tensors="pt", max_length=1024, truncation=True).to(device)

        # Step 5: Generate the summary using the model
        summary_ids = self.model.generate(
            inputs['input_ids'],
            max_length=max_length,
            min_length=min_length,
            num_beams=4
        )

        # Step 6: Decode and return the generated summary
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:
class GraphAwareAttentionSummarizer(KGEnhancedSummarizer): ## KG-Enhanced Graph-Aware summarizer
    """
    This class enhances BART-based summarization with graph-aware attention,
    including second-order neighbors in the attention mask.
    """

    def __init__(self, model_name='facebook/bart-large', device='cuda'):
        super().__init__()
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model = BartForConditionalGenerationWithGraphAttention.from_pretrained(model_name)
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model.to(self.device)

    def _create_graph_attention_mask(self, graph_data, input_length):
        """
        Generates an attention mask based on the provided knowledge graph data,
        including second-order neighbors.

        Args:
            graph_data: The graph structure containing relationships between nodes (entities).
            input_length: The length of the input sequence for the summarization model.

        Returns:
            A graph-aware attention mask of shape (input_length, input_length).
        """
        # Assume graph_data contains adjacency matrix
        adjacency_matrix = graph_data.get('adjacency_matrix', None)

        if adjacency_matrix is None or adjacency_matrix.shape[0] != input_length:
            logging.warning("Graph data missing or mismatched with input length.")
            return torch.ones((input_length, input_length), device=self.device)  # Default to all 1s (no extra attention)

        # Normalize adjacency matrix to be suitable for attention
        normalized_adjacency_matrix = F.softmax(adjacency_matrix.float(), dim=-1)

        # Ensure the matrix is symmetric for bidirectional attention
        graph_attention_mask = 0.5 * (normalized_adjacency_matrix + normalized_adjacency_matrix.T)

        # Add second-order neighbors by multiplying the mask with itself
        second_order_neighbors = torch.mm(graph_attention_mask, graph_attention_mask)
        graph_attention_mask += second_order_neighbors

        # Normalize the final mask
        graph_attention_mask = graph_attention_mask / (graph_attention_mask.sum(dim=-1, keepdim=True) + 1e-9)

        return graph_attention_mask

    def summarize(self, input_text, graph_data=None, max_length=142, min_length=56):
        """
        Summarizes the given input text with optional graph-aware attention.

        Args:
            input_text: The input text to summarize.
            graph_data: Knowledge graph data, including adjacency matrix and entities.
            max_length: Maximum length of the generated summary.
            min_length: Minimum length of the generated summary.

        Returns:
            The generated summary text.
        """
        inputs = self.tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True)
        input_ids = inputs['input_ids'].to(self.device)
        attention_mask = inputs['attention_mask'].to(self.device)

        # If graph data is provided, create the graph attention mask
        graph_attention_mask = self._create_graph_attention_mask(graph_data, input_ids.shape[1]) if graph_data else None

        # Perform summarization with the model
        summary_ids = self.model.generate(
            input_ids,
            attention_mask=attention_mask,
            graph_attention_mask=graph_attention_mask,  # Pass the graph-aware mask
            max_length=max_length,
            min_length=min_length,
            num_beams=4,
            length_penalty=2.0,
            early_stopping=True
        )

        # Decode and return the generated summary
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def visualize_graph_attention(self, attention_data, graph_data):
        """
        Visualizes the graph attention for debugging purposes.

        Args:
            attention_data: Attention weights from the model.
            graph_data: The input knowledge graph structure.
        """
        import matplotlib.pyplot as plt
        import networkx as nx

        adjacency_matrix = graph_data.get('adjacency_matrix', None)
        if adjacency_matrix is None:
            logging.error("No graph data to visualize.")
            return

        # Create a graph plot with the attention weights as edge attributes
        graph = nx.from_numpy_matrix(adjacency_matrix.cpu().numpy())
        plt.figure(figsize=(8, 6))
        pos = nx.spring_layout(graph)
        nx.draw(graph, pos, with_labels=True, node_color="lightblue", edge_color="gray", node_size=500, font_size=10)

        plt.title("Graph Attention Visualization")
        plt.show()


In [None]:
class KGConsistencyCheckingSummarizer(KGEnhancedSummarizer):
    def __init__(self, model_name: str = 'facebook/bart-large', threshold: float = 0.3):
        super().__init__(model_name)
        self.consistency_checker = ImprovedConsistencyChecker()
        self.threshold = threshold  # Allowing a flexible threshold

    def summarize(self, text: str, max_length: int = 150, min_length: int = 40) -> str:
        # Step 1: Generate the knowledge graph for the original article
        G_article = self.kg_generator.create_graph(text)
        if not G_article.nodes():
            logging.error(f"No nodes generated in the KG for the article.")
            return ""

        # Step 2: Generate the initial summary
        inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True).to(device)
        summary_ids = self.model.generate(inputs['input_ids'], max_length=max_length, min_length=min_length, num_beams=4)
        summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        # Step 3: Generate the graph for the generated summary
        G_summary = self.kg_generator.create_graph(summary)
        if not G_summary.nodes():
            logging.error(f"No nodes generated in the KG for the summary.")
            return summary

        # Step 4: Calculate the consistency score between the article and the summary
        consistency_score = self.consistency_checker.calculate_consistency(G_article, G_summary)
        logging.info(f"Initial Consistency Score: {consistency_score}")

        # Step 5: If the consistency score is below the threshold, regenerate the summary
        retry_count = 0
        max_retries = 5  # Avoid excessive retries
        while consistency_score < self.threshold and retry_count < max_retries:
            logging.info("Consistency score below threshold, regenerating summary with updated parameters.")
            summary_ids = self.model.generate(inputs['input_ids'], max_length=max_length, min_length=min_length, num_beams=6, repetition_penalty=2.0)
            summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

            # Step 6: Recalculate consistency for the new summary
            G_summary = self.kg_generator.create_graph(summary)
            consistency_score = self.consistency_checker.calculate_consistency(G_article, G_summary)
            logging.info(f"Revised Consistency Score: {consistency_score}")
            retry_count += 1

        # Return the final summary (either the original or the regenerated one)
        return summary

In [None]:
class ImprovedConsistencyChecker:
    def __init__(self):
        # Load a pre-trained sentence transformer model
        self.model = SentenceTransformer('all-MiniLM-L6-v2')

    def calculate_consistency(self, G1: nx.Graph, G2: nx.Graph) -> float:
        # Extract nodes from both graphs
        nodes1 = list(G1.nodes())
        nodes2 = list(G2.nodes())

        # If both graphs have no nodes, consider them similar (consistent)
        if not nodes1 or not nodes2:
            logging.warning("One or both graphs have no nodes.")
            return 0.0

        # Generate node embeddings using the pre-trained model
        try:
            embeddings1 = self.model.encode(nodes1)
            embeddings2 = self.model.encode(nodes2)
        except Exception as e:
            logging.error(f"Error encoding nodes: {e}")
            return 0.0  # Return default score in case of failure

        # Calculate cosine similarity for the node embeddings
        similarity_matrix = cosine_similarity(embeddings1, embeddings2)
        node_similarity = similarity_matrix.max(axis=1).mean()

        # Extract edges and relations from both graphs
        edges1 = list(G1.edges(data=True))
        edges2 = list(G2.edges(data=True))

        # Handle cases where one or both graphs have no edges
        if not edges1 and not edges2:
            edge_similarity = 1.0  # Both graphs have no edges, consider them similar
        elif not edges1 or not edges2:
            edge_similarity = 0.0  # One graph has edges, the other doesn't
        else:
            # Extract relation labels and compute embeddings for the relations
            edge_relations1 = [edge_data.get('relation', '') for _, _, edge_data in edges1]
            edge_relations2 = [edge_data.get('relation', '') for _, _, edge_data in edges2]

            try:
                relation_embeddings1 = self.model.encode(edge_relations1)
                relation_embeddings2 = self.model.encode(edge_relations2)
                edge_similarity_matrix = cosine_similarity(relation_embeddings1, relation_embeddings2)
                edge_similarity = edge_similarity_matrix.max(axis=1).mean()
            except Exception as e:
                logging.error(f"Error encoding edges: {e}")
                edge_similarity = 0.0  # Return default score in case of failure

        # Return the final consistency score based on node and edge similarities
        return (node_similarity + edge_similarity) / 2





In [None]:
# Custom BART models

In [None]:
class BartForConditionalGenerationWithKGEmbeddings(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.kg_embedding_projection = torch.nn.Linear(config.d_model + 64, config.d_model)
        self.kg_embedding_projection.to(device)  # Move to the correct device

    def forward(self, input_ids, attention_mask=None, kg_embeddings=None, **kwargs):
        if kg_embeddings is not None:
            inputs_embeds = self.model.encoder.embed_tokens(input_ids) * self.model.encoder.embed_scale
            inputs_embeds = torch.cat([inputs_embeds, kg_embeddings], dim=-1)
            inputs_embeds = self.kg_embedding_projection(inputs_embeds)
            encoder_outputs = self.model.encoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
            return super().forward(encoder_outputs=encoder_outputs, attention_mask=attention_mask, **kwargs)
        else:
            return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
class BartForConditionalGenerationWithGraphAttention(BartForConditionalGeneration):
    def forward(self, input_ids, attention_mask=None, graph_attention_mask=None, **kwargs):
        """
        The forward method now combines the standard attention mask with the graph attention mask.
        The graph attention mask adjusts attention based on the structure of the knowledge graph.
        """
        if graph_attention_mask is not None:
            # Combine normal attention with graph-based attention using a scaling factor
            combined_attention = attention_mask.unsqueeze(1) * (1 + self.graph_attention_weight * graph_attention_mask.unsqueeze(0))
            combined_attention = combined_attention / (combined_attention.sum(dim=-1, keepdim=True) + 1e-9)
        else:
            combined_attention = attention_mask

        # Pass the combined attention through the BART model
        return super().forward(input_ids, attention_mask=combined_attention, **kwargs)

In [None]:
# Evaluation

In [None]:
class SummarizerEvaluator:
    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.meteor = evaluate.load("meteor")
        self.smoothie = SmoothingFunction().method1

    def evaluate(self, reference: str, candidate: str) -> Dict[str, float]:
        rouge_scores = self.rouge_scorer.score(reference, candidate)
        bleu_score = sentence_bleu([reference.split()], candidate.split(), smoothing_function=self.smoothie)
        meteor_score = self.meteor.compute(predictions=[candidate], references=[reference])['meteor']

        # Calculate BERTScore
        P, R, F1 = bert_score.score([candidate], [reference], lang="en", verbose=False)

        return {
            "rouge1": rouge_scores['rouge1'].fmeasure,
            "rouge2": rouge_scores['rouge2'].fmeasure,
            "rougeL": rouge_scores['rougeL'].fmeasure,
            "bleu": bleu_score,
            "meteor": meteor_score,
            "bertscore_precision": P.item(),
            "bertscore_recall": R.item(),
            "bertscore_f1": F1.item()
        }

In [None]:
# Comparison

In [None]:
def compare_summarizers(dataset, summarizers: Dict[str, KGEnhancedSummarizer], num_samples: int = 50):
    evaluator = SummarizerEvaluator()
    results = {name: [] for name in summarizers.keys()}
    example_summaries = {name: [] for name in summarizers.keys()}

    for item in tqdm(dataset.select(range(num_samples)), desc="Processing articles"):
        article = item['article']
        reference_summary = item['highlights']

        for name, summarizer in summarizers.items():
            try:
                generated_summary = summarizer.summarize(article)
                scores = evaluator.evaluate(reference_summary, generated_summary)
                results[name].append(scores)

                # Save example summaries (let's save the first 5)
                if len(example_summaries[name]) < 5:
                    example_summaries[name].append({
                        'article': article,
                        'reference': reference_summary,
                        'generated': generated_summary
                    })

                # Log individual sample results to wandb
                wandb.log({f"{name}_{metric}": score for metric, score in scores.items()})
            except Exception as e:
                print(f"Error with {name} summarizer: {str(e)}")
                results[name].append({metric: float('nan') for metric in evaluator.evaluate("", "").keys()})

    # Calculate average scores
    avg_results = {}
    for name, scores_list in results.items():
        avg_scores = {metric: np.mean([s[metric] for s in scores_list]) for metric in scores_list[0].keys()}
        avg_results[name] = avg_scores

        # Log average scores to wandb
        wandb.log({f"avg_{name}_{metric}": score for metric, score in avg_scores.items()})

    return results, avg_results, example_summaries

In [None]:
# Main execution

In [None]:
# Start a new wandb run
run = wandb.init(project="kg-enhanced-summarization", name=f"comparison-run-{timestamp}") # a unique timestamp identifier for each run

In [None]:
# Dataset loading
dataset = load_dataset('cnn_dailymail', '3.0.0', split='train[:100]')

In [None]:
print(f"Dataset size: {len(dataset)}")

# Debug: Print dataset info
print("Dataset info:")
print(dataset)
print("\nDataset features:")
print(dataset.features)
print("\nFirst item in dataset:")
print(dataset[0])

In [None]:
summarizers = {
    "Base BART": BaseBARTSummarizer(),
    "KG-Enhanced Input": KGEnhancedInputSummarizer(),
    "Graph-Aware Attention": GraphAwareAttentionSummarizer(),
    "KG Consistency Checking": KGConsistencyCheckingSummarizer(),
    ## "KG Embedding Integration": KGEmbeddingIntegrationSummarizer() ## deprecated
}

In [None]:
# Move models to the appropriate device
for summarizer in summarizers.values():
    summarizer.model.to(device)

In [None]:
# Debug: Print first article and summary
print("\nFirst article:")
print(dataset[0]['article'])
print("\nFirst summary:")
print(dataset[0]['highlights'])

In [None]:
# Run comparison
results, avg_results, example_summaries = compare_summarizers(dataset, summarizers)

In [None]:
# Display results
df_results = pd.DataFrame(avg_results).transpose()
print("\nAverage Scores:")
print(df_results)

In [None]:
# Save example summaries
with open(f'example_summaries_{timestamp}.txt', 'w') as f:
    for name, summaries in example_summaries.items():
        f.write(f"\n\n{name} Summaries:\n")
        for i, summary in enumerate(summaries):
            f.write(f"\nExample {i+1}:\n")
            f.write(f"Article: {summary['article'][:200]}...\n")
            f.write(f"Reference: {summary['reference']}\n")
            f.write(f"Generated: {summary['generated']}\n")

In [None]:
# Log example summaries file to wandb
wandb.save(f'example_summaries_{timestamp}.txt')

In [None]:
# Save results to CSV
csv_filename = f"summarizer_comparison_results_{timestamp}.csv" # unqi
df_results.to_csv(csv_filename)
print(f"\nResults saved to {csv_filename}")

In [None]:
# Log CSV file to wandb
wandb.save(csv_filename)

In [None]:
# Statistical Analysis
print("\nStatistical Analysis:")
for metric in df_results.columns:
    print(f"\nMetric: {metric}")
    _, p_value = stats.f_oneway(*(
        [scores[metric] for scores in results[name]]
        for name in summarizers.keys()
    ))
    print(f"One-way ANOVA p-value: {p_value:.4f}")

    # Log p-value to wandb
    wandb.log({f"p_value_{metric}": p_value})

    if p_value < 0.05:
        print("Significant difference detected. Performing post-hoc Tukey HSD test.")
        tukey_results = stats.tukey_hsd(*(
            [scores[metric] for scores in results[name]]
            for name in summarizers.keys()
        ))
        print(tukey_results)

        # Log Tukey HSD results to wandb
        wandb.log({f"tukey_hsd_{metric}": tukey_results})

In [None]:
# Visualizations
plt.figure(figsize=(15, 8))  # Increased figure size to accommodate more metrics
sns.boxplot(data=pd.melt(pd.DataFrame(results)))
plt.title("Distribution of Scores Across Metrics and Models")
plt.xlabel("Metric")
plt.ylabel("Score")
plt.xticks(rotation=45)
plt.tight_layout()

In [None]:
# Save and log plot to wandb
plot_filename = "score_distribution.png"
plt.savefig(plot_filename)
wandb.log({"score_distribution": wandb.Image(plot_filename)})

plt.close()

In [None]:
# End the wandb run
wandb.finish()

In [None]:
# unit tests

In [None]:
# KG extraction

In [None]:
import unittest

class TestKnowledgeGraphExtractor(unittest.TestCase):
    def setUp(self):
        self.extractor = KnowledgeGraphExtractor()

    def test_simple_subject_object_relation(self):
        text = "John gave Mary a book."
        entities, relations = self.extractor.extract_entities_and_relations(text)
        self.assertIn(('John', 'PERSON'), entities)
        self.assertIn(('Mary', 'PERSON'), entities)
        self.assertIn(('book', 'OBJ'), [ent[0] for ent in entities])
        self.assertIn(('John', 'gave', 'book'), relations)

    def test_passive_voice_with_agent(self):
        text = "The book was given by John."
        entities, relations = self.extractor.extract_entities_and_relations(text)
        self.assertIn(('John', 'PERSON'), entities)
        self.assertIn(('book', 'OBJ'), [ent[0] for ent in entities])
        self.assertIn(('book', 'by', 'John'), relations)

    def test_prepositional_relation(self):
        text = "Mary put the book on the table."
        entities, relations = self.extractor.extract_entities_and_relations(text)
        self.assertIn(('Mary', 'PERSON'), entities)
        self.assertIn(('book', 'OBJ'), [ent[0] for ent in entities])
        self.assertIn(('book', 'on', 'table'), relations)

    def test_no_relations(self):
        text = "This is a test sentence with no relations."
        entities, relations = self.extractor.extract_entities_and_relations(text)
        self.assertEqual(len(relations), 0)

if __name__ == "__main__":
    unittest.main()


In [None]:
# KG consistency checker

import networkx as nx

# Define two knowledge graphs (G1 and G2)
G1 = nx.Graph()
G1.add_node("John", label="PERSON")
G1.add_node("Mary", label="PERSON")
G1.add_node("book", label="OBJECT")
G1.add_edge("John", "book", relation="gave")
G1.add_edge("book", "Mary", relation="to")

G2 = nx.Graph()
G2.add_node("John", label="PERSON")
G2.add_node("Mary", label="PERSON")
G2.add_node("book", label="OBJECT")
G2.add_edge("John", "book", relation="handed")
G2.add_edge("book", "Mary", relation="to")

# Instantiate and calculate consistency
checker = ImprovedConsistencyChecker()
consistency_score = checker.calculate_consistency(G1, G2)
print(f"Consistency Score: {consistency_score}")


In [None]:
# Generating a KG-enhanced version of CNN/Daily Mail

In [None]:
# Load the CNN/Daily Mail dataset using the datasets library
def load_cnn_dailymail_dataset():
    # Load a subset of the dataset containing articles
    dataset = load_dataset('cnn_dailymail', '3.0.0', split='train[:100]')  # Load 100 articles as a subset
    return dataset

In [None]:
# Process CNN/Daily Mail dataset articles using the KnowledgeGraphGenerator
def process_cnn_dailymail_articles(dataset):
    kg_generator = KnowledgeGraphGenerator(use_trf_model=True)
    kg_data = []

    for sample in tqdm(dataset, desc="Processing articles", unit="article"):
        article = sample['article']  # Focus on article text
        summary = sample.get('highlights', '')  # Extract the reference summary (highlights)

        # Create a knowledge graph for the article
        G = kg_generator.create_graph(article)
        G = kg_generator.clean_knowledge_graph(G)

        # Extract and format entities and relations from the graph
        entities = list(G.nodes(data="label"))
        relations = [(u, v, d['relation']) for u, v, d in G.edges(data=True)]

        # Store the article, entities, relations, and summary
        kg_data.append({
            'article': article,
            'entities': [e[0] for e in entities],
            'relations': relations,
            'summary': summary  # Include the reference summary
        })

    return kg_data

In [None]:
# Save the KG-enhanced dataset to a CSV file
def save_kg_enhanced_articles_dataset(kg_data, output_path):
    df_kg = pd.DataFrame(kg_data)
    df_kg.to_csv(output_path, index=False)

In [None]:
# Main execution flow
if __name__ == "__main__":
    # Step 1: Load the dataset
    logging.info("Loading CNN/Daily Mail dataset...")
    dataset = load_cnn_dailymail_dataset()
    logging.info("Dataset loaded successfully.")

    # Step 2: Process the CNN/Daily Mail articles to extract entities, relations, and create KG
    logging.info("Starting to process CNN/Daily Mail articles for KG extraction...")
    kg_data = process_cnn_dailymail_articles(dataset)
    logging.info("KG extraction completed successfully.")

    # Step 3: Save the KG-enhanced articles dataset to a new CSV file
    output_path = 'kg_enhanced_cnn_dailymail_articles.csv'
    logging.info(f"Saving the KG-enhanced articles dataset to {output_path}...")
    save_kg_enhanced_articles_dataset(kg_data, output_path)
    logging.info(f"Dataset saved successfully at {output_path}.")

    # (Optional) Print the first entry of the KG-enhanced data to verify the output
    logging.info("Printing the first entry of the KG-enhanced data for verification...")
    print(kg_data[0])

In [None]:
# Fine-tuning Base Bart on KG-enhanced dataset

In [None]:
# Step 1: Initialize Wandb for logging
wandb.init(project="KG-enhanced-BART-summarization", config={
    "learning_rate": 5e-5,
    "batch_size": 4,
    "epochs": 3,
    "subset_size": 100  # Subset size of CNN/Daily Mail for initial fine-tuning
})

# Step 2: Load KG-Enhanced CNN/Daily Mail Subset (replace with your KG-enhanced CSV file path)
kg_dataset_path = '/content/kg_enhanced_cnn_dailymail_articles.csv'
df = pd.read_csv(kg_dataset_path)

# Step 3: Create a custom dataset class for BART fine-tuning
class CNN_DailyMailDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=1024):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        article = self.data.iloc[idx]['article']
        summary = self.data.iloc[idx]['summary']

        inputs = self.tokenizer(
            article, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        labels = self.tokenizer(
            summary, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
        ).input_ids

        inputs["labels"] = labels
        return inputs

# Evaluation method

def evaluate_model_during_training(model, dataset, tokenizer):
    """
    Evaluate the fine-tuned model using ROUGE, BLEU, and BERTScore during training.
    """
    model.eval()  # Set the model to evaluation mode
    predictions = []
    references = []

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for idx in range(len(dataset)):
            article = dataset.iloc[idx]['article']
            summary = dataset.iloc[idx]['summary']

            inputs = tokenizer(article, max_length=1024, return_tensors="pt", truncation=True).input_ids.cuda()
            generated_ids = model.generate(inputs, max_length=150, num_beams=5, length_penalty=2.0)
            prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

            predictions.append(prediction)
            references.append(summary)

    # Calculate ROUGE scores
    rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = [rouge_scorer_instance.score(ref, pred) for ref, pred in zip(references, predictions)]

    # Calculate BLEU score
    bleu = sacrebleu.corpus_bleu(predictions, [references])

    # Calculate BERTScore
    bertscore = bert_score.score(predictions, references, lang="en")

    return rouge_scores, bleu.score, bertscore

# Step 4: Initialize BART tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large").cuda()

# Step 5: Prepare the dataset and dataloader
dataset = CNN_DailyMailDataset(df, tokenizer)
dataloader = DataLoader(dataset, batch_size=wandb.config.batch_size, shuffle=True)

# Step 6: Setup optimizer, scaler (for AMP), and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=wandb.config.learning_rate)
scaler = GradScaler()
scheduler = get_scheduler(
    "cosine", optimizer=optimizer, num_warmup_steps=100, num_training_steps=len(dataloader) * wandb.config.epochs
)

# Step 7: Fine-tuning loop with AMP and gradient accumulation
for epoch in range(wandb.config.epochs):
    model.train()  # Set the model to training mode
    total_loss = 0

    for i, batch in enumerate(dataloader):
        batch = {key: value.squeeze().cuda() for key, value in batch.items()}  # Move the batch to GPU

        with autocast():  # Automatic Mixed Precision for faster training
            outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
            loss = outputs.loss  # Loss per batch

        scaler.scale(loss).backward()  # Scale the loss for mixed precision

        if (i + 1) % wandb.config.batch_size == 0:  # Perform optimizer step after accumulating gradients
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()  # Clear gradients after step
            scheduler.step()  # Update the learning rate based on scheduler

        total_loss += loss.item()
        wandb.log({"training_loss": total_loss / (i + 1)})  # Log the average loss up to this point

    # Log epoch-wise metrics and save the model at the end of each epoch
    model.save_pretrained(f"bart_finetuned_epoch_{epoch}")
    wandb.log({"epoch": epoch, "avg_loss": total_loss / len(dataloader)})

    # **Evaluate the model after each epoch**
    rouge_scores, bleu_score_value, bertscore = evaluate_model_during_training(model, df, tokenizer)

    # **Log evaluation metrics to W&B**
    wandb.log({
        "epoch": epoch,
        "ROUGE-1": sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores),
        "ROUGE-2": sum([score['rouge2'].fmeasure for score in rouge_scores]) / len(rouge_scores),
        "ROUGE-L": sum([score['rougeL'].fmeasure for score in rouge_scores]) / len(rouge_scores),
        "BLEU": bleu_score_value,
        "BERTScore_Precision": bertscore[0].mean().item(),
        "BERTScore_Recall": bertscore[1].mean().item(),
        "BERTScore_F1": bertscore[2].mean().item()
    })

# Finalize logging after all epochs are done
wandb.finish()