In [None]:

#Ontology Guided KG Extraction


import pandas as pd
import openai
from tqdm import tqdm
import time
import os
import re
import json
from collections import Counter, defaultdict

# Load your OpenAI API key from environment variable
os.environ["OPENAI_API_KEY"] = ""  #add your openAI key
openai.api_key = os.environ.get("OPENAI_API_KEY")

def load_ontology_schema(schema_file_path="ontology_schema.json"):
    """
    Load ontology schema from JSON file
    
    Args:
        schema_file_path: Path to the ontology schema JSON file
    
    Returns:
        str: Formatted ontology schema string for use in prompts
    """
    try:
        print(f" Loading ontology schema from: {schema_file_path}")
        with open(schema_file_path, 'r', encoding='utf-8') as f:
            schema_data = json.load(f)
        
        # Check if the file has the schema_string field (from our ontology extractor)
        if 'schema_string' in schema_data:
            ontology_schema = schema_data['schema_string']
            print(" Using pre-formatted schema string from file")
        else:
            # Build schema string from components
            print(" Building schema string from components")
            ontology_schema = build_schema_from_components(schema_data)
        
        print(f" Schema loaded with {len(schema_data.get('classes', []))} classes and {len(schema_data.get('all_properties', []))} properties")
        return ontology_schema
        
    except FileNotFoundError:
        print(f" Error: {schema_file_path} not found")
        print("Please ensure you have run the ontology schema extractor first")
        return None
    except json.JSONDecodeError:
        print(f" Error: Invalid JSON in {schema_file_path}")
        return None
    except Exception as e:
        print(f" Error loading schema: {e}")
        return None

def build_schema_from_components(schema_data):
    """
    Build ontology schema string from JSON components
    
    Args:
        schema_data: Dictionary containing ontology components
    
    Returns:
        str: Formatted ontology schema string
    """
    schema_parts = []
    
    # Add classes
    if 'classes' in schema_data and schema_data['classes']:
        classes_str = ', '.join(sorted(schema_data['classes']))
        schema_parts.append(f"Ontology Classes:\n- {classes_str}")
    
    # Add relationships/properties
    if 'all_properties' in schema_data and schema_data['all_properties']:
        schema_parts.append("\nOntology Relationships:")
        
        property_domains = schema_data.get('property_domains', {})
        property_ranges = schema_data.get('property_ranges', {})
        
        for prop in sorted(schema_data['all_properties']):
            domains = property_domains.get(prop, [])
            ranges = property_ranges.get(prop, [])
            
            if domains and ranges:
                for domain in domains:
                    for range_item in ranges:
                        schema_parts.append(f"- {prop} ({domain} -> {range_item})")
            else:
                schema_parts.append(f"- {prop}")
    
    # Add class hierarchy if available
    if 'class_hierarchy' in schema_data and schema_data['class_hierarchy']:
        schema_parts.append("\nClass Hierarchy:")
        for superclass, subclasses in schema_data['class_hierarchy'].items():
            for subclass in subclasses:
                schema_parts.append(f"- {subclass} subClassOf {superclass}")
    
    return "\n".join(schema_parts)

# Load the input chunks file
print(" Loading chunked data...")
chunks_df = pd.read_csv("chunks.csv", encoding="utf-8")

# Clean column names (remove BOM and whitespace)
chunks_df.columns = chunks_df.columns.str.strip().str.replace('\ufeff', '')
print(f" Loaded {len(chunks_df)} chunks")

# Check available columns
print(f" Available columns: {list(chunks_df.columns)}")

# Verify we have the required columns
required_columns = ['chunk_id', 'text']
missing_columns = [col for col in required_columns if col not in chunks_df.columns]

if missing_columns:
    print(f" Missing required columns: {missing_columns}")
    print("Available columns:", list(chunks_df.columns))
    exit(1)

print(f" Data preview:")
print(f"   • Shape: {chunks_df.shape}")
print(f"   • Columns: {list(chunks_df.columns)}")
if len(chunks_df) > 0:
    print(f"   • Sample chunk_id: {chunks_df['chunk_id'].iloc[0]}")
    print(f"   • Sample text: {str(chunks_df['text'].iloc[0])[:100]}...")

# Load ontology schema dynamically
ontology_schema = load_ontology_schema("ontology_schema.json")

if not ontology_schema:
    print(" Cannot proceed without ontology schema. Exiting...")
    exit(1)

print("\n Using the following ontology schema:")
print("="*60)
print(ontology_schema)
print("="*60)

#  Guide GPT 4 turbo
example = """
Text:
Institutions must assume that EEA students have the right to remain in the UK. Once a student is enrolled, the institution is expected to take all reasonable steps to ensure that the student can complete their programme.

Triples:
(institution, has_assumption, eea_students_right_to_remain)
(student, is_enrolled_in, study_programme)
(institution, has_legal_duty, ensure_student_completion)
(institution, provides_support, student)
"""

# Prompt builder function
def build_prompt(text_chunk):
    return f"""
You are a knowledge extraction model that converts educational policy text into triples.

Below is a text chunk from an educational policy document:

\"\"\"{text_chunk}\"\"\"

Use the ontology schema below to guide your extraction:

{ontology_schema}

---

Here is an example:

{example}

---

Now, extract all triples from the provided chunk in the format:
(subject, predicate, object)

-Do not use full sentences or long descriptions as object values. Instead, normalize them into short, meaningful phrases or identifiers (e.g., `GDPR_compliance`, `funding_16_19`, `right_to_remain`).

"""  

# Function to query GPT-4
def extract_triples(text_chunk):
    prompt = build_prompt(text_chunk[:2000])  # Truncate if too long
    try:
        response = openai.ChatCompletion.create(
            model="gpt-4-turbo",
            messages=[
                {"role": "system", "content": "You are an expert in ontology-based information extraction."},
                {"role": "user",   "content": prompt}
            ],
            temperature=0.2,
            max_tokens=512
        )
        return response['choices'][0]['message']['content']
    except Exception as e:
        print(f"Error during API call: {e}")
        return ""

# Regex pattern to match triples (handles commas inside the object)
triple_pattern = re.compile(r"^\(\s*([^,]+?)\s*,\s*([^,]+?)\s*,\s*(.+?)\s*\)$")

def analyze_triples_statistics(triples_df, output_filename):
    """
    Generate comprehensive statistics and summary for the extracted triples
    
    Args:
        triples_df: DataFrame containing the extracted triples
        output_filename: Name of the output CSV file
    
    Returns:
        dict: Dictionary containing all statistics
    """
    print(f"\n GENERATING TRIPLES STATISTICS AND SUMMARY")
    print("="*60)
    
    # Basic statistics
    total_triples = len(triples_df)
    unique_chunks = triples_df['chunk_id'].nunique()
    unique_subjects = triples_df['subject'].nunique()
    unique_predicates = triples_df['predicate'].nunique()
    unique_objects = triples_df['object'].nunique()
    
    # Frequency analysis
    subject_counts = Counter(triples_df['subject'])
    predicate_counts = Counter(triples_df['predicate'])
    object_counts = Counter(triples_df['object'])
    
    # Chunk-level analysis
    triples_per_chunk = triples_df.groupby('chunk_id').size()
    
    # Relationship patterns
    subject_predicate_pairs = Counter(zip(triples_df['subject'], triples_df['predicate']))
    predicate_object_pairs = Counter(zip(triples_df['predicate'], triples_df['object']))
    
    # Print detailed statistics
    print(f" BASIC STATISTICS:")
    print(f"   • Total Triples: {total_triples:,}")
    print(f"   • Unique Chunks Processed: {unique_chunks:,}")
    print(f"   • Unique Subjects: {unique_subjects:,}")
    print(f"   • Unique Predicates: {unique_predicates:,}")
    print(f"   • Unique Objects: {unique_objects:,}")
    print(f"   • Average Triples per Chunk: {triples_per_chunk.mean():.1f}")
    print(f"   • Min Triples per Chunk: {triples_per_chunk.min()}")
    print(f"   • Max Triples per Chunk: {triples_per_chunk.max()}")
    
    print(f"\n TOP 10 MOST FREQUENT SUBJECTS:")
    for i, (subject, count) in enumerate(subject_counts.most_common(10), 1):
        print(f"   {i:2d}. {subject} ({count} occurrences)")
    
    print(f"\n TOP 10 MOST FREQUENT PREDICATES:")
    for i, (predicate, count) in enumerate(predicate_counts.most_common(10), 1):
        print(f"   {i:2d}. {predicate} ({count} occurrences)")
    
    print(f"\n TOP 10 MOST FREQUENT OBJECTS:")
    for i, (obj, count) in enumerate(object_counts.most_common(10), 1):
        print(f"   {i:2d}. {obj} ({count} occurrences)")
    
    print(f"\n CHUNK DISTRIBUTION:")
    chunk_stats = triples_per_chunk.describe()
    print(f"   • Mean: {chunk_stats['mean']:.1f}")
    print(f"   • Std Dev: {chunk_stats['std']:.1f}")
    print(f"   • 25th Percentile: {chunk_stats['25%']:.0f}")
    print(f"   • Median: {chunk_stats['50%']:.0f}")
    print(f"   • 75th Percentile: {chunk_stats['75%']:.0f}")
    
    print(f"\n TOP 10 SUBJECT-PREDICATE PATTERNS:")
    for i, ((subject, predicate), count) in enumerate(subject_predicate_pairs.most_common(10), 1):
        print(f"   {i:2d}. ({subject}, {predicate}) - {count} times")
    
    print(f"\n TOP 10 PREDICATE-OBJECT PATTERNS:")
    for i, ((predicate, obj), count) in enumerate(predicate_object_pairs.most_common(10), 1):
        print(f"   {i:2d}. ({predicate}, {obj}) - {count} times")
    
    # Quality metrics
    print(f"\n QUALITY METRICS:")
    avg_subject_length = triples_df['subject'].str.len().mean()
    avg_predicate_length = triples_df['predicate'].str.len().mean()
    avg_object_length = triples_df['object'].str.len().mean()
    
    print(f"   • Average Subject Length: {avg_subject_length:.1f} characters")
    print(f"   • Average Predicate Length: {avg_predicate_length:.1f} characters")
    print(f"   • Average Object Length: {avg_object_length:.1f} characters")
    
    # Coverage analysis
    coverage_rate = unique_chunks / len(chunks_df) * 100 if 'chunks_df' in globals() else 0
    print(f"   • Chunk Coverage: {coverage_rate:.1f}% ({unique_chunks} out of {len(chunks_df) if 'chunks_df' in globals() else 'N/A'} chunks)")
    
    # Compile statistics dictionary
    stats = {
        'basic_stats': {
            'total_triples': total_triples,
            'unique_chunks': unique_chunks,
            'unique_subjects': unique_subjects,
            'unique_predicates': unique_predicates,
            'unique_objects': unique_objects,
            'avg_triples_per_chunk': float(triples_per_chunk.mean()),
            'min_triples_per_chunk': int(triples_per_chunk.min()),
            'max_triples_per_chunk': int(triples_per_chunk.max())
        },
        'top_subjects': dict(subject_counts.most_common(20)),
        'top_predicates': dict(predicate_counts.most_common(20)),
        'top_objects': dict(object_counts.most_common(20)),
        'top_subject_predicate_patterns': {f"{s}||{p}": count for (s, p), count in subject_predicate_pairs.most_common(20)},
        'top_predicate_object_patterns': {f"{p}||{o}": count for (p, o), count in predicate_object_pairs.most_common(20)},
        'chunk_statistics': {k: float(v) if pd.notna(v) else None for k, v in chunk_stats.to_dict().items()},
        'triple_statistics': {
            'avg_subject_length': float(avg_subject_length),
            'avg_predicate_length': float(avg_predicate_length),
            'avg_object_length': float(avg_object_length)
        }
    }
    
    return stats

def save_detailed_summary(triples_df, stats, output_filename):
    """
    Save detailed summary to JSON format only
    
    Args:
        triples_df: DataFrame containing the extracted triples
        stats: Statistics dictionary
        output_filename: Base filename for outputs
    """
    base_name = output_filename.replace('.csv', '')
    
    # Save statistics to JSON
    stats_filename = f"{base_name}_statistics.json"
    with open(stats_filename, 'w', encoding='utf-8') as f:
        json.dump(stats, f, indent=2, ensure_ascii=False)
    print(f" Detailed statistics saved to: {stats_filename}")

# Extraction loop
def main():
    print(f"\n Starting knowledge graph extraction...")
    print(f" Processing chunks using dynamic ontology schema")
    
    triples_list = []
  
    test_chunks = chunks_df
    print(f" Processing {len(test_chunks)} chunks in FULL mode")
    
    for _, row in tqdm(test_chunks.iterrows(), total=len(test_chunks), desc="Extracting triples"):  # FIXED
        chunk_id   = row['chunk_id']
        chunk_text = row['text']
        result     = extract_triples(chunk_text)
        time.sleep(1.2)  # throttle requests
        
        # Append raw output for inspection
        with open("raw_outputs.txt", "a", encoding="utf-8") as f:
            f.write(f"\n=== CHUNK ID: {chunk_id} ===\n{result}\n")
        
        if not result.strip():
            continue
        
        for line in result.splitlines():
            # 1) Trim whitespace
            # 2) Remove leading bullets or numbering like "1. ", "- ", "* "
            # 3) Strip backticks from code fences
            cleaned = line.strip().lstrip("-*0123456789. ").strip('`')
            m = triple_pattern.match(cleaned)
            if m:
                s, p, o = m.groups()
                triples_list.append({
                    "chunk_id":  chunk_id,
                    "subject":   s.strip(),
                    "predicate": p.strip(),
                    "object":    o.strip().strip('"')
                })
            else:
                # Optional: log unmatched lines for further analysis
                # print(f"Unmatched line: {cleaned}")
                pass
    
    # Save triples to CSV
    output_df = pd.DataFrame(triples_list)
    output_filename = "Ontology_Guided__KG_Triples.csv"
    output_df.to_csv(output_filename, index=False)
    
    # Generate comprehensive statistics and summary
    if len(triples_list) > 0:
        stats = analyze_triples_statistics(output_df, output_filename)
        save_detailed_summary(output_df, stats, output_filename)
    
    # Print summary
    print(f"\n Triple extraction complete!")
    print(f" Extraction Summary:")
    print(f"   • Processed chunks: {len(test_chunks)}")
    print(f"   • Total triples extracted: {len(triples_list)}")
    if len(test_chunks) > 0:
        print(f"   • Average triples per chunk: {len(triples_list)/len(test_chunks):.1f}")
    print(f"   • Results saved to: {output_filename}")
    print(f"   • Raw outputs saved to: raw_outputs.txt")
    
    # Show sample triples
    if len(triples_list) > 0:
        print(f"\n Sample extracted triples:")
        for i, triple in enumerate(triples_list[:5]):
            print(f"   {i+1}. ({triple['subject']}, {triple['predicate']}, {triple['object']})")
        if len(triples_list) > 5:
            print(f"   ... and {len(triples_list) - 5} more")
    
    # List all output files generated
    print(f"\n OUTPUT FILES GENERATED:")
    print(f"    Main results: {output_filename}")
    print(f"    Statistics: {output_filename.replace('.csv', '_statistics.json')}")
    print(f"    Raw outputs: raw_outputs.txt")

if __name__ == "__main__":  
    main()