# Graph DB Configuration

In [1]:
import os
from dotenv import load_dotenv, dotenv_values 

load_dotenv() 

NEO4j_URI = 'bolt://localhost:7687'
NEO4j_USER = 'neo4j'
NEO4j_PASSWORD = os.environ.get('pass')
GROQ_API = os.environ.get('GROQ_API')

# Connect with graphDB

In [2]:
from langchain_community.graphs import Neo4jGraph

In [3]:
graph = Neo4jGraph(url=NEO4j_URI, username=NEO4j_USER, password=NEO4j_PASSWORD)

In [4]:
graph

<langchain_community.graphs.neo4j_graph.Neo4jGraph at 0x2bbb14ad0c0>

# Use the Groq API for free LLM Models specifically using llama-3.2-90b-text-preview

In [5]:
from langchain_groq import ChatGroq

In [8]:
llm = ChatGroq(groq_api_key=GROQ_API, model_name='llama-3.2-90b-vision-preview')

In [9]:
llm

ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x000002BBCB3C1C00>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000002BBCB3C27D0>, model_name='llama-3.2-90b-vision-preview', model_kwargs={}, groq_api_key=SecretStr('**********'))

# Creating appropriate dataset for GraphDB

In [10]:
from typing import Dict, List, Optional, Set
import pandas as pd
import numpy as np
from datetime import datetime
import re
from tqdm import tqdm
import hashlib
from collections import Counter

In [11]:
class CausalGraph:
    def __init__(self, graph: Neo4jGraph):
        self.graph = graph
        self.debug_info = {
            'total_triggers_found': 0,
            'triggers_per_event': []
        }

    def clear_database(self):
        clear_queries = [
            "MATCH (n) DETACH DELETE n",
            "CALL apoc.schema.assert({}, {})"
        ]
        for query in clear_queries:
            try:
                self.graph.query(query)
            except Exception as e:
                print(f"Warning during cleanup: {e}")
    
    def create_indexes(self):
        constraints = [
            "CREATE CONSTRAINT IF NOT EXISTS FOR (e:Event) REQUIRE e.id IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (c:Cause) REQUIRE c.id IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (e:Effect) REQUIRE e.id IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (t:Trigger) REQUIRE t.id IS UNIQUE"
        ]
        for query in constraints:
            try:
                self.graph.query(query)
            except Exception as e:
                print(f"Warning during index creation: {e}")
    
    def clean_text(self, text: str, preserve_case: bool = False) -> str:
        if pd.isna(text) or text is None:
            return ""
        cleaned = str(text).strip()
        cleaned = re.sub(r'\s+', ' ', cleaned)
        cleaned = cleaned.replace('"', "'").replace('\\', '')
        return cleaned if preserve_case else cleaned.lower()
    
    def generate_hash(self, text: str, event_id: str = "") -> str:
        text_to_hash = f"{text}_{event_id}" if event_id else text
        return hashlib.md5(text_to_hash.encode()).hexdigest()
    
    def extract_elements(self, tagged_text: str) -> Optional[Dict]:
        if not isinstance(tagged_text, str) or tagged_text == 'NoTag':
            return None
            
        try:
            patterns = {
                'causes': r'<cause>((?:(?!</cause>).)*)</cause>',
                'effects': r'<effect>((?:(?!</effect>).)*)</effect>',
                'triggers': r'<trigger>((?:(?!</trigger>).)*)</trigger>'
            }
            
            elements = {}
            for key, pattern in patterns.items():
                matches = re.findall(pattern, tagged_text, re.DOTALL | re.IGNORECASE)
                if key == 'triggers':
                    elements[key] = [m.strip() for m in matches if m.strip()]
                    self.debug_info['triggers_per_event'].append(len(elements[key]))
                    self.debug_info['total_triggers_found'] += len(elements[key])
                else:
                    cleaned_matches = [m.strip() for m in matches if m.strip()]
                    elements[key] = list(dict.fromkeys(cleaned_matches))
            
            return elements if any(elements.values()) else None
            
        except Exception as e:
            print(f"Error extracting elements: {str(e)}")
            return None
    
    def create_event_graph(self, text: str, tagged_text: str, event_id: str):
        elements = self.extract_elements(tagged_text)
        if not elements:
            return
        
        event_query = """
        MERGE (e:Event {id: $id})
        SET e.text = $text,
            e.tagged_text = $tagged_text,
            e.created_at = datetime()
        """
        
        self.graph.query(
            event_query,
            params={
                'id': event_id,
                'text': self.clean_text(text),
                'tagged_text': tagged_text
            }
        )
        
        for cause in elements['causes']:
            cause_id = self.generate_hash(self.clean_text(cause))
            self.graph.query("""
            MATCH (e:Event {id: $event_id})
            MERGE (c:Cause {id: $cause_id})
            SET c.text = $cause_text
            MERGE (c)-[r:CAUSES]->(e)
            SET r.created_at = datetime()
            """, params={'event_id': event_id, 'cause_id': cause_id, 'cause_text': cause})
        
        for effect in elements['effects']:
            effect_id = self.generate_hash(self.clean_text(effect))
            self.graph.query("""
            MATCH (e:Event {id: $event_id})
            MERGE (eff:Effect {id: $effect_id})
            SET eff.text = $effect_text
            MERGE (e)-[r:RESULTS_IN]->(eff)
            SET r.created_at = datetime()
            """, params={'event_id': event_id, 'effect_id': effect_id, 'effect_text': effect})
        
        for trigger in elements['triggers']:
            trigger_id = self.generate_hash(self.clean_text(trigger), event_id)
            self.graph.query("""
            MATCH (e:Event {id: $event_id})
            MERGE (t:Trigger {id: $trigger_id})
            SET t.text = $trigger_text,
                t.event_id = $event_id
            MERGE (e)-[r:HAS_TRIGGER]->(t)
            SET r.created_at = datetime()
            """, params={
                'event_id': event_id,
                'trigger_id': trigger_id,
                'trigger_text': trigger
            })
    
    def analyze_dataset(self, csv_path: str):
        df = pd.read_csv(csv_path)
        df = df.replace({np.nan: None})
        tagged_rows = df[df['tagged_sentence'] != 'NoTag']
        
        total_stats = {'causes': 0, 'effects': 0, 'triggers': 0}
        unique_elements = {'causes': set(), 'effects': set(), 'triggers': set()}
        
        print("Analyzing dataset...")
        for _, row in tagged_rows.iterrows():
            elements = self.extract_elements(str(row['tagged_sentence']))
            if elements:
                for key in ['causes', 'effects', 'triggers']:
                    total_stats[key] += len(elements[key])
                    unique_elements[key].update(elements[key])
        
        print("\nDataset Analysis:")
        print(f"Total tagged sentences: {len(tagged_rows)}")
        print("\nTotal elements found:")
        for key, value in total_stats.items():
            print(f"Total {key}: {value}")
        print("\nUnique elements:")
        for key, value in unique_elements.items():
            print(f"Unique {key}: {len(value)}")
        
        print("\nTrigger Statistics:")
        print(f"Total triggers found: {self.debug_info['total_triggers_found']}")
        trigger_counts = Counter(self.debug_info['triggers_per_event'])
        print("Events by trigger count:")
        for count, freq in sorted(trigger_counts.items()):
            print(f"{count} trigger(s): {freq} events")
    
    def load_dataset(self, csv_path: str, clear_existing: bool = True):
        if clear_existing:
            self.clear_database()
        self.create_indexes()
        
        self.analyze_dataset(csv_path)
        
        df = pd.read_csv(csv_path)
        df = df.replace({np.nan: None})
        tagged_rows = df[df['tagged_sentence'] != 'NoTag']
        
        print("\nLoading data into graph...")
        for idx, row in tqdm(tagged_rows.iterrows(), total=len(tagged_rows)):
            try:
                self.create_event_graph(
                    text=str(row['text']),
                    tagged_text=str(row['tagged_sentence']),
                    event_id=f"event_{idx}"
                )
            except Exception as e:
                print(f"\nError processing row {idx}: {str(e)}")
        
        stats = self.get_graph_statistics()
        print("\nFinal Graph Statistics:")
        for key, value in stats.items():
            print(f"{key}: {value}")
    
    def get_graph_statistics(self) -> Dict:
        node_stats_query = """
        MATCH (n)
        RETURN {
            events: count(CASE WHEN n:Event THEN 1 END),
            causes: count(CASE WHEN n:Cause THEN 1 END),
            effects: count(CASE WHEN n:Effect THEN 1 END),
            triggers: count(CASE WHEN n:Trigger THEN 1 END)
        } as stats
        """
        
        rel_stats_query = """
        MATCH ()-[r]->()
        RETURN count(r) as relationships
        """
        
        node_results = self.graph.query(node_stats_query)
        rel_results = self.graph.query(rel_stats_query)
        
        stats = node_results[0]['stats']
        stats['relationships'] = rel_results[0]['relationships']
        return stats

In [12]:
graph = Neo4jGraph(url=NEO4j_URI, username=NEO4j_USER, password=NEO4j_PASSWORD)
graph_creation = CausalGraph(graph)
graph_creation.load_dataset('Causal_dataset.csv')

Analyzing dataset...

Dataset Analysis:
Total tagged sentences: 1030

Total elements found:
Total causes: 1177
Total effects: 1125
Total triggers: 1106

Unique elements:
Unique causes: 1153
Unique effects: 1119
Unique triggers: 577

Trigger Statistics:
Total triggers found: 1106
Events by trigger count:
0 trigger(s): 100 events
1 trigger(s): 806 events
2 trigger(s): 95 events
3 trigger(s): 15 events
4 trigger(s): 8 events
5 trigger(s): 3 events
6 trigger(s): 3 events

Loading data into graph...


100%|██████████| 1030/1030 [00:21<00:00, 48.26it/s]



Final Graph Statistics:
events: 1021
effects: 1118
triggers: 1102
causes: 1147
relationships: 3404


# Check if correct data has been added to DB 

In [13]:
def analyze_dataset_distribution(csv_path: str):
    df = pd.read_csv(csv_path)
    total_sentences = len(df)
    tagged_sentences = df[df['tagged_sentence'] != 'NoTag'].shape[0]
    
    cause_count = df['tagged_sentence'].str.count('<cause>').sum()
    effect_count = df['tagged_sentence'].str.count('<effect>').sum()
    trigger_count = df['tagged_sentence'].str.count('<trigger>').sum()
    
    print(f"Dataset Analysis:")
    print(f"Total sentences: {total_sentences}")
    print(f"Tagged sentences: {tagged_sentences}")
    print(f"Unique causes: {cause_count}")
    print(f"Unique effects: {effect_count}")
    print(f"Unique triggers: {trigger_count}")

In [14]:
analyze_dataset_distribution('Causal_dataset.csv')

Dataset Analysis:
Total sentences: 2005
Tagged sentences: 1030
Unique causes: 1182.0
Unique effects: 1132.0
Unique triggers: 1107.0
