In [9]:
import torch
import numpy as np
import re
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
import json

In [10]:
# Load the LLM-Embedder model (same as your code)
tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
model = AutoModel.from_pretrained("BAAI/llm-embedder")

In [11]:
def get_embeddings(texts, instruction=""):
    if instruction:
        texts = [f"{instruction} {text}" for text in texts]
    
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
    
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
    
    return embeddings

In [12]:
class SimpleGraphRAG:
    def __init__(self):
        self.documents = []
        self.entities = {}  # entity -> {docs: [doc_ids], embedding: tensor}
        self.relationships = []  # [(entity1, relation, entity2, doc_id)]
        self.doc_embeddings = None
        
    def extract_entities_and_relations(self, text, doc_id):
        entities = []
        relationships = []
        
        # Extract named entities (simple patterns for football context)
        # Players (capitalized names)
        players = re.findall(r'\b[A-Z][a-z]+ [A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
        # Teams
        teams = re.findall(r'\b(?:Real Madrid|Barcelona|Madrid|Barça)\b', text)
        # Locations
        locations = re.findall(r'\b(?:Santiago Bernabéu|Camp Nou|La Liga)\b', text)
        
        all_entities = players + teams + locations
        entities.extend(all_entities)
        
        # Extract relationships using simple patterns
        # Goal relationships
        goal_patterns = [
            (r'(\w+(?:\s+\w+)*)\s+(?:scored|goal|finish)', 'SCORED'),
            (r'(\w+(?:\s+\w+)*)\s+took.*?lead', 'TOOK_LEAD'),
            (r'(\w+(?:\s+\w+)*)\s+sealed.*?win', 'SEALED_WIN'),
            (r'(\w+(?:\s+\w+)*)\s+responded.*?goal', 'RESPONDED_WITH_GOAL')
        ]
        
        for pattern, relation in goal_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            for match in matches:
                if match.strip():
                    relationships.append((match.strip(), relation, 'GOAL', doc_id))
        
        # Team vs Team relationships
        if 'Real Madrid' in text and 'Barcelona' in text:
            relationships.append(('Real Madrid', 'PLAYED_AGAINST', 'Barcelona', doc_id))
            relationships.append(('Barcelona', 'PLAYED_AGAINST', 'Real Madrid', doc_id))
        
        # Player-Team relationships
        player_team_patterns = [
            (r'(Jude Bellingham|Federico Valverde|Vinícius Jr\.)', 'Real Madrid'),
            (r'(Robert Lewandowski)', 'Barcelona')
        ]
        
        for player_pattern, team in player_team_patterns:
            if re.search(player_pattern, text):
                relationships.append((re.search(player_pattern, text).group(1), 'PLAYS_FOR', team, doc_id))
        
        return entities, relationships
    
    def index_documents(self, documents):
        self.documents = documents
        
        print("🔍 Extracting entities and relationships...")
        
        # Extract entities and relationships
        for doc_id, doc in enumerate(documents):
            entities, relations = self.extract_entities_and_relations(doc, doc_id)
            
            # Store entities
            for entity in entities:
                if entity not in self.entities:
                    self.entities[entity] = {'docs': [], 'embedding': None}
                self.entities[entity]['docs'].append(doc_id)
            
            # Store relationships
            self.relationships.extend(relations)
        
        # Create embeddings for entities
        print("🧠 Creating entity embeddings...")
        entity_names = list(self.entities.keys())
        if entity_names:
            entity_embeddings = get_embeddings(entity_names, "Represent this entity:")
            for i, entity in enumerate(entity_names):
                self.entities[entity]['embedding'] = entity_embeddings[i]
        
        # Create document embeddings
        print("📄 Creating document embeddings...")
        self.doc_embeddings = get_embeddings(documents)
        
        print(f"✅ Indexed {len(documents)} documents, {len(self.entities)} entities, {len(self.relationships)} relationships")
    
    def examine_graph(self):
        print("\n" + "="*60)
        print("📊 GRAPH ANALYSIS")
        print("="*60)
        
        print(f"\n🏷️  ENTITIES ({len(self.entities)}):")
        for entity, data in self.entities.items():
            doc_count = len(data['docs'])
            print(f"  • {entity} (appears in {doc_count} documents)")
        
        print(f"\n🔗 RELATIONSHIPS ({len(self.relationships)}):")
        relation_counts = defaultdict(int)
        for entity1, relation, entity2, doc_id in self.relationships:
            relation_counts[relation] += 1
            print(f"  • {entity1} --[{relation}]--> {entity2} (doc {doc_id})")
        
        print(f"\n📈 RELATIONSHIP TYPES:")
        for relation, count in relation_counts.items():
            print(f"  • {relation}: {count} instances")
        
        # Quality assessment
        print(f"\n🎯 QUALITY ASSESSMENT:")
        print(f"  • Entity coverage: {len(self.entities)} unique entities")
        print(f"  • Relationship diversity: {len(relation_counts)} different relation types")
        print(f"  • Average relations per document: {len(self.relationships)/len(self.documents):.1f}")
        
        # Identify potential issues
        issues = []
        if len(self.entities) < 5:
            issues.append("Low entity count - may need better extraction patterns")
        if len(relation_counts) < 3:
            issues.append("Limited relationship types - consider more patterns")
        
        if issues:
            print(f"\n⚠️  POTENTIAL ISSUES:")
            for issue in issues:
                print(f"  • {issue}")
        else:
            print(f"\n✅ Graph quality looks good!")
    
    def graph_enhanced_retrieval(self, query, top_k=2):
        print(f"\n🔍 Query: {query}")
        print("="*50)
        
        # Get query embedding
        query_embedding = get_embeddings([query], "Represent this sentence for searching relevant passages:")
        
        # 1. Traditional document similarity
        doc_similarities = cosine_similarity(query_embedding.numpy(), self.doc_embeddings.numpy())[0]
        
        # 2. Entity-based retrieval
        entity_scores = {}
        if self.entities:
            entity_embeddings = torch.stack([data['embedding'] for data in self.entities.values()])
            entity_similarities = cosine_similarity(query_embedding.numpy(), entity_embeddings.numpy())[0]
            
            for i, (entity, data) in enumerate(self.entities.items()):
                entity_scores[entity] = entity_similarities[i]
        
        # 3. Combine scores (simple approach)
        final_scores = doc_similarities.copy()
        
        # Boost documents that contain highly relevant entities
        for entity, score in entity_scores.items():
            if score > 0.5:  # High entity relevance threshold
                for doc_id in self.entities[entity]['docs']:
                    final_scores[doc_id] += 0.2 * score  # Boost factor
        
        # Get top documents
        top_indices = np.argsort(final_scores)[-top_k:][::-1]
        
        print("📄 RETRIEVED DOCUMENTS:")
        retrieved_docs = []
        for i, idx in enumerate(top_indices):
            doc = self.documents[idx]
            doc_score = doc_similarities[idx]
            final_score = final_scores[idx]
            print(f"  {i+1}. Document {idx} (doc_sim: {doc_score:.3f}, final: {final_score:.3f})")
            print(f"     {doc}")
            retrieved_docs.append(doc)
        
        # Show relevant entities
        relevant_entities = [(e, s) for e, s in entity_scores.items() if s > 0.3]
        relevant_entities.sort(key=lambda x: x[1], reverse=True)
        
        if relevant_entities:
            print(f"\n🏷️  RELEVANT ENTITIES:")
            for entity, score in relevant_entities[:5]:
                print(f"  • {entity} (similarity: {score:.3f})")
        
        # Show relevant relationships
        relevant_relations = []
        for entity1, relation, entity2, doc_id in self.relationships:
            if entity1 in [e for e, s in relevant_entities] or entity2 in [e for e, s in relevant_entities]:
                relevant_relations.append((entity1, relation, entity2, doc_id))
        
        if relevant_relations:
            print(f"\n🔗 RELEVANT RELATIONSHIPS:")
            for entity1, relation, entity2, doc_id in relevant_relations[:5]:
                print(f"  • {entity1} --[{relation}]--> {entity2}")
        
        return retrieved_docs

In [13]:
if __name__ == "__main__":
    # Your football documents
    documents = [
        "Real Madrid and Barcelona clashed in another thrilling edition of El Clásico.",
        "The match was filled with intensity, showcasing world-class football from both sides.",
        "Real Madrid took an early lead with a clinical finish from Jude Bellingham.",
        "Barcelona responded quickly with a brilliant goal by Robert Lewandowski.",
        "The midfield battle was fierce, with both teams pressing high and forcing turnovers.",
        "Vinícius Jr. caused constant problems for Barcelona's defense with his pace and dribbling.",
        "A late goal from Federico Valverde sealed the win for Real Madrid.",
        "The Santiago Bernabéu erupted as Madrid secured a vital victory in La Liga.",
        "The win pushed Real Madrid to the top of the table, asserting dominance in the title race."
    ]
    
    # Create and run GraphRAG
    graph_rag = SimpleGraphRAG()
    graph_rag.index_documents(documents)
    graph_rag.examine_graph()
    
    # Test queries
    queries = [
        "Who scored the first goal for Real Madrid?",
        "Which player sealed the victory?",
        "What happened at Santiago Bernabéu?"
    ]
    
    for query in queries:
        graph_rag.graph_enhanced_retrieval(query)
        print("\n" + "="*70 + "\n")

🔍 Extracting entities and relationships...
🧠 Creating entity embeddings...
📄 Creating document embeddings...
✅ Indexed 9 documents, 9 entities, 12 relationships

📊 GRAPH ANALYSIS

🏷️  ENTITIES (9):
  • Real Madrid (appears in 8 documents)
  • Barcelona (appears in 3 documents)
  • Jude Bellingham (appears in 1 documents)
  • Robert Lewandowski (appears in 1 documents)
  • Federico Valverde (appears in 1 documents)
  • The Santiago (appears in 1 documents)
  • La Liga (appears in 2 documents)
  • Madrid (appears in 1 documents)
  • Santiago Bernabéu (appears in 1 documents)

🔗 RELATIONSHIPS (12):
  • Real Madrid --[PLAYED_AGAINST]--> Barcelona (doc 0)
  • Barcelona --[PLAYED_AGAINST]--> Real Madrid (doc 0)
  • Real Madrid took an early lead with a clinical --[SCORED]--> GOAL (doc 2)
  • Real Madrid --[TOOK_LEAD]--> GOAL (doc 2)
  • Jude Bellingham --[PLAYS_FOR]--> Real Madrid (doc 2)
  • Barcelona responded quickly with a brilliant --[SCORED]--> GOAL (doc 3)
  • Barcelona --[RESPONDED_W