# CORD-19 GraphRAG Pipeline

A comprehensive pipeline for building knowledge graphs from the CORD-19 dataset, enabling semantic search and citation-aware summarization.

## Overview
This notebook implements a complete GraphRAG (Graph Retrieval-Augmented Generation) pipeline with the following steps:

1. **Setup & Dataset Loading** - Initialize models and load CORD-19 data
2. **Load & Preprocess Metadata** - Clean and prepare the dataset
3. **Entity Extraction** - Extract biomedical entities using spaCy
4. **Graph Construction** - Build knowledge graph with NetworkX
5. **Semantic Embeddings** - Generate embeddings using SciBERT
6. **FAISS Index** - Create vector index for fast retrieval
7. **GraphRAG Retrieval** - Combine semantic search with graph context
8. **GPT-4 Summarization** - Generate summaries with citations

## Requirements
```bash
pip install pandas scispacy spacy transformers faiss-cpu networkx openai torch matplotlib seaborn tqdm
python -m spacy download en_core_web_sm
```


## Step 1: Setup & Dataset Loading


In [None]:
# Import required libraries
import pandas as pd
import spacy
import networkx as nx
import numpy as np
import torch
import faiss
from transformers import AutoTokenizer, AutoModel
from openai import OpenAI
import os
from pathlib import Path
import json
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("‚úÖ Libraries imported successfully")


In [None]:
# Load spaCy model
try:
    nlp = spacy.load("en_core_web_sm")
    print("‚úÖ Loaded spaCy model: en_core_web_sm")
except OSError:
    print("‚ùå spaCy model not found. Please run: python -m spacy download en_core_web_sm")
    print("   You can also try: pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl")


In [None]:
# Load SciBERT model for biomedical text
print("üì• Loading SciBERT model...")
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
print("‚úÖ Loaded SciBERT model")


In [None]:
# Initialize OpenAI client (optional - set your API key)
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
if OPENAI_API_KEY:
    client = OpenAI(api_key=OPENAI_API_KEY)
    print("‚úÖ OpenAI client initialized")
else:
    client = None
    print("‚ö†Ô∏è OpenAI API key not found. GPT-4 summarization will be skipped.")
    print("   Set your API key: export OPENAI_API_KEY='your-key-here'")


## Step 2: Load & Preprocess Metadata


In [None]:
# Load CORD-19 metadata
metadata_path = "2020-04-10/metadata.csv"
print(f"üìÇ Loading metadata from {metadata_path}...")

df = pd.read_csv(metadata_path)
print(f"   Total papers: {len(df):,}")

# Keep only papers with abstracts
df = df.dropna(subset=["abstract"])
print(f"   Papers with abstracts: {len(df):,}")

# Select relevant columns
df = df[["cord_uid", "title", "abstract", "authors", "journal", "publish_time"]]

print("\nüìã Sample data:")
df.head()


In [None]:
# Dataset overview
print("üìä Dataset Overview:")
print(f"   Total papers: {len(df):,}")
print(f"   Columns: {len(df.columns)}")
print(f"   Date range: {df['publish_time'].min()} to {df['publish_time'].max()}")

# Analyze missing data
print("\nüìà Data Completeness:")
missing_data = df.isnull().sum()
missing_percent = (missing_data / len(df)) * 100

for col in df.columns:
    if missing_percent[col] > 0:
        print(f"   {col}: {missing_percent[col]:.1f}% missing ({missing_data[col]:,} papers)")


In [None]:
# Analyze sources and journals
print("üìö Source Distribution:")
if 'source_x' in df.columns:
    source_counts = df['source_x'].value_counts()
    for source, count in source_counts.head(5).items():
        print(f"   {source}: {count:,} papers ({count/len(df)*100:.1f}%)")

print("\nüìñ Top 10 Journals:")
journal_counts = df['journal'].value_counts().head(10)
for journal, count in journal_counts.items():
    print(f"   {journal}: {count:,} papers")


## Step 3: Entity Extraction (SpaCy)


In [None]:
# Define entity extraction function
def extract_entities(text):
    """Extract entities from text using spaCy"""
    doc = nlp(text)
    return [(ent.text, ent.label_) for ent in doc.ents]

# Example on one abstract
if len(df) > 0:
    print("üîç Example entity extraction:")
    sample_text = df['abstract'].iloc[0]
    entities = extract_entities(sample_text)
    
    print(f"   Text: {sample_text[:200]}...")
    print(f"   Entities found: {entities[:10]}")  # Show first 10 entities
    
    # Show entity types
    entity_types = [ent[1] for ent in entities]
    entity_type_counts = pd.Series(entity_types).value_counts()
    print(f"\n   Entity types: {dict(entity_type_counts)}")


In [None]:
# Analyze entity distribution across a sample
sample_size = 100
sample_df = df.head(sample_size)

print(f"üß¨ Analyzing entities in {sample_size} papers...")

all_entities = []
entity_types = []

for idx, row in tqdm(sample_df.iterrows(), total=len(sample_df), desc="Extracting entities"):
    entities = extract_entities(row['abstract'])
    all_entities.extend([ent[0] for ent in entities])
    entity_types.extend([ent[1] for ent in entities])

# Analyze entity types
entity_type_counts = pd.Series(entity_types).value_counts()
print(f"\nüìä Entity Type Distribution:")
for entity_type, count in entity_type_counts.head(10).items():
    print(f"   {entity_type}: {count:,} entities")

# Plot entity types
plt.figure(figsize=(12, 6))
entity_type_counts.head(10).plot(kind='bar')
plt.title('Top 10 Entity Types in CORD-19 Abstracts')
plt.xlabel('Entity Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


## Step 4: Graph Construction (NetworkX)


In [None]:
# Initialize graph
G = nx.Graph()

# Build graph from sample of papers
sample_size = 500  # Limit for speed
sample_df = df.head(sample_size)

print(f"üï∏Ô∏è Building graph from {len(sample_df)} papers...")

for idx, row in tqdm(sample_df.iterrows(), total=len(sample_df), desc="Building graph"):
    paper_id = row['cord_uid']
    
    # Add paper node
    G.add_node(paper_id, 
              type="paper", 
              title=row["title"],
              journal=row["journal"],
              year=row["publish_time"])
    
    # Add entities
    entities = extract_entities(row["abstract"])
    for ent, label in entities:
        # Clean entity text
        ent_clean = ent.strip()
        if len(ent_clean) > 1:  # Skip single characters
            G.add_node(ent_clean, type=label)
            G.add_edge(paper_id, ent_clean, relation="mentions")
    
    # Add authors
    if pd.notna(row["authors"]):
        authors = [author.strip() for author in row["authors"].split(";") if author.strip()]
        for author in authors:
            G.add_node(author, type="author")
            G.add_edge(paper_id, author, relation="authored_by")

print(f"‚úÖ Graph constructed:")
print(f"   Nodes: {len(G.nodes):,}")
print(f"   Edges: {len(G.edges):,}")


In [None]:
# Analyze graph structure
node_types = {}
for node, data in G.nodes(data=True):
    node_type = data.get('type', 'unknown')
    node_types[node_type] = node_types.get(node_type, 0) + 1

print(f"üìä Node types:")
for node_type, count in sorted(node_types.items(), key=lambda x: x[1], reverse=True):
    print(f"   {node_type}: {count:,}")

# Plot node types
plt.figure(figsize=(12, 6))
node_type_counts = pd.Series(node_types)
node_type_counts.plot(kind='bar')
plt.title('Node Types in Knowledge Graph')
plt.xlabel('Node Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


## Step 5: Semantic Embeddings with SciBERT


In [None]:
# Define embedding function
def embed(text):
    """Generate embeddings using SciBERT"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).numpy()

# Test embedding on a sample text
sample_text = "COVID-19 is a respiratory disease caused by SARS-CoV-2 virus."
sample_embedding = embed(sample_text)
print(f"‚úÖ Sample embedding shape: {sample_embedding.shape}")
print(f"   Embedding dimension: {sample_embedding.shape[1]}")


In [None]:
# Compute embeddings for abstracts
embedding_size = 1000  # Number of abstracts to embed
abstracts = df["abstract"].head(embedding_size).tolist()

print(f"üìä Generating embeddings for {len(abstracts)} abstracts...")

embeddings_list = []
for i, abstract in enumerate(tqdm(abstracts, desc="Generating embeddings")):
    try:
        emb = embed(abstract)
        embeddings_list.append(emb)
    except Exception as e:
        print(f"   ‚ö†Ô∏è Error processing abstract {i}: {e}")
        # Add zero embedding as fallback
        embeddings_list.append(np.zeros((1, 768)))

embeddings = np.vstack(embeddings_list)
print(f"‚úÖ Generated embeddings: {embeddings.shape}")


## Step 6: FAISS Index for Retrieval


In [None]:
# Create FAISS index
dim = embeddings.shape[1]
print(f"üìä Building FAISS index with dimension {dim}")

index = faiss.IndexFlatL2(dim)
index.add(embeddings.astype('float32'))

print(f"‚úÖ FAISS index built with {index.ntotal} vectors")


In [None]:
# Test semantic search
query = "What drugs are being tested for COVID-19 treatment?"

def embed_query(query_text):
    """Embed query text"""
    inputs = tokenizer(query_text, return_tensors="pt", truncation=True, max_length=256, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).numpy().astype('float32')

q_vec = embed_query(query)
D, I = index.search(q_vec, k=5)

print(f"üîç Query: '{query}'")
print(f"üìã Top 5 results:")

results_df = df.iloc[I[0]][["title", "abstract"]]
for i, (idx, row) in enumerate(results_df.iterrows()):
    print(f"\n   {i+1}. {row['title'][:100]}...")
    print(f"      {row['abstract'][:200]}...")


## Step 7: GraphRAG Retrieval


In [None]:
# Get graph context for retrieved papers
def get_context_from_graph(paper_ids, G):
    """Get graph context for retrieved papers"""
    context = []
    for pid in paper_ids:
        neighbors = list(G.neighbors(pid))
        paper_data = G.nodes[pid]
        title = paper_data.get('title', 'Unknown')
        
        context.append(f"Paper '{title}' (ID: {pid}) mentions: {neighbors[:10]}")  # Limit to first 10 neighbors
    return "\n".join(context)

# Get context for our search results
paper_ids = df.iloc[I[0]]["cord_uid"].tolist()
graph_context = get_context_from_graph(paper_ids, G)

print("üìä Graph context retrieved:")
print(graph_context[:1000] + "..." if len(graph_context) > 1000 else graph_context)


## Step 8: GPT-4 Summarization (OpenAI API)


In [None]:
# Generate summary with GPT-4 (if API key is available)
if client:
    # Get abstracts for retrieved papers
    context_text = "\n\n".join(df.iloc[I[0]]["abstract"].tolist())
    
    prompt = f"""
You are a biomedical research assistant specializing in COVID-19 research.
Summarize the following abstracts and their graph context in relation to the query.
Cite specific entities, papers, and authors where relevant.

Query: {query}

Abstracts:
{context_text}

Graph Context:
{graph_context}

Please provide a comprehensive summary with specific citations.
"""
    
    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1000,
            temperature=0.3
        )
        
        summary = response.choices[0].message.content
        print("‚úÖ Summary generated:")
        print(summary)
        
    except Exception as e:
        print(f"‚ùå Error generating summary: {e}")
        summary = None
else:
    print("‚ö†Ô∏è OpenAI client not available. Skipping GPT-4 summarization.")
    print("\nüìã Manual Summary of Retrieved Papers:")
    
    # Manual summary of results
    for i, (idx, row) in enumerate(results_df.iterrows()):
        print(f"\n{i+1}. {row['title']}")
        print(f"   Abstract: {row['abstract'][:300]}...")
    
    summary = "Manual summary provided above"


## Save Results


In [None]:
# Create results directory
Path("results").mkdir(exist_ok=True)

# Save graph
nx.write_gml(G, "results/cord19_graph.gml")
print("‚úÖ Graph saved to results/cord19_graph.gml")

# Save embeddings
np.save("results/embeddings.npy", embeddings)
print("‚úÖ Embeddings saved to results/embeddings.npy")

# Save FAISS index
faiss.write_index(index, "results/faiss_index.bin")
print("‚úÖ FAISS index saved to results/faiss_index.bin")

# Save query results
results = {
    "query": query,
    "paper_ids": paper_ids,
    "graph_context": graph_context,
    "summary": summary if 'summary' in locals() else "No summary generated",
    "timestamp": pd.Timestamp.now().isoformat()
}

with open("results/query_results.json", "w") as f:
    json.dump(results, f, indent=2)
print("‚úÖ Query results saved to results/query_results.json")

# Save processed metadata
df.to_csv("results/processed_metadata.csv", index=False)
print("‚úÖ Processed metadata saved to results/processed_metadata.csv")


## Interactive Query Interface


In [None]:
# Interactive query function
def query_cord19(query_text, top_k=5):
    """Query the CORD-19 GraphRAG system"""
    print(f"üîç Querying: '{query_text}'")
    
    # Generate query embedding
    q_vec = embed_query(query_text)
    
    # Search FAISS index
    D, I = index.search(q_vec, k=top_k)
    
    # Get results
    results_df = df.iloc[I[0]][["title", "abstract", "authors", "journal"]]
    
    print(f"\nüìã Top {top_k} results:")
    for i, (idx, row) in enumerate(results_df.iterrows()):
        print(f"\n{i+1}. {row['title']}")
        print(f"   Authors: {row['authors'][:100]}...")
        print(f"   Journal: {row['journal']}")
        print(f"   Abstract: {row['abstract'][:300]}...")
    
    # Get graph context
    paper_ids = df.iloc[I[0]]["cord_uid"].tolist()
    graph_context = get_context_from_graph(paper_ids, G)
    
    return results_df, graph_context

# Example queries
example_queries = [
    "What are the symptoms of COVID-19?",
    "How is SARS-CoV-2 transmitted?",
    "What treatments are available for coronavirus?",
    "What is the mortality rate of COVID-19?",
    "How effective are masks in preventing COVID-19?"
]

print("üí° Example queries you can try:")
for i, q in enumerate(example_queries, 1):
    print(f"   {i}. {q}")


## Summary and Next Steps


In [None]:
print("üéâ CORD-19 GraphRAG Pipeline Complete!")
print("=" * 50)
print(f"‚úÖ Processed {len(df):,} papers")
print(f"‚úÖ Built graph with {len(G.nodes):,} nodes and {len(G.edges):,} edges")
print(f"‚úÖ Generated {embeddings.shape[0]:,} embeddings")
print(f"‚úÖ Created FAISS index with {index.ntotal:,} vectors")

print("\nüìÅ Results saved to 'results' directory:")
print("   - cord19_graph.gml: Knowledge graph")
print("   - embeddings.npy: SciBERT embeddings")
print("   - faiss_index.bin: Vector search index")
print("   - query_results.json: Query results")
print("   - processed_metadata.csv: Cleaned dataset")

print("\nüöÄ Next Steps:")
print("   1. Experiment with different queries")
print("   2. Increase sample sizes for better coverage")
print("   3. Add more sophisticated entity extraction")
print("   4. Implement graph-based reasoning")
print("   5. Add citation tracking and validation")
print("   6. Deploy as a web application")

print("\nüí° Tips:")
print("   - Use specific biomedical terms for better results")
print("   - Try different query formulations")
print("   - Explore the graph structure for insights")
print("   - Combine multiple queries for comprehensive analysis")
