# EG-RAG: Evidence Graph based Retrieval-Augmented Generation

This notebook demonstrates the EG-RAG pipeline for handling conflicting information in multi-document QA.

**Paper**: AAMAS 2026 - "EG-RAG: Retrieval-Augmented Generation with Evidence Graph for Reliable Multi-Document Reasoning"

## 1. Setup

First, install the required dependencies and set up your API key.

In [None]:
# Install dependencies (run once)
# !pip install -r ../requirements.txt

In [None]:
import os
import sys
from dotenv import load_dotenv

# Add src to path
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

# Load environment variables from .env file
load_dotenv()

# Verify API key is set (do NOT hardcode your key!)
assert os.getenv('OPENAI_API_KEY'), "Please set OPENAI_API_KEY in your .env file"

## 2. Import Libraries

In [None]:
import re
import numpy as np
import networkx as nx
import torch
from tqdm import tqdm

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import openai

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'MPS' if torch.backends.mps.is_available() else 'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 3. Core Functions

### 3.1 Text Processing

In [None]:
def simple_sent_tokenize(text: str) -> list[str]:
    """Split text into sentences based on punctuation."""
    return re.split(r'(?<=[.\?\!])\s+', text.strip())


def separate_passages(document: str) -> list[str]:
    """Split document by 'Document:' delimiter."""
    raw_passages = document.split("Document:")
    return [p.strip() for p in raw_passages if p.strip()]

### 3.2 Key Sentence Extraction

In [None]:
# Set device
device = torch.device("mps") if torch.backends.mps.is_available() else \
         torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")


def extract_key_sentences(
    docs: list[str],
    query: str,
    model_name: str = "all-MiniLM-L6-v2",
    top_k: int = 3
) -> tuple[list[str], list[float]]:
    """
    Extract top-k key sentences from documents based on semantic similarity to query.
    
    Args:
        docs: List of document passages
        query: User query
        model_name: SentenceTransformer model name
        top_k: Number of top sentences to return
    
    Returns:
        Tuple of (sentences, scores)
    """
    # Tokenize documents into sentences
    sentences = []
    for doc in docs:
        sentences.extend(simple_sent_tokenize(doc))
    
    if not sentences:
        return [], []
    
    # Load embedding model
    embedder = SentenceTransformer(model_name, device=device)
    
    # Encode sentences and query
    sent_embs = embedder.encode(sentences, convert_to_numpy=True, normalize_embeddings=True)
    query_emb = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
    
    # Compute cosine similarity
    scores = np.dot(sent_embs, query_emb)
    
    # Get top-k indices
    top_k = min(top_k, len(sentences))
    top_idx = np.argsort(scores)[-top_k:][::-1]
    
    return [sentences[i] for i in top_idx], [float(scores[i]) for i in top_idx]

### 3.3 NLI Classification

In [None]:
# Load NLI model
nli_model_name = "roberta-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device).eval()

# Label mapping: {0: contradiction, 1: neutral, 2: entailment}
idx2label = {0: "contradiction", 1: "neutral", 2: "support"}


def nli_classify(sentence_a: str, sentence_b: str) -> tuple[str, float]:
    """
    Classify the relationship between two sentences using NLI.
    
    Returns:
        Tuple of (label, probability)
        label: 'contradiction', 'neutral', or 'support'
    """
    inputs = tokenizer(sentence_a, sentence_b, return_tensors="pt", 
                       truncation=True, padding=True, max_length=512).to(device)
    
    with torch.no_grad():
        logits = nli_model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
    
    idx = int(np.argmax(probs))
    return idx2label[idx], float(probs[idx])

### 3.4 Evidence Graph Construction

In [None]:
def build_evidence_graph(sentences: list[str], scores: list[float]) -> nx.Graph:
    """
    Build an evidence graph from key sentences.
    
    Nodes: Each sentence with attributes (text, score)
    Edges: NLI relationships with attributes (label, weight)
           weight = nli_probability * score_i * score_j
    """
    G = nx.Graph()
    
    # Add nodes
    for i, (sent, score) in enumerate(zip(sentences, scores)):
        G.add_node(i, text=sent, score=score)
    
    # Add edges based on NLI classification
    n = len(sentences)
    for i in range(n):
        for j in range(i + 1, n):
            label, prob = nli_classify(sentences[i], sentences[j])
            weight = prob * scores[i] * scores[j]
            G.add_edge(i, j, label=label, weight=weight)
    
    return G


def find_subgraphs_by_label(G: nx.Graph, label: str) -> list[nx.Graph]:
    """Find connected subgraphs containing only edges with the specified label."""
    H = nx.Graph((u, v, d) for u, v, d in G.edges(data=True) if d["label"] == label)
    return [G.subgraph(c).copy() for c in nx.connected_components(H) if len(c) > 1]


def extract_clusters(G: nx.Graph, subgraphs: list[nx.Graph]) -> list[list[str]]:
    """Extract sentence clusters from subgraphs."""
    return [[G.nodes[node]["text"] for node in subg.nodes()] for subg in subgraphs]

### 3.5 LLM Answer Generation

In [None]:
def create_prompt(
    question: str,
    documents: list[str],
    formatted_key_sents: list[str],
    contradictory_clusters: list[list[str]],
    neutral_clusters: list[list[str]],
    support_clusters: list[list[str]]
) -> list[dict]:
    """Create a QA prompt with evidence graph analysis."""
    prompt = f"""You are an expert retrieval-based QA system.

Follow these steps for each input:

You will receive a question together with a set of key passages.
Each key passage includes a document number, a confidence score, and a key sentence.

Focusing on the highest-scoring sentence in each document, consult all of the key sentences to determine the correct answer step-by-step.

To find the answer, focus on the highest-scoring sentences in each document in key passages, and refer to the clusters.
It's important to find answers by comparing the key sentences extracted from each document.

* ANSWER EXTRACTION RULES
- Extract the *exact* noun phrase(s) or term(s) that answer the question.
- No explanations, no punctuation, no extra text - just the terms.

Reference Information:

Question: {question}
Documents: {documents}
Key passages: {formatted_key_sents}
Contradictory clusters: {contradictory_clusters}
Neutral clusters: {neutral_clusters}
Support clusters: {support_clusters}

You must provide the final answer as follows:
["answer1", "answer2"]
"""

    return [
        {"role": "system", "content": "You are a helpful and logical assistant."},
        {"role": "user", "content": prompt}
    ]


def generate_answer(messages: list[dict], model: str = "gpt-4o-mini") -> str:
    """Generate answer using OpenAI API."""
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        temperature=0.0
    )
    return response.choices[0].message["content"].strip()

### 3.6 Main EG-RAG Pipeline

In [None]:
def eg_rag_pipeline(
    context: str,
    question: str,
    top_k: int = 3,
    verbose: bool = True
) -> tuple[str, nx.Graph]:
    """
    Run the EG-RAG pipeline.
    
    Args:
        context: Document context (may contain multiple 'Document:' sections)
        question: User question
        top_k: Number of key sentences per document
        verbose: Print intermediate results
    
    Returns:
        Tuple of (answer_string, evidence_graph)
    """
    # 1. Separate passages
    passages = separate_passages(context)
    
    # 2. Extract key sentences from each passage
    key_sents = []
    key_scores = []
    key_idxs = []
    
    for idx, passage in enumerate(passages):
        sents, scores = extract_key_sentences([passage], question, top_k=top_k)
        key_sents.extend(sents)
        key_scores.extend(scores)
        key_idxs.extend([idx] * len(sents))
    
    # Format key sentences
    formatted_key_sents = [
        f"[{idx}] ({score:.3f}) {sent}"
        for idx, sent, score in zip(key_idxs, key_sents, key_scores)
    ]
    
    # 3. Build evidence graph
    G = build_evidence_graph(key_sents, key_scores)
    
    # 4. Find clusters
    contradiction_subgraphs = find_subgraphs_by_label(G, "contradiction")
    neutral_subgraphs = find_subgraphs_by_label(G, "neutral")
    support_subgraphs = find_subgraphs_by_label(G, "support")
    
    contradictory_clusters = extract_clusters(G, contradiction_subgraphs)
    neutral_clusters = extract_clusters(G, neutral_subgraphs)
    support_clusters = extract_clusters(G, support_subgraphs)
    
    if verbose:
        print("Key Sentences:")
        print("\n".join(formatted_key_sents))
        
        if contradictory_clusters:
            print("\nContradictory clusters found:")
            for i, cluster in enumerate(contradictory_clusters, 1):
                print(f"  Cluster {i}:")
                for s in cluster:
                    print(f"   - {s}")
        else:
            print("\nNo contradictory clusters found.")
    
    # 5. Generate answer
    messages = create_prompt(
        question, passages, formatted_key_sents,
        contradictory_clusters, neutral_clusters, support_clusters
    )
    answer = generate_answer(messages)
    
    if verbose:
        print(f"\nAnswer: {answer}")
    
    return answer, G

## 4. Demo: Single Example

In [None]:
# Example with conflicting information
example_context = """
Document: Anti-inflammatory drugs are often used to control the effects of inflammation. 
Glucocorticoids are the most powerful of these drugs; however, these drugs can have many 
undesirable side effects, such as central obesity, hyperglycemia, osteoporosis. 
Immunosuppressive drugs such as cyclosporin prevent T cells from responding to signals 
correctly by inhibiting signal transduction pathways.

Document: Anti-inflammatory drugs are often used to control the effects of inflammation. 
Glucocorticoids are the most powerful of these drugs; however, these drugs can have many 
undesirable side effects, such as central obesity, hyperglycemia, osteoporosis. 
Immunosuppressive drugs such as aspirin prevent T cells from responding to signals 
correctly by inhibiting signal transduction pathways.
"""

example_question = "What is an example of an immunosuppressive drug that prevents T cell activity by altering signal transduction pathways?"

print(f"Question: {example_question}")
print("=" * 80)

answer, graph = eg_rag_pipeline(example_context, example_question, top_k=3, verbose=True)

## 5. Evaluate on FaithEval-Inconsistent Dataset

In [None]:
# Load FaithEval-Inconsistent dataset
dataset = load_dataset("Salesforce/FaithEval-inconsistent-v1.0", split="test")
print(f"Dataset size: {len(dataset)} samples")

# Limit samples for demo (set to None for full evaluation)
num_samples = 10
if num_samples:
    dataset = dataset.select(range(num_samples))
    print(f"Using {num_samples} samples for demo")

In [None]:
# Run evaluation
predictions = []
ground_truths = []

for i, example in enumerate(tqdm(dataset, desc="Evaluating")):
    answer, _ = eg_rag_pipeline(
        example['context'],
        example['question'],
        top_k=3,
        verbose=False
    )
    predictions.append(answer)
    ground_truths.append(example['answers'])
    
    tqdm.write(f"[{i}] Predicted: {answer}")

## 6. Evaluation Metrics

In [None]:
import json
import ast

def normalize_answer(s: str) -> str:
    """Lowercase, strip, and remove punctuation."""
    return re.sub(r'[^\w\s]', '', s.lower()).strip()


def count_matches(true_list: list[str], pred_list: list[str]) -> int:
    """Count matched answers using substring matching."""
    norm_true = [normalize_answer(t) for t in true_list]
    norm_pred = [normalize_answer(p) for p in pred_list]
    matched = set()
    
    for i, t in enumerate(norm_true):
        for p in norm_pred:
            if p and (p in t or t in p):
                matched.add(i)
                break
    return len(matched)


def parse_prediction(pred_str: str) -> list[str]:
    """Parse prediction string to list."""
    pred_str = (pred_str or '').strip()
    if pred_str.lower() == 'unknown':
        return []
    try:
        return json.loads(pred_str)
    except:
        try:
            return ast.literal_eval(pred_str)
        except:
            return []

In [None]:
# Compute metrics
total_matched = 0
total_true = 0
fully_correct = 0

for i, (pred, true_list) in enumerate(zip(predictions, ground_truths)):
    pred_list = parse_prediction(pred)
    matched = count_matches(true_list, pred_list)
    
    n_true = len(true_list)
    if matched == n_true:
        fully_correct += 1
    
    total_matched += matched
    total_true += n_true
    
    acc = matched / n_true if n_true > 0 else 0
    print(f"Item {i}: {acc:.2f} ({matched}/{n_true} matched)")

# Summary
element_accuracy = total_matched / total_true if total_true > 0 else 0
item_accuracy = fully_correct / len(predictions) if predictions else 0

print("\n" + "=" * 50)
print("RESULTS")
print("=" * 50)
print(f"Element-level Accuracy: {element_accuracy:.2%} ({total_matched}/{total_true})")
print(f"Item-level Accuracy (EM): {item_accuracy:.2%} ({fully_correct}/{len(predictions)})")
print("=" * 50)

## 7. Visualize Evidence Graph

In [None]:
import matplotlib.pyplot as plt

def visualize_graph(G: nx.Graph, title: str = "Evidence Graph"):
    """Visualize the evidence graph."""
    pos = nx.spring_layout(G, seed=42)
    
    plt.figure(figsize=(10, 8))
    
    # Draw nodes
    nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue')
    nx.draw_networkx_labels(G, pos)
    
    # Draw edges with colors based on label
    edge_colors = {
        'contradiction': 'red',
        'neutral': 'gray',
        'support': 'green'
    }
    
    for label, color in edge_colors.items():
        edges = [(u, v) for u, v, d in G.edges(data=True) if d['label'] == label]
        if edges:
            nx.draw_networkx_edges(G, pos, edgelist=edges, edge_color=color, 
                                   width=2, alpha=0.7, label=label)
    
    plt.legend()
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


# Visualize the graph from our example
visualize_graph(graph, "Evidence Graph - Immunosuppressive Drug Example")

## 8. Experiment with Different top_k Values

The `top_k` parameter controls how many key sentences are extracted per document:
- `top_k=1`: Fastest, minimal context
- `top_k=2`: Balanced
- `top_k=3`: Default, good coverage

In [None]:
# Compare different top_k values
for k in [1, 2, 3]:
    print(f"\n{'='*50}")
    print(f"top_k = {k}")
    print(f"{'='*50}")
    answer, _ = eg_rag_pipeline(example_context, example_question, top_k=k, verbose=False)
    print(f"Answer: {answer}")

---

## Citation

If you find this code useful, please cite our paper:

```bibtex
@inproceedings{
hong2026egrag,
title={{EG}-{RAG}: Retrieval-Augmented Generation with Evidence Graph for Reliable Multi-Document Reasoning},
author={Seunggwan Hong and Junhyung Moon and Eunkyeong Lee and Jaehyoung Park and Hyunseung Choo},
booktitle={The 25th International Conference on Autonomous Agents and Multi-Agent Systems},
year={2026},
url={https://openreview.net/forum?id=ahosApE8Ap}
}
```