# Week 6 — Graph-RAG over Your Corpus
**Goal:** Build an entity graph from your corpus and retrieve evidence using neighborhood expansion.

**Colab tips:** You can run this notebook in Colab. If using your own PDFs, mount Drive.


In [None]:

#@title Setup (Colab-friendly)
import os, re, json, math, time, random
import numpy as np
import pandas as pd
import networkx as nx
from collections import defaultdict
from typing import List, Dict, Tuple

DATA_DIR = "./data_week6"
os.makedirs(DATA_DIR, exist_ok=True)

corpus_path = os.path.join(DATA_DIR, "corpus.csv")
if not os.path.exists(corpus_path):
    demo = pd.DataFrame({
        "doc_id": [f"doc{i}" for i in range(1,6)],
        "text": [
            "Method X was introduced by Author A and compared on Dataset D1 with F1=0.78.",
            "Author A collaborated with Author B; Method X improved Metric F1 on D1.",
            "Dataset D2 was used to evaluate Method Y introduced by Author C.",
            "Paper P3 applies Method X to Dataset D2 and reports Accuracy 0.82.",
            "Survey S1 links Method Y, Dataset D2, and Metric AUC."
        ]
    })
    demo.to_csv(corpus_path, index=False)
corpus = pd.read_csv(corpus_path)
corpus.head()


## 1. Lightweight Entity Extraction

In [None]:

ENTITY_PATTERNS = {
    "METHOD": r"Method\s+[A-Z][A-Za-z0-9]*",
    "AUTHOR": r"Author\s+[A-Z]",
    "DATASET": r"Dataset\s+[A-Z0-9]+",
    "PAPER": r"Paper\s+[A-Z0-9]+|Survey\s+[A-Z0-9]+",
    "METRIC": r"F1|Accuracy|AUC"
}
import re
def extract_entities(text: str):
    ents = []
    for typ, pat in ENTITY_PATTERNS.items():
        for m in re.finditer(pat, text):
            ents.append((m.group(0), typ, m.start(), m.end()))
    return ents

rows = []
for _, r in corpus.iterrows():
    ents = extract_entities(r["text"])
    for e, typ, s, t in ents:
        rows.append({"doc_id": r["doc_id"], "entity": e, "type": typ, "start": s, "end": t, "span": r["text"][max(0,s-40):min(len(r['text']),t+40)]})
ents_df = pd.DataFrame(rows)
ents_df.head(10)


## 2. Relation Extraction (Co-occurrence within sentence)

In [None]:

def sentence_split(text):
    import re
    return re.split(r"(?<=[.!?])\s+", text)

edges = []
for _, r in corpus.iterrows():
    for sent in sentence_split(r["text"]):
        ents = extract_entities(sent)
        for i in range(len(ents)):
            for j in range(i+1, len(ents)):
                e1, t1, *_ = ents[i]
                e2, t2, *_ = ents[j]
                edges.append({"doc_id": r["doc_id"], "head": e1, "type1": t1, "tail": e2, "type2": t2, "sentence": sent})

edges_df = pd.DataFrame(edges)
edges_df.head(10)


## 3. Build Graph (NetworkX) and Attach Evidence

In [None]:

G = nx.Graph()
for _, e in ents_df.iterrows():
    G.add_node(e["entity"], type=e["type"])

for _, ed in edges_df.iterrows():
    G.add_edge(ed["head"], ed["tail"], doc_id=ed["doc_id"], sentence=ed["sentence"])

len(G.nodes()), len(G.edges())


## 4. Graph-aware Retrieval

In [None]:

def detect_seed_entities(query: str):
    seeds = []
    for node in G.nodes():
        if node.lower().split()[-1] in query.lower():
            seeds.append(node)
    for node, data in G.nodes(data=True):
        if data.get("type") in ("METHOD","DATASET") and data["type"].lower() in query.lower():
            seeds.append(node)
    return list(dict.fromkeys(seeds))

def neighborhood_evidence(seeds, hops=1, max_spans=12):
    spans = []
    seen_edges = set()
    for s in seeds:
        nodes = nx.single_source_shortest_path_length(G, s, cutoff=hops).keys()
        for u in nodes:
            for v in G.neighbors(u):
                e = tuple(sorted([u, v]))
                if e in seen_edges: 
                    continue
                seen_edges.add(e)
                data = G.get_edge_data(u, v)
                spans.append({"u": u, "v": v, "doc_id": data.get("doc_id"), "sentence": data.get("sentence")})
                if len(spans) >= max_spans: 
                    return spans
    return spans

def graph_rag(query: str, hops=1):
    seeds = detect_seed_entities(query)
    spans = neighborhood_evidence(seeds, hops=hops)
    return {"seeds": seeds, "spans": spans}

demo_out = graph_rag("Which dataset evaluated Method X with F1?")
demo_out


## 5. Prompt Assembly (stub)

In [None]:

def assemble_prompt(query, seeds, spans):
    ev_lines = [f"- ({s['doc_id']}) {s['sentence']}" for s in spans]
    return f"""System: Answer using ONLY the evidence and cite (doc_id).
Query: {query}

Seeds: {', '.join(seeds)}

Evidence:\n""" + "\n".join(ev_lines) + "\n\nAnswer:"

print(assemble_prompt("Which dataset evaluated Method X with F1?", demo_out["seeds"], demo_out["spans"]))


## 6. Optional: Graph Visualization

In [None]:

import matplotlib.pyplot as plt
pos = nx.spring_layout(G, seed=7)
type_to_color = {"METHOD":"#6aa84f","AUTHOR":"#3c78d8","DATASET":"#cc0000","PAPER":"#674ea7","METRIC":"#e69138"}
node_colors = [type_to_color.get(G.nodes[n].get("type",""), "#999") for n in G.nodes()]
plt.figure(figsize=(6,4))
nx.draw(G, pos, with_labels=True, node_color=node_colors, node_size=900, font_size=9, edge_color="#bbb")
plt.title("Entity Graph (demo)"); plt.tight_layout(); plt.show()
