In [1]:
# With real Neo4j
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
from data.feature_processor import FeatureProcessor
from data.data_loader import create_train_val_test_loaders

# Connect to KG
kg = KGConnector("bolt://localhost:7687", "neo4j", "password")

# Build graph
constructor = GraphConstructor(kg)
graph = constructor.build_hetero_graph("Amsterdam_North")

# Process features
processor = FeatureProcessor()
processor.process_graph_features(graph)

# Create loaders for retrofit task
train_loader, val_loader, test_loader = create_train_val_test_loaders(
    graph, task='retrofit', batch_size=32
)

# Ready for GNN training!

KeyboardInterrupt: 

In [None]:
# check_neo4j_schema.py
"""
Check what labels and properties actually exist in your Neo4j database.
"""

from neo4j import GraphDatabase

# Neo4j credentials
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "aminasad"

def check_schema():
    """Check Neo4j schema to understand actual data structure."""
    
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    
    with driver.session() as session:
        print("="*60)
        print("NEO4J DATABASE SCHEMA")
        print("="*60)
        
        # 1. Get all node labels
        print("\n📊 NODE LABELS:")
        result = session.run("""
            CALL db.labels() YIELD label
            RETURN label
            ORDER BY label
        """)
        labels = [record['label'] for record in result]
        for label in labels:
            print(f"  - {label}")
        
        # 2. Count nodes for each label
        print("\n📊 NODE COUNTS:")
        for label in labels:
            result = session.run(f"MATCH (n:{label}) RETURN count(n) as count")
            count = result.single()['count']
            if count > 0:
                print(f"  - {label}: {count}")
        
        # 3. Get relationship types
        print("\n📊 RELATIONSHIP TYPES:")
        result = session.run("""
            CALL db.relationshipTypes() YIELD relationshipType
            RETURN relationshipType
            ORDER BY relationshipType
        """)
        rel_types = [record['relationshipType'] for record in result]
        for rel_type in rel_types:
            print(f"  - {rel_type}")
        
        # 4. Get properties for main labels
        print("\n📊 PROPERTIES BY LABEL:")
        for label in labels[:10]:  # Check first 10 labels
            result = session.run(f"""
                MATCH (n:{label})
                WITH n LIMIT 1
                RETURN keys(n) as properties
            """)
            record = result.single()
            if record and record['properties']:
                print(f"\n  {label}:")
                for prop in record['properties']:
                    print(f"    - {prop}")
        
        # 5. Check for Building nodes specifically
        print("\n📊 BUILDING NODE SAMPLE:")
        result = session.run("""
            MATCH (b:Building)
            RETURN b
            LIMIT 1
        """)
        record = result.single()
        if record:
            building = dict(record['b'])
            print("  Sample Building properties:")
            for key, value in building.items():
                print(f"    - {key}: {value}")
        
        # 6. Find grid-like structures
        print("\n📊 CHECKING FOR GRID STRUCTURE:")
        
        # Check for any hierarchical relationships
        result = session.run("""
            MATCH (a)-[r]->(b)
            WHERE labels(a) <> labels(b)
            RETURN DISTINCT labels(a)[0] as from_label, 
                   type(r) as rel_type, 
                   labels(b)[0] as to_label,
                   count(*) as count
            ORDER BY count DESC
            LIMIT 10
        """)
        
        print("\n  Hierarchical relationships found:")
        for record in result:
            print(f"    {record['from_label']} -[{record['rel_type']}]-> {record['to_label']}: {record['count']}")
        
        # 7. Check for region/area properties
        print("\n📊 REGION/AREA PROPERTIES:")
        result = session.run("""
            MATCH (n)
            WHERE n.region IS NOT NULL 
               OR n.area IS NOT NULL 
               OR n.district IS NOT NULL
               OR n.zone IS NOT NULL
            RETURN DISTINCT 
                labels(n)[0] as label,
                n.region as region,
                n.area as area,
                n.district as district,
                n.zone as zone
            LIMIT 10
        """)
        
        for record in result:
            print(f"  {record['label']}: region={record['region']}, area={record['area']}")
    
    driver.close()
    print("\n" + "="*60)
    print("Please share this output so we can adapt the code to your schema!")
    print("="*60)

if __name__ == "__main__":
    check_schema()

NEO4J DATABASE SCHEMA

📊 NODE LABELS:
  - AdjacencyCluster
  - BatterySystem
  - Building
  - CableGroup
  - CableSegment
  - ConnectionPoint
  - EnergyState
  - GridComponent
  - HeatPumpSystem
  - LVCabinet
  - LV_Network
  - MV_Transformer
  - Metadata
  - SolarSystem
  - Substation
  - SystemBaseline
  - TimeSlot
  - Transformer

📊 NODE COUNTS:
  - AdjacencyCluster: 327
  - BatterySystem: 1485
  - Building: 1517
  - CableGroup: 209
  - CableSegment: 4455
  - ConnectionPoint: 1517
  - EnergyState: 95424
  - GridComponent: 576
  - HeatPumpSystem: 1138
  - LVCabinet: 316
  - Metadata: 1
  - SolarSystem: 986
  - Substation: 2
  - SystemBaseline: 1
  - TimeSlot: 672
  - Transformer: 49

📊 RELATIONSHIP TYPES:
  - ADJACENT_TO
  - CAN_INSTALL
  - CONNECTED_TO
  - CONNECTS_TO
  - DURING
  - FEEDS_FROM
  - HAS_CONNECTION_POINT
  - HAS_INSTALLED
  - IN_ADJACENCY_CLUSTER
  - NEAR_MV
  - ON_SEGMENT
  - PART_OF
  - SHOULD_ELECTRIFY

📊 PROPERTIES BY LABEL:

  AdjacencyCluster:
    - solar_penetra



    EnergyState -[DURING]-> TimeSlot: 95424
    CableSegment -[PART_OF]-> CableGroup: 4455
    Building -[IN_ADJACENCY_CLUSTER]-> AdjacencyCluster: 2233
    ConnectionPoint -[ON_SEGMENT]-> CableSegment: 1517
    Building -[CONNECTED_TO]-> CableGroup: 1517
    Building -[HAS_CONNECTION_POINT]-> ConnectionPoint: 1517
    Building -[CAN_INSTALL]-> BatterySystem: 1463
    Building -[SHOULD_ELECTRIFY]-> HeatPumpSystem: 1079
    Building -[CAN_INSTALL]-> SolarSystem: 926
    CableGroup -[CONNECTS_TO]-> Transformer: 301

📊 REGION/AREA PROPERTIES:
  Building: region=None, area=4115.0
  Building: region=None, area=230.0
  Building: region=None, area=4955.0
  Building: region=None, area=5131.0
  Building: region=None, area=527.0
  Building: region=None, area=2698.0
  Building: region=None, area=1626.0
  Building: region=None, area=2335.0
  Building: region=None, area=48.0
  Building: region=None, area=3519.0

Please share this output so we can adapt the code to your schema!


In [None]:
# simple_test.py - Simple working test with all fixes applied
"""
Simple test to verify the pipeline works with your Neo4j schema.
"""

from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
from data.feature_processor import FeatureProcessor

# Connection settings
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "aminasad"

def test_pipeline_fixed():
    """Test the complete pipeline with proper error handling."""
    
    from data.kg_connector import KGConnector
    from data.graph_constructor import GraphConstructor
    from data.feature_processor import FeatureProcessor
    
    NEO4J_URI = "bolt://localhost:7687"
    NEO4J_USER = "neo4j"
    NEO4J_PASSWORD = "aminasad"
    
    print("="*60)
    print("TESTING COMPLETE PIPELINE")
    print("="*60)
    
    # 1. Connect to Neo4j
    print("\n1. Connecting to Neo4j...")
    kg = KGConnector(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    
    if not kg.verify_connection():
        print("❌ Cannot connect to Neo4j")
        return
    
    print("✅ Connected to Neo4j")
    
    # 2. Get a district
    print("\n2. Finding a district...")
    with kg.driver.session() as session:
        result = session.run("""
            MATCH (b:Building)
            WHERE b.district_name IS NOT NULL
            RETURN DISTINCT b.district_name as district
            LIMIT 1
        """).single()
        
        district = result['district']
        print(f"✅ Using district: {district}")
    
    # 3. Get basic statistics
    print("\n3. Getting district statistics...")
    with kg.driver.session() as session:
        stats = session.run("""
            MATCH (b:Building {district_name: $district})
            RETURN count(b) as count
        """, district=district).single()
        
        print(f"✅ Found {stats['count']} buildings")
    
    # 4. Build graph
    print("\n4. Building PyTorch Geometric graph...")
    constructor = GraphConstructor(kg)
    
    try:
        graph = constructor.build_hetero_graph(district)
        print(f"✅ Graph built successfully!")
        print(f"   Node types: {list(graph.num_nodes_dict.keys())}")
        print(f"   Node counts: {graph.num_nodes_dict}")
        
        # Show edge types
        edge_types = []
        for edge_type in graph.edge_types:
            if hasattr(graph[edge_type], 'edge_index'):
                num_edges = graph[edge_type].edge_index.shape[1]
                edge_types.append(f"{edge_type}: {num_edges} edges")
        if edge_types:
            print(f"   Edge types: {edge_types[:3]}...")
        
    except Exception as e:
        print(f"❌ Graph construction failed: {e}")
        kg.close()
        return
    
    # 5. Process features
    print("\n5. Processing features...")
    processor = FeatureProcessor()
    
    try:
        processor.process_graph_features(graph)
        print(f"✅ Features processed successfully!")
        
        # Show feature dimensions
        for node_type in ['building', 'cable_group', 'adjacency_cluster']:
            if node_type in graph.node_types and hasattr(graph[node_type], 'x'):
                shape = graph[node_type].x.shape
                print(f"   {node_type}: {shape[0]} nodes × {shape[1]} features")
                
    except Exception as e:
        print(f"❌ Feature processing failed: {e}")
    
    # 6. Create task-specific graphs
    print("\n6. Creating task-specific graphs...")
    
    # Retrofit task
    try:
        retrofit_graph = constructor.build_subgraph_for_task(district, 'retrofit')
        if hasattr(retrofit_graph['building'], 'y'):
            num_retrofit = retrofit_graph['building'].y.sum().item()
            total = len(retrofit_graph['building'].y)
            print(f"✅ Retrofit task: {num_retrofit}/{total} buildings need retrofit")
    except Exception as e:
        print(f"⚠️  Retrofit task failed: {e}")
    
    # Energy sharing task
    try:
        sharing_graph = constructor.build_subgraph_for_task(district, 'energy_sharing')
        if 'adjacency_cluster' in sharing_graph.node_types:
            num_clusters = sharing_graph['adjacency_cluster'].x.shape[0]
            print(f"✅ Energy sharing: {num_clusters} clusters found")
            if hasattr(sharing_graph['adjacency_cluster'], 'y'):
                avg_potential = sharing_graph['adjacency_cluster'].y.mean().item()
                print(f"   Average sharing potential: {avg_potential:.3f}")
    except Exception as e:
        print(f"⚠️  Energy sharing task failed: {e}")
    
    # Solar task
    try:
        solar_graph = constructor.build_subgraph_for_task(district, 'solar')
        if hasattr(solar_graph['building'], 'y'):
            total_potential = solar_graph['building'].y.sum().item()
            print(f"✅ Solar optimization: {total_potential:.0f} kWh/year potential")
    except Exception as e:
        print(f"⚠️  Solar task failed: {e}")
    
    # 7. Test data queries
    print("\n7. Testing specific queries...")
    
    # Test cable group aggregation - fixed to get correct cable group
    with kg.driver.session() as session:
        result = session.run("""
            MATCH (cg:CableGroup)<-[:CONNECTED_TO]-(b:Building {district_name: $district})
            WITH cg, count(b) as count
            WHERE count > 0 AND count < 50  // Get reasonable sized cable group
            RETURN cg.group_id as id, count
            LIMIT 1
        """, district=district).single()
        
        if result and result['id']:
            print(f"✅ Cable group {result['id']}: {result['count']} buildings connected")
    
    # Test adjacency clusters - fixed to handle None/string values
    clusters = kg.get_adjacency_clusters(district, min_cluster_size=2)
    if clusters:
        print(f"✅ Found {len(clusters)} adjacency clusters")
        # Find first cluster with valid sharing potential
        for cluster in clusters:
            potential = cluster.get('sharing_potential')
            if potential is not None:
                try:
                    # Try to convert to float if it's a string
                    if isinstance(potential, str):
                        potential = float(potential)
                    print(f"   Best cluster: {cluster['cluster_id']} "
                          f"(potential: {potential:.2f})")
                    break
                except (ValueError, TypeError):
                    # Skip if conversion fails
                    continue
        else:
            # No valid potential found
            print(f"   Clusters found but no valid sharing potential values")
    
    kg.close()
    print("\n" + "="*60)
    print("✅ ALL TESTS COMPLETED SUCCESSFULLY!")
    print("="*60)
if __name__ == "__main__":
    test_pipeline_fixed()

NameError: name 'np' is not defined

In [None]:
# test.py
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor

kg = KGConnector("bolt://localhost:7687", "neo4j", "aminasad")
constructor = GraphConstructor(kg)

# Get first available district
with kg.driver.session() as session:
    result = session.run("MATCH (b:Building) WHERE b.district_name IS NOT NULL RETURN DISTINCT b.district_name LIMIT 1").single()
    district = result['district_name']

# Build graph
graph = constructor.build_hetero_graph(district)
print(f"✅ Success! Graph has {graph.num_nodes_dict}")

kg.close()

KeyError: 'district_name'

In [None]:
# check_and_fix_sharing.py
"""
Check what cluster properties are actually available and fix the energy sharing task.
"""

from data.kg_connector import KGConnector
import torch
from torch_geometric.data import HeteroData

NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "aminasad"

def check_cluster_properties():
    """Check what properties adjacency clusters actually have."""
    
    print("="*60)
    print("CHECKING ADJACENCY CLUSTER PROPERTIES")
    print("="*60)
    
    kg = KGConnector(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    
    # Get a sample cluster with all its properties
    with kg.driver.session() as session:
        result = session.run("""
            MATCH (ac:AdjacencyCluster)<-[:IN_ADJACENCY_CLUSTER]-(b:Building)
            WHERE b.district_name = 'Buitenveldert-Oost'
            WITH ac, count(b) as building_count
            WHERE building_count > 3
            RETURN ac, building_count
            LIMIT 5
        """).data()
        
        print("\n📊 Sample Adjacency Clusters:")
        for i, record in enumerate(result, 1):
            cluster = dict(record['ac'])
            print(f"\n{i}. Cluster ID: {cluster.get('cluster_id')}")
            print(f"   Buildings: {record['building_count']}")
            print("   Properties:")
            for key, value in cluster.items():
                if value is not None:
                    print(f"     - {key}: {value}")
    
    # Check statistics of cluster properties
    with kg.driver.session() as session:
        stats = session.run("""
            MATCH (ac:AdjacencyCluster)<-[:IN_ADJACENCY_CLUSTER]-(b:Building)
            WHERE b.district_name = 'Buitenveldert-Oost'
            WITH ac
            RETURN 
                count(ac) as total_clusters,
                count(ac.energy_sharing_potential) as clusters_with_sharing,
                count(ac.solar_penetration) as clusters_with_solar,
                count(ac.hp_penetration) as clusters_with_hp,
                count(ac.battery_penetration) as clusters_with_battery,
                avg(ac.member_count) as avg_members,
                avg(ac.solar_penetration) as avg_solar_pen,
                avg(ac.hp_penetration) as avg_hp_pen
        """).single()
        
        print("\n📈 Cluster Statistics:")
        print(f"   Total clusters: {stats['total_clusters']}")
        print(f"   With energy_sharing_potential: {stats['clusters_with_sharing']}")
        print(f"   With solar_penetration: {stats['clusters_with_solar']}")
        print(f"   With hp_penetration: {stats['clusters_with_hp']}")
        print(f"   With battery_penetration: {stats['clusters_with_battery']}")
        if stats['avg_solar_pen'] is not None:
            print(f"   Average solar penetration: {stats['avg_solar_pen']:.3f}")
        if stats['avg_hp_pen'] is not None:
            print(f"   Average HP penetration: {stats['avg_hp_pen']:.3f}")
    
    kg.close()
    return result


def fixed_build_energy_sharing_graph(kg_connector, district_name: str, min_cluster_size: int = 3) -> HeteroData:
    """
    Fixed version of build_energy_sharing_graph that uses available properties.
    Add this to your graph_constructor.py
    """
    from data.graph_constructor import GraphConstructor
    
    constructor = GraphConstructor(kg_connector)
    
    # Get adjacency clusters
    clusters = kg_connector.get_adjacency_clusters(district_name, min_cluster_size)
    
    # Build graph with cluster edges
    graph = constructor.build_hetero_graph(district_name, include_energy_sharing=True)
    
    # Since energy_sharing_potential is None, calculate it from other properties
    if 'adjacency_cluster' in graph.node_types and hasattr(graph['adjacency_cluster'], 'x'):
        cluster_features = graph['adjacency_cluster'].x
        
        # Features from _add_nodes_to_graph:
        # 0: member_count
        # 1: energy_sharing_potential (always 0)
        # 2: solar_penetration
        # 3: hp_penetration
        # 4: battery_penetration
        # 5: thermal_benefit
        # 6: cable_savings
        
        # Calculate sharing score based on available metrics
        member_count = cluster_features[:, 0]
        solar_pen = cluster_features[:, 2]
        hp_pen = cluster_features[:, 3]
        battery_pen = cluster_features[:, 4]
        
        # Simple scoring: more members and more DER penetration = better sharing potential
        # Normalize member count (assume max 20 members)
        member_score = torch.clamp(member_count / 20.0, 0, 1)
        
        # Average DER penetration
        der_score = (solar_pen + hp_pen + battery_pen) / 3.0
        
        # Combined score
        sharing_scores = (member_score * 0.3 + der_score * 0.7)
        
        graph['adjacency_cluster'].y = sharing_scores
        
        print(f"   Calculated sharing scores - Mean: {sharing_scores.mean():.3f}, Max: {sharing_scores.max():.3f}")
    
    graph.task = 'energy_sharing'
    
    return graph


def test_fixed_energy_sharing():
    """Test the fixed energy sharing task."""
    
    print("\n" + "="*60)
    print("TESTING FIXED ENERGY SHARING")
    print("="*60)
    
    from data.kg_connector import KGConnector
    from data.graph_constructor import GraphConstructor
    
    kg = KGConnector(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    
    # Build the energy sharing graph with the fix
    try:
        graph = fixed_build_energy_sharing_graph(kg, "Buitenveldert-Oost")
        
        if hasattr(graph['adjacency_cluster'], 'y'):
            scores = graph['adjacency_cluster'].y
            print(f"\n✅ Energy sharing task successful!")
            print(f"   Clusters: {len(scores)}")
            print(f"   Average score: {scores.mean().item():.3f}")
            print(f"   Max score: {scores.max().item():.3f}")
            print(f"   Min score: {scores.min().item():.3f}")
            
            # Find best clusters
            top_indices = torch.topk(scores, min(5, len(scores))).indices
            print(f"\n   Top 5 clusters by sharing potential:")
            for i, idx in enumerate(top_indices, 1):
                print(f"     {i}. Cluster index {idx}: score {scores[idx].item():.3f}")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
    
    kg.close()


def create_final_fix_for_graph_constructor():
    """
    The final fix to add to your graph_constructor.py
    """
    
    fix_code = '''
# In graph_constructor.py, replace the _build_energy_sharing_graph method with this:

def _build_energy_sharing_graph(self, district_name: str,
                               min_cluster_size: int = 3) -> HeteroData:
    """Build graph for energy sharing analysis."""
    
    # Get adjacency clusters
    clusters = self.kg.get_adjacency_clusters(district_name, min_cluster_size)
    
    # Build graph with cluster edges
    graph = self.build_hetero_graph(district_name, include_energy_sharing=True)
    
    # Calculate sharing scores from available properties
    if 'adjacency_cluster' in graph.node_types and hasattr(graph['adjacency_cluster'], 'x'):
        cluster_features = graph['adjacency_cluster'].x
        
        # Extract relevant features
        member_count = cluster_features[:, 0]
        solar_pen = cluster_features[:, 2]
        hp_pen = cluster_features[:, 3]
        battery_pen = cluster_features[:, 4]
        
        # Calculate sharing potential score
        # More members = better sharing opportunities
        member_score = torch.clamp(member_count / 20.0, 0, 1)
        
        # Higher DER penetration = more energy to share
        der_score = (solar_pen + hp_pen + battery_pen) / 3.0
        
        # Combined score (weighted average)
        sharing_scores = (member_score * 0.3 + der_score * 0.7)
        
        graph['adjacency_cluster'].y = sharing_scores
    
    graph.task = 'energy_sharing'
    
    return graph
'''
    
    print("\n" + "="*60)
    print("FIX TO ADD TO graph_constructor.py:")
    print("="*60)
    print(fix_code)


if __name__ == "__main__":
    # First check what properties are available
    clusters = check_cluster_properties()
    
    # Test the fixed energy sharing
    test_fixed_energy_sharing()
    
    # Show the fix to apply
    create_final_fix_for_graph_constructor()

Couldn't import dot_parser, loading of dot files will not be possible.
CHECKING ADJACENCY CLUSTER PROPERTIES

📊 Sample Adjacency Clusters:

1. Cluster ID: ROW_LV_GROUP_0021_4818282
   Buildings: 4
   Properties:
     - thermal_benefit: HIGH
     - cable_savings: HIGH
     - avg_solar_potential_kwp: 3.833760011315345
     - pattern: LINEAR
     - created_at: 2025-08-18T01:28:23.809000000+00:00
     - export_potential_kw: 0.0
     - cluster_type: ROW_HOUSES
     - energy_sharing_potential: LOW
     - function_diversity: 1
     - solar_penetration: 0.0
     - district_name: Buitenveldert-Oost
     - cluster_id: ROW_LV_GROUP_0021_4818282
     - hp_penetration: 0.0
     - total_solar_generation_kw: 0.0
     - battery_penetration: 0.0
     - self_sufficiency_ratio: 0.0
     - total_demand_kw: 0.3503184523809521
     - lv_group_id: LV_GROUP_0021
     - sharing_benefit_kwh: 0.0
     - member_count: 4
     - avg_shared_walls: 2.0

2. Cluster ID: ROW_LV_GROUP_0021_4818281
   Buildings: 4
   Prop

In [None]:
# test_data_pipeline.py
"""
Test script for the complete data processing pipeline.
Tests: kg_connector, graph_constructor, feature_processor, data_loader
"""

import sys
import logging
import torch
import numpy as np
from pathlib import Path

# Add your project path if needed
# sys.path.append('../')

from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
from data.feature_processor import FeatureProcessor
from data.data_loader import TaskSpecificLoader, create_train_val_test_loaders

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class DataPipelineTest:
    """Test suite for the complete data pipeline."""
    
    def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
        """Initialize test suite with Neo4j connection."""
        self.uri = neo4j_uri
        self.user = neo4j_user
        self.password = neo4j_password
        self.kg_connector = None
        self.graph_constructor = None
        self.feature_processor = None
        self.results = {}
        
    def test_kg_connector(self):
        """Test 1: KG Connector functionality."""
        print("\n" + "="*60)
        print("TEST 1: KG CONNECTOR")
        print("="*60)
        
        try:
            # Initialize connector
            self.kg_connector = KGConnector(self.uri, self.user, self.password)
            
            # Test connection
            is_connected = self.kg_connector.verify_connection()
            assert is_connected, "Failed to connect to Neo4j"
            print("✓ Neo4j connection successful")
            
            # Test getting district hierarchy
            district = "Buitenveldert-Oost"  # Replace with your district
            hierarchy = self.kg_connector.get_district_hierarchy(district)
            
            if hierarchy:
                print(f"✓ Retrieved hierarchy for district {district}")
                if 'transformers' in hierarchy:
                    print(f"  - Found {len(hierarchy['transformers'])} transformers")
            else:
                print(f"⚠ No hierarchy data found for {district}")
            
            # Test getting grid topology
            topology = self.kg_connector.get_grid_topology(district)
            assert topology is not None, "Failed to get topology"
            
            print("✓ Grid topology retrieved:")
            for node_type, nodes in topology['nodes'].items():
                if nodes:
                    print(f"  - {node_type}: {len(nodes)} nodes")
            
            # Test getting retrofit candidates
            candidates = self.kg_connector.get_retrofit_candidates(
                district, 
                energy_labels=['E', 'F', 'G'],
                age_filter='19'
            )
            print(f"✓ Found {len(candidates)} cable groups with retrofit candidates")
            
            # Test getting adjacency clusters
            clusters = self.kg_connector.get_adjacency_clusters(district, min_cluster_size=3)
            print(f"✓ Found {len(clusters)} adjacency clusters")
            
            # Test getting time series (if available)
            if topology['nodes']['buildings']:
                sample_building_ids = [
                    str(b.get('ogc_fid', '')) 
                    for b in topology['nodes']['buildings'][:5]
                    if b.get('ogc_fid')
                ]
                
                if sample_building_ids:
                    time_series = self.kg_connector.get_building_time_series(
                        sample_building_ids, 
                        lookback_hours=24
                    )
                    
                    if time_series:
                        print(f"✓ Retrieved time series for {len(time_series)} buildings")
                        for bid, ts_data in list(time_series.items())[:1]:
                            print(f"  - Building {bid}: shape {ts_data.shape}")
                    else:
                        print("⚠ No time series data available")
            
            self.results['kg_connector'] = "PASSED"
            return True
            
        except Exception as e:
            print(f"✗ KG Connector test failed: {e}")
            self.results['kg_connector'] = f"FAILED: {e}"
            return False
    
    def test_graph_constructor(self):
        """Test 2: Graph Constructor functionality."""
        print("\n" + "="*60)
        print("TEST 2: GRAPH CONSTRUCTOR")
        print("="*60)
        
        try:
            # Initialize graph constructor
            self.graph_constructor = GraphConstructor(self.kg_connector)
            print("✓ Graph constructor initialized")
            
            district = "Buitenveldert-Oost"  # Replace with your district
            
            # Test building basic graph
            print("\nBuilding basic graph...")
            graph = self.graph_constructor.build_hetero_graph(
                district, 
                include_energy_sharing=True,
                include_temporal=False  # Start without temporal
            )
            
            print("✓ Basic graph built:")
            print(f"  Node types: {graph.node_types}")
            for node_type in graph.node_types:
                if hasattr(graph[node_type], 'x'):
                    print(f"  - {node_type}: {graph[node_type].x.shape}")
            
            print(f"  Edge types: {len(graph.edge_types)}")
            for edge_type in graph.edge_types:
                edge_index = graph[edge_type].edge_index
                print(f"  - {edge_type}: {edge_index.shape[1]} edges")
            
            # Test building graph with temporal features
            print("\nBuilding graph with temporal features...")
            graph_temporal = self.graph_constructor.build_hetero_graph(
                district,
                include_energy_sharing=True,
                include_temporal=True,
                lookback_hours=24
            )
            
            # Check temporal features
            has_temporal = False
            for node_type in graph_temporal.node_types:
                if hasattr(graph_temporal[node_type], 'x_temporal'):
                    has_temporal = True
                    shape = graph_temporal[node_type].x_temporal.shape
                    print(f"✓ Temporal features for {node_type}: {shape}")
            
            if not has_temporal:
                print("⚠ No temporal features found (may not have time series data)")
            
            # Test task-specific graphs
            print("\nBuilding task-specific graphs...")
            
            # Retrofit graph
            retrofit_graph = self.graph_constructor._build_retrofit_graph(
                district,
                energy_labels=['E', 'F', 'G']
            )
            if 'building' in retrofit_graph.node_types and hasattr(retrofit_graph['building'], 'y'):
                retrofit_labels = retrofit_graph['building'].y
                print(f"✓ Retrofit graph: {retrofit_labels.sum().item():.0f} retrofit candidates")
            
            # Energy sharing graph
            sharing_graph = self.graph_constructor._build_energy_sharing_graph(
                district,
                min_cluster_size=3
            )
            if 'adjacency_cluster' in sharing_graph.node_types:
                print(f"✓ Energy sharing graph built")
            
            # Solar graph
            solar_graph = self.graph_constructor._build_solar_graph(district)
            if 'building' in solar_graph.node_types and hasattr(solar_graph['building'], 'y'):
                print(f"✓ Solar graph: max potential {solar_graph['building'].y.max().item():.0f} kWh/year")
            
            self.results['graph_constructor'] = "PASSED"
            self.graph = graph  # Save for next tests
            return True
            
        except Exception as e:
            print(f"✗ Graph Constructor test failed: {e}")
            self.results['graph_constructor'] = f"FAILED: {e}"
            return False
    
    def test_feature_processor(self):
        """Test 3: Feature Processor functionality."""
        print("\n" + "="*60)
        print("TEST 3: FEATURE PROCESSOR")
        print("="*60)
        
        try:
            # Initialize feature processor
            self.feature_processor = FeatureProcessor()
            print("✓ Feature processor initialized")
            
            # Check if we have a graph from previous test
            if not hasattr(self, 'graph'):
                print("⚠ No graph available, building new one...")
                district = "Buitenveldert-Oost"
                self.graph = self.graph_constructor.build_hetero_graph(district)
            
            # Test processing graph features
            print("\nProcessing graph features...")
            original_shapes = {}
            for node_type in self.graph.node_types:
                if hasattr(self.graph[node_type], 'x'):
                    original_shapes[node_type] = self.graph[node_type].x.shape
            
            # Process features
            self.feature_processor.process_graph_features(self.graph, fit=True)
            
            print("✓ Features processed:")
            for node_type in self.graph.node_types:
                if hasattr(self.graph[node_type], 'x'):
                    processed_shape = self.graph[node_type].x.shape
                    print(f"  - {node_type}: {original_shapes[node_type]} -> {processed_shape}")
                    
                    # Check for engineered features
                    if hasattr(self.graph[node_type], 'x_engineered'):
                        eng_shape = self.graph[node_type].x_engineered.shape
                        print(f"    + Engineered features: {eng_shape}")
            
            # Test task-specific features
            print("\nCreating task-specific features...")
            
            for task in ['retrofit', 'energy_sharing', 'solar', 'electrification']:
                task_features = self.feature_processor.create_task_specific_features(
                    self.graph, task
                )
                if task_features:
                    print(f"✓ {task} features:")
                    for feat_name, feat_tensor in task_features.items():
                        if feat_tensor is not None:
                            print(f"  - {feat_name}: shape {feat_tensor.shape}")
            
            # Test temporal feature processing if available
            if any(hasattr(self.graph[nt], 'x_temporal') for nt in self.graph.node_types):
                print("\nProcessing temporal features...")
                for node_type in self.graph.node_types:
                    if hasattr(self.graph[node_type], 'x_temporal'):
                        temporal = self.graph[node_type].x_temporal
                        processed_temporal = self.feature_processor.process_temporal_features(
                            temporal, normalize=True
                        )
                        print(f"✓ {node_type} temporal: {temporal.shape} -> {processed_temporal.shape}")
                        
                        # Test pattern extraction
                        patterns = self.feature_processor.extract_temporal_patterns(temporal)
                        if patterns:
                            print(f"  Extracted patterns: {list(patterns.keys())}")
            
            # Test saving/loading processors
            print("\nTesting save/load processors...")
            save_path = "test_processors.pkl"
            self.feature_processor.save_processors(save_path)
            print(f"✓ Saved processors to {save_path}")
            
            # Create new processor and load
            new_processor = FeatureProcessor()
            new_processor.load_processors(save_path)
            print(f"✓ Loaded processors from {save_path}")
            
            # Clean up
            Path(save_path).unlink(missing_ok=True)
            
            self.results['feature_processor'] = "PASSED"
            return True
            
        except Exception as e:
            print(f"✗ Feature Processor test failed: {e}")
            self.results['feature_processor'] = f"FAILED: {e}"
            return False
    








# Replace the test_data_loader method in your test file with this:

    def test_data_loader(self):
        """Test 4: Data Loader functionality."""
        print("\n" + "="*60)
        print("TEST 4: DATA LOADER")
        print("="*60)
        
        try:
            # Ensure we have a properly processed graph
            if not hasattr(self, 'graph') or self.graph is None:
                print("⚠ Building and processing graph...")
                district = "Buitenveldert-Oost"
                
                # Build graph WITHOUT temporal features for basic testing
                self.graph = self.graph_constructor.build_hetero_graph(
                    district,
                    include_energy_sharing=True,
                    include_temporal=False  # Start without temporal
                )
                
                # Process features to ensure they're tensors
                self.feature_processor.process_graph_features(self.graph)
            
            # Verify all features are tensors
            print("\nVerifying feature types...")
            for node_type in self.graph.node_types:
                if hasattr(self.graph[node_type], 'x'):
                    features = self.graph[node_type].x
                    if not isinstance(features, torch.Tensor):
                        print(f"⚠ Converting {node_type} features to tensor")
                        self.graph[node_type].x = torch.tensor(features, dtype=torch.float)
                    print(f"✓ {node_type} features are tensors: {features.shape}")
            
            # Test TaskSpecificLoader
            print("\nTesting TaskSpecificLoader...")
            loader_creator = TaskSpecificLoader(batch_size=16)
            
            # Test different task loaders
            tasks = ['retrofit', 'energy_sharing', 'solar', 'grid_planning', 'electrification']
            
            for task in tasks:
                print(f"\nTesting {task} loader...")
                
                try:
                    # Create train/val/test splits
                    train_loader, val_loader, test_loader = create_train_val_test_loaders(
                        self.graph,
                        task=task,
                        train_ratio=0.7,
                        val_ratio=0.15,
                        batch_size=16
                    )
                    
                    print(f"✓ Created {task} loaders")
                    
                    # Test loading a batch
                    batch_count = 0
                    for batch in train_loader:
                        if batch_count == 0:  # Just test first batch
                            print(f"  Sample batch:")
                            
                            # Check node types in batch
                            for node_type in batch.node_types:
                                if hasattr(batch[node_type], 'x'):
                                    x = batch[node_type].x
                                    print(f"    - {node_type}: {x.shape} (type: {type(x).__name__})")
                            
                            # Check edge types
                            edge_count = sum(1 for _ in batch.edge_types)
                            print(f"    - Edge types: {edge_count}")
                            
                            batch_count += 1
                            break
                    
                    if batch_count == 0:
                        print(f"  ⚠ No batches generated for {task}")
                        
                except Exception as e:
                    print(f"  ⚠ {task} loader error: {e}")
                    continue
            
            self.results['data_loader'] = "PASSED"
            return True
            
        except Exception as e:
            print(f"✗ Data Loader test failed: {e}")
            self.results['data_loader'] = f"FAILED: {e}"
            return False












    def run_all_tests(self):
        """Run all tests in sequence."""
        print("\n" + "="*60)
        print("RUNNING COMPLETE DATA PIPELINE TESTS")
        print("="*60)
        
        # Run tests in order
        tests = [
            ('KG Connector', self.test_kg_connector),
            ('Graph Constructor', self.test_graph_constructor),
            ('Feature Processor', self.test_feature_processor),
            ('Data Loader', self.test_data_loader)
        ]
        
        for test_name, test_func in tests:
            try:
                success = test_func()
                if not success:
                    print(f"\n⚠ {test_name} failed, but continuing with other tests...")
            except Exception as e:
                print(f"\n✗ {test_name} crashed: {e}")
                self.results[test_name.lower().replace(' ', '_')] = f"CRASHED: {e}"
        
        # Print summary
        print("\n" + "="*60)
        print("TEST SUMMARY")
        print("="*60)
        
        for component, result in self.results.items():
            status = "✓" if result == "PASSED" else "✗"
            print(f"{status} {component}: {result}")
        
        # Close connections
        if self.kg_connector:
            self.kg_connector.close()
            print("\n✓ Neo4j connection closed")
    
    def quick_integration_test(self):
        """Quick test to verify the complete pipeline works end-to-end."""
        print("\n" + "="*60)
        print("QUICK INTEGRATION TEST")
        print("="*60)
        
        try:
            # 1. Connect to KG
            kg = KGConnector(self.uri, self.user, self.password)
            assert kg.verify_connection(), "KG connection failed"
            print("✓ Step 1: Connected to Neo4j")
            
            # 2. Build graph
            constructor = GraphConstructor(kg)
            district = "Buitenveldert-Oost"  # Replace with your district
            graph = constructor.build_hetero_graph(
                district,
                include_energy_sharing=True,
                include_temporal=True
            )
            print(f"✓ Step 2: Built graph with {len(graph.node_types)} node types")
            
            # 3. Process features
            processor = FeatureProcessor()
            processor.process_graph_features(graph, fit=True)
            print("✓ Step 3: Processed features")
            
            # 4. Create data loaders
            train_loader, val_loader, test_loader = create_train_val_test_loaders(
                graph,
                task='retrofit',
                batch_size=32
            )
            print("✓ Step 4: Created data loaders")
            
            # 5. Test one batch
            for batch in train_loader:
                print(f"✓ Step 5: Successfully loaded batch with {batch.num_nodes} total nodes")
                break
            
            print("\n✓ INTEGRATION TEST PASSED!")
            kg.close()
            return True
            
        except Exception as e:
            print(f"\n✗ INTEGRATION TEST FAILED: {e}")
            if kg:
                kg.close()
            return False


# Main execution
if __name__ == "__main__":
    # Configuration - UPDATE THESE VALUES
    NEO4J_URI = "bolt://localhost:7687"
    NEO4J_USER = "neo4j"
    NEO4J_PASSWORD = "aminasad"
    
    # Create test suite
    tester = DataPipelineTest(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    
    # Choose test mode
    print("Select test mode:")
    print("1. Run all tests (comprehensive)")
    print("2. Quick integration test")
    print("3. Test specific component")
    
    choice = input("\nEnter choice (1/2/3): ").strip()
    
    if choice == "1":
        tester.run_all_tests()
    elif choice == "2":
        tester.quick_integration_test()
    elif choice == "3":
        print("\nSelect component:")
        print("1. KG Connector")
        print("2. Graph Constructor")
        print("3. Feature Processor")
        print("4. Data Loader")
        
        comp_choice = input("\nEnter choice (1-4): ").strip()
        
        if comp_choice == "1":
            tester.test_kg_connector()
        elif comp_choice == "2":
            tester.test_kg_connector()  # Need KG first
            tester.test_graph_constructor()
        elif comp_choice == "3":
            tester.test_kg_connector()
            tester.test_graph_constructor()
            tester.test_feature_processor()
        elif comp_choice == "4":
            tester.test_kg_connector()
            tester.test_graph_constructor()
            tester.test_feature_processor()
            tester.test_data_loader()
    else:
        print("Running quick integration test by default...")
        tester.quick_integration_test()

Select test mode:
1. Run all tests (comprehensive)
2. Quick integration test
3. Test specific component


2025-08-20 18:15:01,663 - data.kg_connector - INFO - Connected to Neo4j at bolt://localhost:7687



RUNNING COMPLETE DATA PIPELINE TESTS

TEST 1: KG CONNECTOR
✓ Neo4j connection successful
✓ Retrieved hierarchy for district Buitenveldert-Oost
  - Found 6 transformers


2025-08-20 18:15:03,984 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346


✓ Grid topology retrieved:
  - buildings: 335 nodes
  - cable_groups: 21 nodes
  - transformers: 44 nodes
  - adjacency_clusters: 95 nodes
✓ Found 4 cable groups with retrofit candidates


2025-08-20 18:15:04,275 - data.kg_connector - INFO - Fetching time series for 5 buildings
2025-08-20 18:15:04,275 - data.kg_connector - INFO - Time range: 1706396400000 to 1706482800000 (24 hours)
2025-08-20 18:15:04,446 - data.kg_connector - INFO - Retrieved time series for 5 buildings out of 5 requested
2025-08-20 18:15:04,447 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 18:15:04,448 - data.graph_constructor - INFO - Temporal features: False, Lookback: 24 hours


✓ Found 95 adjacency clusters
✓ Retrieved time series for 5 buildings
  - Building 4236986: shape (24, 8)

TEST 2: GRAPH CONSTRUCTOR
✓ Graph constructor initialized

Building basic graph...


2025-08-20 18:15:04,590 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 18:15:04,819 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:15:04,820 - data.graph_constructor - INFO - Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
2025-08-20 18:15:04,821 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 18:15:04,821 - data.graph_constructor - INFO - Temporal features: True, Lookback: 24 hours
2025-08-20 18:15:04,970 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 18:15:04,973 - data.graph_constructor - INFO - Fetching temporal features for 335 buildings...
2025-08-20 18:15:04,978 - data.kg_connector - INFO - Fetching time series for 335 buildings
2025-08-20 18:15:04,979 - data.kg_connector - INFO - Time range: 1706396400000 to 1

✓ Basic graph built:
  Node types: ['building', 'cable_group', 'transformer', 'adjacency_cluster']
  - building: torch.Size([335, 17])
  - cable_group: torch.Size([21, 12])
  - transformer: torch.Size([44, 3])
  - adjacency_cluster: torch.Size([95, 11])
  Edge types: 3
  - ('building', 'connected_to', 'cable_group'): 335 edges
  - ('cable_group', 'connects_to', 'transformer'): 13 edges
  - ('building', 'in_cluster', 'adjacency_cluster'): 346 edges

Building graph with temporal features...


2025-08-20 18:15:05,509 - data.kg_connector - INFO - Retrieved time series for 335 buildings out of 335 requested
2025-08-20 18:15:05,513 - data.graph_constructor - INFO - Added temporal features: torch.Size([335, 24, 8])
2025-08-20 18:15:05,704 - data.graph_constructor - INFO - Fetching temporal features for 95 clusters...
2025-08-20 18:15:06,038 - data.graph_constructor - INFO - Added cluster temporal features: torch.Size([95, 24, 7])
2025-08-20 18:15:06,040 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:15:06,041 - data.graph_constructor - INFO - Temporal features added for: ['building', 'adjacency_cluster']
2025-08-20 18:15:06,041 - data.graph_constructor - INFO - Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
2025-08-20 18:15:06,050 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 18:15:06,050 - data.graph_cons

✓ Temporal features for building: torch.Size([335, 24, 8])
✓ Temporal features for adjacency_cluster: torch.Size([95, 24, 7])

Building task-specific graphs...


2025-08-20 18:15:06,624 - data.kg_connector - INFO - Retrieved time series for 335 buildings out of 335 requested
2025-08-20 18:15:06,627 - data.graph_constructor - INFO - Added temporal features: torch.Size([335, 24, 8])
2025-08-20 18:15:06,728 - data.graph_constructor - INFO - Fetching temporal features for 95 clusters...
2025-08-20 18:15:06,965 - data.graph_constructor - INFO - Added cluster temporal features: torch.Size([95, 24, 7])
2025-08-20 18:15:06,967 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:15:06,967 - data.graph_constructor - INFO - Temporal features added for: ['building', 'adjacency_cluster']
2025-08-20 18:15:06,968 - data.graph_constructor - INFO - Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
2025-08-20 18:15:07,153 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 18:15:07,153 - data.graph_cons

✓ Retrofit graph: 6 retrofit candidates


2025-08-20 18:15:07,337 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 18:15:07,340 - data.graph_constructor - INFO - Fetching temporal features for 335 buildings...
2025-08-20 18:15:07,344 - data.kg_connector - INFO - Fetching time series for 335 buildings
2025-08-20 18:15:07,345 - data.kg_connector - INFO - Time range: 1706396400000 to 1706482800000 (24 hours)
2025-08-20 18:15:07,909 - data.kg_connector - INFO - Retrieved time series for 335 buildings out of 335 requested
2025-08-20 18:15:07,913 - data.graph_constructor - INFO - Added temporal features: torch.Size([335, 24, 8])
2025-08-20 18:15:08,026 - data.graph_constructor - INFO - Fetching temporal features for 95 clusters...
2025-08-20 18:15:08,350 - data.graph_constructor - INFO - Added cluster temporal features: torch.Size([95, 24, 7])
2025-08-20 18:15:08,352 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:15:08,3

✓ Energy sharing graph built


2025-08-20 18:15:08,944 - data.kg_connector - INFO - Retrieved time series for 335 buildings out of 335 requested
2025-08-20 18:15:08,948 - data.graph_constructor - INFO - Added temporal features: torch.Size([335, 24, 8])
2025-08-20 18:15:09,069 - data.graph_constructor - INFO - Fetching temporal features for 95 clusters...
2025-08-20 18:15:09,331 - data.graph_constructor - INFO - Added cluster temporal features: torch.Size([95, 24, 7])
2025-08-20 18:15:09,332 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:15:09,334 - data.graph_constructor - INFO - Temporal features added for: ['building', 'adjacency_cluster']
2025-08-20 18:15:09,334 - data.graph_constructor - INFO - Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
2025-08-20 18:15:09,336 - data.feature_processor - INFO - Processing graph features
2025-08-20 18:15:09,345 - data.feature_processor - INFO - Adde

✓ Solar graph: max potential 2666089 kWh/year

TEST 3: FEATURE PROCESSOR
✓ Feature processor initialized

Processing graph features...
✓ Features processed:
  - building: torch.Size([335, 17]) -> torch.Size([335, 17])
    + Engineered features: torch.Size([335, 7])
  - cable_group: torch.Size([21, 12]) -> torch.Size([21, 12])
    + Engineered features: torch.Size([21, 4])
  - transformer: torch.Size([44, 3]) -> torch.Size([44, 3])
  - adjacency_cluster: torch.Size([95, 11]) -> torch.Size([95, 11])
    + Engineered features: torch.Size([95, 5])

Creating task-specific features...
✓ retrofit features:
  - retrofit_priority: shape torch.Size([335])
  - retrofit_cost: shape torch.Size([335])
✓ energy_sharing features:
  - self_sufficiency: shape torch.Size([95])
  - system_penetration: shape torch.Size([95])
  - sharing_efficiency: shape torch.Size([95])
✓ solar features:
  - solar_suitability: shape torch.Size([335])
  - solar_generation: shape torch.Size([335])
✓ electrification features

In [None]:
# full_pipeline_test.py
"""
Full pipeline test now that basic connectivity works.
"""

import torch
import numpy as np
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
from data.feature_processor import FeatureProcessor
from data.data_loader import create_train_val_test_loaders
import logging

logging.basicConfig(level=logging.INFO)

def test_full_pipeline():
    """Test the complete pipeline."""
    
    print("="*80)
    print("FULL PIPELINE TEST")
    print("="*80)
    
    # 1. Connect to KG
    print("\n1. Connecting to KG...")
    kg = KGConnector(
        uri="bolt://localhost:7687",
        user="neo4j",
        password="aminasad"
    )
    
    if not kg.verify_connection():
        print("❌ Failed to connect to Neo4j")
        return
    print("✅ Connected to Neo4j")
    
    # 2. Build graph with temporal features
    print("\n2. Building graph with temporal features...")
    constructor = GraphConstructor(kg)
    
    try:
        graph = constructor.build_hetero_graph(
            district_name="Buitenveldert-Oost",
            include_temporal=True,  # Enable temporal
            include_energy_sharing=True,
            lookback_hours=24
        )
        print(f"✅ Graph built: {graph.num_nodes_dict}")
        
        # Check temporal features
        temporal_count = 0
        for node_type in graph.node_types:
            if hasattr(graph[node_type], 'x_temporal'):
                shape = graph[node_type].x_temporal.shape
                print(f"  {node_type} temporal: {shape}")
                temporal_count += 1
        
        if temporal_count == 0:
            print("⚠️ No temporal features found")
        else:
            print(f"✅ Temporal features added for {temporal_count} node types")
            
    except Exception as e:
        print(f"❌ Error building graph: {e}")
        import traceback
        traceback.print_exc()
        kg.close()
        return
    
    # 3. Process features
    print("\n3. Processing features...")
    processor = FeatureProcessor()
    
    try:
        processor.process_graph_features(graph, fit=True)
        
        # Check for engineered features
        eng_count = 0
        for node_type in graph.node_types:
            if hasattr(graph[node_type], 'x_engineered'):
                shape = graph[node_type].x_engineered.shape
                print(f"  {node_type} engineered: {shape}")
                eng_count += 1
        
        if eng_count > 0:
            print(f"✅ Engineered features created for {eng_count} node types")
        else:
            print("⚠️ No engineered features created")
            
    except Exception as e:
        print(f"❌ Error processing features: {e}")
        import traceback
        traceback.print_exc()
        kg.close()
        return
    
    # 4. Test data loaders for each task
    print("\n4. Testing data loaders...")
    tasks = ['retrofit', 'energy_sharing', 'solar', 'electrification']
    
    for task in tasks:
        print(f"\n  Testing {task} loader...")
        try:
            train_loader, val_loader, test_loader = create_train_val_test_loaders(
                graph, task, batch_size=32
            )
            
            # Get one batch to verify
            for batch in train_loader:
                print(f"    ✅ {task}: Batch with {len(batch.node_types)} node types")
                
                # Check if temporal features are preserved
                has_temporal = False
                for node_type in batch.node_types:
                    if hasattr(batch[node_type], 'x_temporal'):
                        has_temporal = True
                        print(f"      Temporal features preserved for {node_type}")
                
                break  # Just check first batch
                
        except Exception as e:
            print(f"    ❌ {task}: {e}")
    
    # 5. Check specific task features
    print("\n5. Testing task-specific features...")
    try:
        for task in ['retrofit', 'energy_sharing', 'solar', 'electrification']:
            task_features = processor.create_task_specific_features(graph, task)
            if task_features:
                print(f"  ✅ {task}: {len(task_features)} specific features")
                for feat_name, feat_tensor in task_features.items():
                    if isinstance(feat_tensor, torch.Tensor):
                        print(f"    - {feat_name}: {feat_tensor.shape}")
            else:
                print(f"  ⚠️ {task}: No specific features")
    except Exception as e:
        print(f"  ❌ Error creating task features: {e}")
    
    # 6. Verify temporal patterns
    print("\n6. Checking temporal patterns...")
    if hasattr(graph['building'], 'x_temporal'):
        try:
            patterns = processor.extract_temporal_patterns(graph['building'].x_temporal)
            if patterns:
                print(f"  ✅ Extracted {len(patterns)} temporal patterns")
                for pattern_name, pattern_tensor in patterns.items():
                    print(f"    - {pattern_name}: {pattern_tensor.shape}")
            else:
                print("  ⚠️ No temporal patterns extracted")
        except Exception as e:
            print(f"  ❌ Error extracting patterns: {e}")
    
    # Close connection
    kg.close()
    
    print("\n" + "="*80)
    print("PIPELINE TEST COMPLETE")
    print("="*80)

if __name__ == "__main__":
    test_full_pipeline()

Couldn't import dot_parser, loading of dot files will not be possible.


INFO:data.kg_connector:Connected to Neo4j at bolt://localhost:7687


FULL PIPELINE TEST

1. Connecting to KG...


INFO:data.graph_constructor:Building graph for district Buitenveldert-Oost
INFO:data.graph_constructor:Temporal features: True, Lookback: 24 hours


✅ Connected to Neo4j

2. Building graph with temporal features...


INFO:data.kg_connector:Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
INFO:data.graph_constructor:Fetching temporal features for 335 buildings...
INFO:data.kg_connector:Fetching time series for 335 buildings
INFO:data.kg_connector:Time range: 1706396400000 to 1706482800000 (24 hours)
INFO:data.kg_connector:Retrieved time series for 335 buildings out of 335 requested
INFO:data.graph_constructor:Added temporal features: torch.Size([335, 24, 8])
INFO:data.graph_constructor:Fetching temporal features for 95 clusters...
INFO:data.graph_constructor:Added cluster temporal features: torch.Size([95, 24, 7])
INFO:data.graph_constructor:No transformer_to_substation edges (substations might not exist)
INFO:data.graph_constructor:Temporal features added for: ['building', 'adjacency_cluster']
INFO:data.graph_constructor:Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
INFO:data.feature_processor:Processing graph features
INFO:da

✅ Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
  building temporal: torch.Size([335, 24, 8])
  adjacency_cluster temporal: torch.Size([95, 24, 7])
✅ Temporal features added for 2 node types

3. Processing features...
  building engineered: torch.Size([335, 7])
  cable_group engineered: torch.Size([21, 4])
  adjacency_cluster engineered: torch.Size([95, 5])
✅ Engineered features created for 3 node types

4. Testing data loaders...

  Testing retrofit loader...
    ✅ retrofit: Batch with 4 node types
      Temporal features preserved for building
      Temporal features preserved for adjacency_cluster

  Testing energy_sharing loader...
    ✅ energy_sharing: Batch with 4 node types
      Temporal features preserved for building
      Temporal features preserved for adjacency_cluster

  Testing solar loader...
    ✅ solar: Batch with 4 node types
      Temporal features preserved for building
      Temporal features preserv

In [None]:
# Quick test script
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor

# Connect
kg = KGConnector("bolt://localhost:7687", "neo4j", "aminasad")

# Get topology
topology = kg.get_grid_topology("Buitenveldert-Oost")

# Check edges
print("Edge counts:")
for edge_type, edges in topology['edges'].items():
    print(f"  {edge_type}: {len(edges)}")

# Build graph
constructor = GraphConstructor(kg)
graph = constructor.build_hetero_graph("Buitenveldert-Oost", include_temporal=False)

print("\nGraph edge types:")
for edge_type in graph.edge_types:
    if hasattr(graph[edge_type], 'edge_index'):
        print(f"  {edge_type}: {graph[edge_type].edge_index.shape}")

kg.close()

INFO:data.kg_connector:Connected to Neo4j at bolt://localhost:7687
INFO:data.kg_connector:Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
INFO:data.graph_constructor:Building graph for district Buitenveldert-Oost
INFO:data.graph_constructor:Temporal features: False, Lookback: 24 hours
INFO:data.kg_connector:Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346


Edge counts:
  building_to_cable: 335
  cable_to_transformer: 13
  transformer_to_substation: 0
  building_to_cluster: 346


INFO:data.graph_constructor:No transformer_to_substation edges (substations might not exist)
INFO:data.graph_constructor:Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
INFO:data.kg_connector:Neo4j connection closed



Graph edge types:
  ('building', 'connected_to', 'cable_group'): torch.Size([2, 335])
  ('building', 'in_cluster', 'adjacency_cluster'): torch.Size([2, 346])


# Model

In [None]:
base_gnn.py (Foundation)
attention_layers.py (Enhanced message passing)
temporal_layers.py (Time series processing)
physics_layers.py (Domain constraints)
task_heads.py (Task-specific outputs)

1. LV-Focused Study Area

You clipped to a local LV network area
Transformers at the boundary (not all cable groups connect within area)
Substations definitely outside (MV/HV level)
This is CORRECT for distribution grid optimization!
2. Adjacency = PHYSICAL (Shared Walls)

These are thermal sharing clusters (row houses, apartments)
NOT electrical neighborhoods
2-8 buildings sharing walls makes perfect sense
Energy sharing = thermal energy through walls, not electrical P2P!


we ned to focus on LV group buildings right? as we clipped area and made analysis, it is possible that some not connected to mv or hv, or not exist in that clipped area!
for adjancy do you consider that that adjancy means neighbour buildings that share wall! not other thing!??







You're absolutely right - this is a clipped area effect, not a data error! When you clip any area:

Cable groups get cut (222 buildings = partial group, most outside boundary)
Transformers are at edges (naturally disconnected)
This is NORMAL for any spatial clip!





ELECTRICAL ENERGY:                    THERMAL ENERGY:
┌──────────────────┐                 ┌──────────────────┐
│ Flows through:   │                 │ Flows through:   │
│ • Cable Groups   │                 │ • Shared Walls   │
│ • Transformers   │                 │ • Adjacent air   │
│ • Grid wires     │                 │ • Physical contact│
└──────────────────┘                 └──────────────────┘
     ↓                                      ↓
Buildings in same                    Buildings in same
cable group                          adjacency cluster



Why We Need Cluster Temporal:
For THERMAL sharing (heating):
python# Adjacency cluster = buildings sharing walls
Building A & B share a wall:

Time  | Building A Heat | Building B Heat | Opportunity
18:00 | HIGH (cooking)  | LOW (empty)     | B→A thermal transfer
02:00 | LOW (sleeping)  | HIGH (night shift) | A→B thermal transfer
For ELECTRICAL sharing:
python# This uses CABLE GROUPS, not adjacency clusters!
Buildings in same cable group:

Time  | Building A Elec | Building B Elec | Opportunity  
12:00 | HIGH (AC on)    | LOW (solar gen) | B→A via cable group

In [None]:
# graph_diagnostics.py
"""
Diagnostic script to understand the actual graph structure and features
Run this to get detailed information about your data
"""

import torch
import numpy as np
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
from data.feature_processor import FeatureProcessor
from data.data_loader import TaskSpecificLoader
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def analyze_graph_structure(graph):
    """Analyze the graph structure in detail"""
    
    print("\n" + "="*60)
    print("GRAPH STRUCTURE ANALYSIS")
    print("="*60)
    
    # Node types and counts
    print("\n1. NODE TYPES AND COUNTS:")
    print("-"*40)
    for node_type in graph.node_types:
        if hasattr(graph[node_type], 'x'):
            count = graph[node_type].x.shape[0]
            features = graph[node_type].x.shape[1]
            print(f"  {node_type:20s}: {count:5d} nodes, {features:3d} features")
            
            # Check for temporal features
            if hasattr(graph[node_type], 'x_temporal'):
                temp_shape = graph[node_type].x_temporal.shape
                print(f"    └── Temporal: {temp_shape}")
            
            # Check for engineered features
            if hasattr(graph[node_type], 'x_engineered'):
                eng_shape = graph[node_type].x_engineered.shape
                print(f"    └── Engineered: {eng_shape}")
    
    # Edge types and counts
    print("\n2. EDGE TYPES AND COUNTS:")
    print("-"*40)
    for edge_type in graph.edge_types:
        edge_index = graph[edge_type].edge_index
        num_edges = edge_index.shape[1]
        src_type, rel_type, dst_type = edge_type
        print(f"  ({src_type}, {rel_type}, {dst_type})")
        print(f"    └── {num_edges} edges")
        
        # Check connectivity
        if num_edges > 0:
            src_nodes = edge_index[0].unique().numel()
            dst_nodes = edge_index[1].unique().numel()
            print(f"    └── Connects {src_nodes} {src_type} to {dst_nodes} {dst_type}")
    
    # Check for isolated nodes
    print("\n3. CONNECTIVITY CHECK:")
    print("-"*40)
    for node_type in graph.node_types:
        if hasattr(graph[node_type], 'x'):
            total_nodes = graph[node_type].x.shape[0]
            connected_nodes = set()
            
            # Check all edges involving this node type
            for edge_type in graph.edge_types:
                src_type, _, dst_type = edge_type
                edge_index = graph[edge_type].edge_index
                
                if src_type == node_type:
                    connected_nodes.update(edge_index[0].tolist())
                if dst_type == node_type:
                    connected_nodes.update(edge_index[1].tolist())
            
            isolated = total_nodes - len(connected_nodes)
            print(f"  {node_type}: {isolated} isolated nodes out of {total_nodes}")
    
    return graph

def analyze_features(graph):
    """Analyze feature distributions and properties"""
    
    print("\n" + "="*60)
    print("FEATURE ANALYSIS")
    print("="*60)
    
    for node_type in graph.node_types:
        if hasattr(graph[node_type], 'x'):
            features = graph[node_type].x
            
            print(f"\n{node_type.upper()} FEATURES:")
            print("-"*40)
            
            # Basic statistics
            print(f"  Shape: {features.shape}")
            print(f"  Range: [{features.min():.3f}, {features.max():.3f}]")
            print(f"  Mean: {features.mean():.3f}")
            print(f"  Std: {features.std():.3f}")
            
            # Check for NaN or Inf
            nan_count = torch.isnan(features).sum().item()
            inf_count = torch.isinf(features).sum().item()
            if nan_count > 0:
                print(f"  ⚠️ WARNING: {nan_count} NaN values!")
            if inf_count > 0:
                print(f"  ⚠️ WARNING: {inf_count} Inf values!")
            
            # Feature-wise statistics for buildings (most important)
            if node_type == 'building':
                print("\n  Feature-wise statistics:")
                feature_names = [
                    'area', 'energy_score', 'solar_score', 'electrify_score',
                    'age', 'roof_area', 'height', 'has_solar', 'has_battery',
                    'has_heat_pump', 'shared_walls', 'x_coord', 'y_coord',
                    'avg_electricity', 'avg_heating', 'peak_electricity', 'energy_intensity'
                ]
                
                for i, name in enumerate(feature_names[:features.shape[1]]):
                    col = features[:, i]
                    print(f"    {i:2d}. {name:20s}: mean={col.mean():7.3f}, std={col.std():7.3f}, "
                          f"min={col.min():7.3f}, max={col.max():7.3f}")
            
            # Temporal features analysis
            if hasattr(graph[node_type], 'x_temporal'):
                temporal = graph[node_type].x_temporal
                print(f"\n  TEMPORAL FEATURES:")
                print(f"    Shape: {temporal.shape} (nodes, timesteps, features)")
                print(f"    Range: [{temporal.min():.3f}, {temporal.max():.3f}]")
                
                # Check temporal patterns
                if len(temporal.shape) == 3:
                    # Average across nodes
                    avg_pattern = temporal.mean(dim=0)  # [timesteps, features]
                    print(f"    Temporal pattern shape: {avg_pattern.shape}")
                    
                    # Check for variation
                    temporal_std = temporal.std(dim=1).mean()
                    print(f"    Average temporal variation: {temporal_std:.3f}")

def analyze_hierarchy(graph):
    """Analyze the grid hierarchy structure"""
    
    print("\n" + "="*60)
    print("GRID HIERARCHY ANALYSIS")
    print("="*60)
    
    # Check building -> cable_group connections
    if ('building', 'connected_to', 'cable_group') in graph.edge_types:
        edge_index = graph['building', 'connected_to', 'cable_group'].edge_index
        
        # Buildings per cable group
        cable_groups = {}
        for i in range(edge_index.shape[1]):
            building = edge_index[0, i].item()
            cable_group = edge_index[1, i].item()
            
            if cable_group not in cable_groups:
                cable_groups[cable_group] = []
            cable_groups[cable_group].append(building)
        
        sizes = [len(buildings) for buildings in cable_groups.values()]
        print(f"\n  Buildings per Cable Group:")
        print(f"    Min: {min(sizes) if sizes else 0}")
        print(f"    Max: {max(sizes) if sizes else 0}")
        print(f"    Mean: {np.mean(sizes) if sizes else 0:.1f}")
        print(f"    Total Cable Groups: {len(cable_groups)}")
    
    # Check cable_group -> transformer connections
    if ('cable_group', 'connects_to', 'transformer') in graph.edge_types:
        edge_index = graph['cable_group', 'connects_to', 'transformer'].edge_index
        
        # Cable groups per transformer
        transformers = {}
        for i in range(edge_index.shape[1]):
            cable_group = edge_index[0, i].item()
            transformer = edge_index[1, i].item()
            
            if transformer not in transformers:
                transformers[transformer] = []
            transformers[transformer].append(cable_group)
        
        sizes = [len(cables) for cables in transformers.values()]
        print(f"\n  Cable Groups per Transformer:")
        print(f"    Min: {min(sizes) if sizes else 0}")
        print(f"    Max: {max(sizes) if sizes else 0}")
        print(f"    Mean: {np.mean(sizes) if sizes else 0:.1f}")
        print(f"    Total Transformers: {len(transformers)}")
    
    # Check adjacency clusters
    if ('building', 'in_cluster', 'adjacency_cluster') in graph.edge_types:
        edge_index = graph['building', 'in_cluster', 'adjacency_cluster'].edge_index
        
        # Buildings per cluster
        clusters = {}
        for i in range(edge_index.shape[1]):
            building = edge_index[0, i].item()
            cluster = edge_index[1, i].item()
            
            if cluster not in clusters:
                clusters[cluster] = []
            clusters[cluster].append(building)
        
        sizes = [len(buildings) for buildings in clusters.values()]
        print(f"\n  Buildings per Adjacency Cluster:")
        print(f"    Min: {min(sizes) if sizes else 0}")
        print(f"    Max: {max(sizes) if sizes else 0}")
        print(f"    Mean: {np.mean(sizes) if sizes else 0:.1f}")
        print(f"    Total Clusters: {len(clusters)}")
        
        # Cluster size distribution
        print(f"\n  Cluster Size Distribution:")
        for size in [3, 5, 10, 20, 50]:
            count = sum(1 for s in sizes if s >= size)
            print(f"    >= {size:2d} buildings: {count} clusters")

def analyze_task_requirements(graph):
    """Analyze requirements for each task"""
    
    print("\n" + "="*60)
    print("TASK-SPECIFIC REQUIREMENTS")
    print("="*60)
    
    # Retrofit task
    print("\n1. RETROFIT TASK:")
    print("-"*40)
    if 'building' in graph.node_types and hasattr(graph['building'], 'x'):
        features = graph['building'].x
        energy_scores = features[:, 1]  # Index 1 is energy_score
        poor_buildings = (energy_scores <= 3).sum().item()  # E, F, G labels
        print(f"  Buildings with poor energy labels (<=3): {poor_buildings}")
        print(f"  Percentage: {poor_buildings / features.shape[0] * 100:.1f}%")
    
    # Solar optimization
    print("\n2. SOLAR OPTIMIZATION:")
    print("-"*40)
    if 'building' in graph.node_types and hasattr(graph['building'], 'x'):
        features = graph['building'].x
        solar_scores = features[:, 2]  # Index 2 is solar_score
        roof_areas = features[:, 5]    # Index 5 is roof_area
        has_solar = features[:, 7]     # Index 7 is has_solar
        
        good_solar = (solar_scores >= 2).sum().item()
        large_roofs = (roof_areas >= 100).sum().item()
        existing_solar = has_solar.sum().item()
        
        print(f"  Buildings with good solar potential (>=2): {good_solar}")
        print(f"  Buildings with large roofs (>=100m²): {large_roofs}")
        print(f"  Buildings with existing solar: {existing_solar}")
        print(f"  Potential new installations: {good_solar - existing_solar}")
    
    # Energy sharing
    print("\n3. ENERGY SHARING:")
    print("-"*40)
    if 'adjacency_cluster' in graph.node_types and hasattr(graph['adjacency_cluster'], 'x'):
        cluster_features = graph['adjacency_cluster'].x
        member_counts = cluster_features[:, 0]  # Index 0 is member_count
        sharing_potential = cluster_features[:, 1]  # Index 1 is sharing_potential
        
        viable_clusters = (member_counts >= 3).sum().item()
        high_potential = (sharing_potential >= 0.5).sum().item()
        
        print(f"  Clusters with >=3 members: {viable_clusters}")
        print(f"  Clusters with high sharing potential: {high_potential}")
        
        # Check temporal features
        if hasattr(graph['adjacency_cluster'], 'x_temporal'):
            print(f"  ✓ Temporal features available for energy sharing")
    
    # Electrification
    print("\n4. ELECTRIFICATION:")
    print("-"*40)
    if 'building' in graph.node_types and hasattr(graph['building'], 'x'):
        features = graph['building'].x
        electrify_scores = features[:, 3]  # Index 3 is electrify_score
        has_heat_pump = features[:, 9]     # Index 9 is has_heat_pump
        
        ready_buildings = (electrify_scores >= 2).sum().item()
        existing_hp = has_heat_pump.sum().item()
        
        print(f"  Buildings ready for electrification (>=2): {ready_buildings}")
        print(f"  Buildings with existing heat pumps: {existing_hp}")
        print(f"  Potential new heat pumps: {ready_buildings - existing_hp}")

def main():
    """Run complete diagnostics"""
    
    # Configuration
    NEO4J_URI = "bolt://localhost:7687"
    NEO4J_USER = "neo4j"
    NEO4J_PASSWORD = "aminasad"
    DISTRICT = "Buitenveldert-Oost"
    
    print("\n" + "="*60)
    print("ENERGY GNN - GRAPH DIAGNOSTICS")
    print("="*60)
    
    # Initialize components
    print("\nInitializing components...")
    kg = KGConnector(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    graph_builder = GraphConstructor(kg)
    feature_processor = FeatureProcessor()
    
    # Build graph with temporal features
    print(f"\nBuilding graph for district: {DISTRICT}")
    graph = graph_builder.build_hetero_graph(
        DISTRICT,
        include_energy_sharing=True,
        include_temporal=True,
        lookback_hours=24
    )
    
    # Process features
    print("\nProcessing features...")
    feature_processor.process_graph_features(graph, fit=True)
    
    # Run analyses
    analyze_graph_structure(graph)
    analyze_features(graph)
    analyze_hierarchy(graph)
    analyze_task_requirements(graph)
    
    # Test data loaders for each task
    print("\n" + "="*60)
    print("DATA LOADER TESTING")
    print("="*60)
    
    loader_creator = TaskSpecificLoader(batch_size=32)
    
    tasks = ['retrofit', 'energy_sharing', 'solar', 'grid_planning', 'electrification']
    
    for task in tasks:
        print(f"\n{task.upper()} LOADER:")
        print("-"*40)
        try:
            loader = loader_creator.create_loader(graph, task, 'train')
            
            # Get first batch
            for batch in loader:
                print(f"  Batch info:")
                for key, value in batch.items():
                    if hasattr(value, 'shape'):
                        print(f"    {key}: {value.shape}")
                    elif hasattr(value, 'edge_index'):
                        print(f"    {key}: edge_index shape {value.edge_index.shape}")
                break
                
        except Exception as e:
            print(f"  ⚠️ Error: {e}")
    
    print("\n" + "="*60)
    print("DIAGNOSTICS COMPLETE")
    print("="*60)
    
    # Close connection
    kg.close()

if __name__ == "__main__":
    main()

2025-08-20 18:16:05,358 - data.kg_connector - INFO - Connected to Neo4j at bolt://localhost:7687
2025-08-20 18:16:05,358 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 18:16:05,359 - data.graph_constructor - INFO - Temporal features: True, Lookback: 24 hours



ENERGY GNN - GRAPH DIAGNOSTICS

Initializing components...

Building graph for district: Buitenveldert-Oost


2025-08-20 18:16:07,581 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 18:16:07,584 - data.graph_constructor - INFO - Fetching temporal features for 335 buildings...
2025-08-20 18:16:07,588 - data.kg_connector - INFO - Fetching time series for 335 buildings
2025-08-20 18:16:07,589 - data.kg_connector - INFO - Time range: 1706396400000 to 1706482800000 (24 hours)
2025-08-20 18:16:08,048 - data.kg_connector - INFO - Retrieved time series for 335 buildings out of 335 requested
2025-08-20 18:16:08,051 - data.graph_constructor - INFO - Added temporal features: torch.Size([335, 24, 8])
2025-08-20 18:16:08,136 - data.graph_constructor - INFO - Fetching temporal features for 95 clusters...
2025-08-20 18:16:08,356 - data.graph_constructor - INFO - Added cluster temporal features: torch.Size([95, 24, 7])
2025-08-20 18:16:08,357 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:16:08,3


Processing features...

GRAPH STRUCTURE ANALYSIS

1. NODE TYPES AND COUNTS:
----------------------------------------
  building            :   335 nodes,  17 features
    └── Temporal: torch.Size([335, 24, 8])
    └── Engineered: torch.Size([335, 7])
  cable_group         :    21 nodes,  12 features
    └── Engineered: torch.Size([21, 4])
  transformer         :    44 nodes,   3 features
  adjacency_cluster   :    95 nodes,  11 features
    └── Temporal: torch.Size([95, 24, 7])
    └── Engineered: torch.Size([95, 5])

2. EDGE TYPES AND COUNTS:
----------------------------------------
  (building, connected_to, cable_group)
    └── 335 edges
    └── Connects 335 building to 21 cable_group
  (cable_group, connects_to, transformer)
    └── 13 edges
    └── Connects 13 cable_group to 6 transformer
  (building, in_cluster, adjacency_cluster)
    └── 346 edges
    └── Connects 98 building to 79 adjacency_cluster

3. CONNECTIVITY CHECK:
----------------------------------------
  building: 0 

In [None]:
# explore_lv_network.py
"""
Explore the LV network structure with correct understanding
Focus on building-cable group relationships and physical adjacency
"""

import torch
import numpy as np
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
import matplotlib.pyplot as plt
import networkx as nx
from typing import Dict, List, Set
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def analyze_lv_network_structure(graph):
    """Analyze the LV network focusing on building-cable group relationships"""
    
    print("\n" + "="*60)
    print("LV NETWORK STRUCTURE ANALYSIS")
    print("="*60)
    
    # Analyze cable groups (the key LV infrastructure)
    if ('building', 'connected_to', 'cable_group') in graph.edge_types:
        edge_index = graph['building', 'connected_to', 'cable_group'].edge_index
        
        print("\n1. CABLE GROUP ANALYSIS:")
        print("-"*40)
        
        # Map buildings to cable groups
        cable_group_map = {}
        building_to_cg = {}
        
        for i in range(edge_index.shape[1]):
            b_idx = edge_index[0, i].item()
            cg_idx = edge_index[1, i].item()
            
            if cg_idx not in cable_group_map:
                cable_group_map[cg_idx] = []
            cable_group_map[cg_idx].append(b_idx)
            building_to_cg[b_idx] = cg_idx
        
        # Analyze distribution
        sizes = [len(buildings) for buildings in cable_group_map.values()]
        
        print(f"  Total Cable Groups: {len(cable_group_map)}")
        print(f"  Buildings per Cable Group:")
        print(f"    Min: {min(sizes)}")
        print(f"    Max: {max(sizes)}")
        print(f"    Mean: {np.mean(sizes):.1f}")
        print(f"    Median: {np.median(sizes):.1f}")
        print(f"    Std: {np.std(sizes):.1f}")
        
        # Distribution histogram
        print(f"\n  Size Distribution:")
        for threshold in [1, 5, 10, 20, 50, 100, 200]:
            count = sum(1 for s in sizes if s >= threshold)
            if count > 0:
                print(f"    >= {threshold:3d} buildings: {count:2d} cable groups")
        
        # Find the outlier (222 buildings in one group)
        max_cg = max(cable_group_map.items(), key=lambda x: len(x[1]))
        print(f"\n  Largest Cable Group (ID {max_cg[0]}): {len(max_cg[1])} buildings")
        print(f"    (This might be a data collection artifact)")
        
        # Check cable group features
        if 'cable_group' in graph.node_types:
            cg_features = graph['cable_group'].x
            print(f"\n  Cable Group Features:")
            print(f"    Shape: {cg_features.shape}")
            
            # Analyze key features
            if cg_features.shape[1] >= 12:
                print(f"    Total length (m): mean={cg_features[:, 0].mean():.1f}")
                print(f"    Segment count: mean={cg_features[:, 1].mean():.1f}")
                print(f"    Building count: mean={cg_features[:, 2].mean():.1f}")
                print(f"    Peak demand: mean={cg_features[:, 10].mean():.1f}")
        
        return cable_group_map, building_to_cg
    
    return {}, {}

def analyze_physical_adjacency(graph):
    """Analyze physical adjacency (shared walls) for thermal sharing"""
    
    print("\n" + "="*60)
    print("PHYSICAL ADJACENCY ANALYSIS (Thermal Sharing)")
    print("="*60)
    
    if ('building', 'in_cluster', 'adjacency_cluster') in graph.edge_types:
        edge_index = graph['building', 'in_cluster', 'adjacency_cluster'].edge_index
        
        print("\n1. SHARED WALL CLUSTERS:")
        print("-"*40)
        
        # Map clusters
        cluster_map = {}
        building_to_cluster = {}
        
        for i in range(edge_index.shape[1]):
            b_idx = edge_index[0, i].item()
            c_idx = edge_index[1, i].item()
            
            if c_idx not in cluster_map:
                cluster_map[c_idx] = []
            cluster_map[c_idx].append(b_idx)
            building_to_cluster[b_idx] = c_idx
        
        sizes = [len(buildings) for buildings in cluster_map.values()]
        
        print(f"  Total Physical Clusters: {len(cluster_map)}")
        print(f"  Buildings with shared walls: {len(building_to_cluster)}/{graph['building'].x.shape[0]}")
        print(f"  Buildings per Cluster:")
        print(f"    Min: {min(sizes) if sizes else 0}")
        print(f"    Max: {max(sizes) if sizes else 0}")
        print(f"    Mean: {np.mean(sizes) if sizes else 0:.1f}")
        
        # Building types in clusters
        print(f"\n  Cluster Types (by size):")
        size_counts = {}
        for size in sizes:
            size_counts[size] = size_counts.get(size, 0) + 1
        
        for size, count in sorted(size_counts.items()):
            building_type = ""
            if size == 2:
                building_type = "(Semi-detached)"
            elif size <= 4:
                building_type = "(Row houses)"
            elif size <= 8:
                building_type = "(Apartment block)"
            print(f"    {size} buildings: {count:2d} clusters {building_type}")
        
        # Analyze thermal sharing potential
        if 'building' in graph.node_types and hasattr(graph['building'], 'x'):
            features = graph['building'].x
            
            print(f"\n2. THERMAL SHARING POTENTIAL:")
            print("-"*40)
            
            # Check shared walls feature (index 10)
            for cluster_id, building_ids in list(cluster_map.items())[:3]:  # Sample
                if len(building_ids) >= 3:
                    shared_walls = [features[bid, 10].item() for bid in building_ids]
                    heating_demand = [features[bid, 14].item() for bid in building_ids]
                    
                    print(f"\n  Cluster {cluster_id} ({len(building_ids)} buildings):")
                    print(f"    Avg shared walls: {np.mean(shared_walls):.1f}")
                    print(f"    Heating variance: {np.std(heating_demand):.3f}")
                    print(f"    (High variance = good thermal sharing potential)")
        
        return cluster_map, building_to_cluster
    
    return {}, {}

def analyze_boundary_conditions(graph):
    """Analyze transformers as boundary conditions"""
    
    print("\n" + "="*60)
    print("BOUNDARY CONDITIONS ANALYSIS")
    print("="*60)
    
    print("\n1. TRANSFORMER STATUS:")
    print("-"*40)
    
    # Check transformer nodes
    if 'transformer' in graph.node_types:
        num_transformers = graph['transformer'].x.shape[0]
        print(f"  Total Transformers in data: {num_transformers}")
        print(f"  Status: ISOLATED (expected - they're boundary nodes)")
        print(f"  Purpose: Capacity constraints for cable groups")
        
        # Check if cable groups reference transformers
        if ('cable_group', 'connects_to', 'transformer') in graph.edge_types:
            edge_index = graph['cable_group', 'connects_to', 'transformer'].edge_index
            if edge_index.shape[1] > 0:
                print(f"  Connected Cable Groups: {edge_index.shape[1]}")
        else:
            print(f"  Note: Transformers likely outside study area")
            print(f"  Treatment: Use as capacity constraints, not graph nodes")
    
    print("\n2. RECOMMENDED APPROACH:")
    print("-"*40)
    print("  • Focus GNN on building-cable group relationships")
    print("  • Use transformers as external capacity constraints")
    print("  • Model thermal sharing via adjacency clusters")
    print("  • Don't model MV/HV levels (outside scope)")

def analyze_task_feasibility(graph, cable_group_map, cluster_map):
    """Analyze which tasks are feasible with LV-only data"""
    
    print("\n" + "="*60)
    print("TASK FEASIBILITY WITH LV NETWORK")
    print("="*60)
    
    print("\n✅ FEASIBLE TASKS:")
    print("-"*40)
    
    print("\n1. BUILDING-LEVEL OPTIMIZATION:")
    print("  • Retrofit targeting (51% candidates)")
    print("  • Solar placement (184 candidates)")
    print("  • Heat pump planning (130 candidates)")
    print("  • Battery placement")
    
    print("\n2. THERMAL ENERGY SHARING:")
    print("  • Between buildings with shared walls")
    print("  • Small clusters (2-8 buildings)")
    print("  • Complementary heating patterns")
    
    print("\n3. LV CABLE GROUP OPTIMIZATION:")
    print("  • Load balancing within cable groups")
    print("  • Phase balancing")
    print("  • Local congestion management")
    
    print("\n⚠️ LIMITED FEASIBILITY:")
    print("-"*40)
    
    print("\n1. ELECTRICAL P2P TRADING:")
    print("  • Only within same cable group")
    print("  • Cannot trade across cable groups")
    print("  • Limited to ~16 buildings average")
    
    print("\n2. GRID PLANNING:")
    print("  • Only LV cable planning")
    print("  • Cannot optimize MV transformers")
    print("  • Treat transformer capacity as fixed")
    
    print("\n❌ NOT FEASIBLE:")
    print("-"*40)
    print("\n1. MV/HV OPTIMIZATION")
    print("2. Substation planning")
    print("3. Cross-transformer energy sharing")

def visualize_lv_structure(graph, cable_group_map, cluster_map, sample_size=50):
    """Visualize a sample of the LV network structure"""
    
    print("\n" + "="*60)
    print("LV NETWORK VISUALIZATION")
    print("="*60)
    
    # Create NetworkX graph for visualization
    G = nx.Graph()
    
    # Sample buildings
    building_indices = list(range(min(sample_size, graph['building'].x.shape[0])))
    
    # Add nodes
    for b_idx in building_indices:
        G.add_node(f"B{b_idx}", type='building')
    
    # Add cable groups for these buildings
    if ('building', 'connected_to', 'cable_group') in graph.edge_types:
        edge_index = graph['building', 'connected_to', 'cable_group'].edge_index
        
        added_cgs = set()
        for i in range(edge_index.shape[1]):
            b_idx = edge_index[0, i].item()
            cg_idx = edge_index[1, i].item()
            
            if b_idx in building_indices:
                if cg_idx not in added_cgs:
                    G.add_node(f"CG{cg_idx}", type='cable_group')
                    added_cgs.add(cg_idx)
                G.add_edge(f"B{b_idx}", f"CG{cg_idx}", type='electrical')
    
    # Add adjacency clusters
    if ('building', 'in_cluster', 'adjacency_cluster') in graph.edge_types:
        edge_index = graph['building', 'in_cluster', 'adjacency_cluster'].edge_index
        
        added_clusters = set()
        for i in range(edge_index.shape[1]):
            b_idx = edge_index[0, i].item()
            c_idx = edge_index[1, i].item()
            
            if b_idx in building_indices:
                if c_idx not in added_clusters:
                    G.add_node(f"AC{c_idx}", type='adjacency')
                    added_clusters.add(c_idx)
                G.add_edge(f"B{b_idx}", f"AC{c_idx}", type='thermal')
    
    print(f"\nSample Network Statistics:")
    print(f"  Nodes: {G.number_of_nodes()}")
    print(f"  Edges: {G.number_of_edges()}")
    print(f"  Components: {nx.number_connected_components(G)}")
    
    return G

def check_data_quality_issues(graph):
    """Check for specific data quality issues"""
    
    print("\n" + "="*60)
    print("DATA QUALITY CHECKS")
    print("="*60)
    
    issues_found = []
    
    # Check 1: The 222-building cable group
    if ('building', 'connected_to', 'cable_group') in graph.edge_types:
        edge_index = graph['building', 'connected_to', 'cable_group'].edge_index
        
        cg_counts = {}
        for i in range(edge_index.shape[1]):
            cg = edge_index[1, i].item()
            cg_counts[cg] = cg_counts.get(cg, 0) + 1
        
        max_count = max(cg_counts.values())
        if max_count > 100:
            issues_found.append(f"Cable group with {max_count} buildings (likely data error)")
    
    # Check 2: Normalized features that shouldn't be
    if 'adjacency_cluster' in graph.node_types:
        cluster_features = graph['adjacency_cluster'].x
        member_counts = cluster_features[:, 0]
        
        if member_counts.max() < 2:  # Normalized when should be counts
            issues_found.append("Cluster member counts incorrectly normalized")
    
    # Check 3: Missing temporal data
    if hasattr(graph['adjacency_cluster'], 'x_temporal'):
        temporal = graph['adjacency_cluster'].x_temporal
        if temporal.sum() == 0:
            issues_found.append("Cluster temporal features are all zeros")
    
    # Print issues
    if issues_found:
        print("\n⚠️ ISSUES FOUND:")
        for i, issue in enumerate(issues_found, 1):
            print(f"  {i}. {issue}")
    else:
        print("\n✅ No major issues detected")
    
    return issues_found

def main():
    """Run complete LV network analysis"""
    
    # Configuration
    NEO4J_URI = "bolt://localhost:7687"
    NEO4J_USER = "neo4j"
    NEO4J_PASSWORD = "aminasad"
    DISTRICT = "Buitenveldert-Oost"
    
    print("\n" + "="*60)
    print("LV NETWORK FOCUSED ANALYSIS")
    print("="*60)
    
    # Initialize
    kg = KGConnector(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    graph_builder = GraphConstructor(kg)
    
    # Build graph
    graph = graph_builder.build_hetero_graph(
        DISTRICT,
        include_energy_sharing=True,
        include_temporal=True,
        lookback_hours=24
    )
    
    # Analyze LV structure
    cable_group_map, building_to_cg = analyze_lv_network_structure(graph)
    
    # Analyze physical adjacency
    cluster_map, building_to_cluster = analyze_physical_adjacency(graph)
    
    # Analyze boundary conditions
    analyze_boundary_conditions(graph)
    
    # Check task feasibility
    analyze_task_feasibility(graph, cable_group_map, cluster_map)
    
    # Check data quality
    issues = check_data_quality_issues(graph)
    
    # Visualize sample
    G = visualize_lv_structure(graph, cable_group_map, cluster_map)
    
    # Close connection
    kg.close()
    
    print("\n" + "="*60)
    print("ANALYSIS COMPLETE")
    print("="*60)

if __name__ == "__main__":
    main()

INFO:data.kg_connector:Connected to Neo4j at bolt://localhost:7687
INFO:data.graph_constructor:Building graph for district Buitenveldert-Oost
INFO:data.graph_constructor:Temporal features: True, Lookback: 24 hours



LV NETWORK FOCUSED ANALYSIS


INFO:data.kg_connector:Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
INFO:data.graph_constructor:Fetching temporal features for 335 buildings...
INFO:data.kg_connector:Fetching time series for 335 buildings
INFO:data.kg_connector:Time range: 1706396400000 to 1706482800000 (24 hours)
INFO:data.kg_connector:Retrieved time series for 335 buildings out of 335 requested
INFO:data.graph_constructor:Added temporal features: torch.Size([335, 24, 8])
INFO:data.graph_constructor:Fetching temporal features for 95 clusters...
INFO:data.graph_constructor:Added cluster temporal features: torch.Size([95, 24, 7])
INFO:data.graph_constructor:No transformer_to_substation edges (substations might not exist)
INFO:data.graph_constructor:Temporal features added for: ['building', 'adjacency_cluster']
INFO:data.graph_constructor:Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
INFO:data.kg_connector:Neo4j connection closed



LV NETWORK STRUCTURE ANALYSIS

1. CABLE GROUP ANALYSIS:
----------------------------------------
  Total Cable Groups: 21
  Buildings per Cable Group:
    Min: 1
    Max: 222
    Mean: 16.0
    Median: 4.0
    Std: 46.2

  Size Distribution:
    >=   1 buildings: 21 cable groups
    >=   5 buildings: 10 cable groups
    >=  10 buildings:  6 cable groups
    >=  20 buildings:  1 cable groups
    >=  50 buildings:  1 cable groups
    >= 100 buildings:  1 cable groups
    >= 200 buildings:  1 cable groups

  Largest Cable Group (ID 0): 222 buildings
    (This might be a data collection artifact)

  Cable Group Features:
    Shape: torch.Size([21, 12])
    Total length (m): mean=3279.9
    Segment count: mean=91.2
    Building count: mean=43.5
    Peak demand: mean=15.8

PHYSICAL ADJACENCY ANALYSIS (Thermal Sharing)

1. SHARED WALL CLUSTERS:
----------------------------------------
  Total Physical Clusters: 79
  Buildings with shared walls: 98/335
  Buildings per Cluster:
    Min: 2
    

In [None]:
# diagnose_id_mismatch.py
"""
Diagnose the ID mismatch between Neo4j and graph construction
"""

from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
import logging

logging.basicConfig(level=logging.INFO)

def diagnose_id_mismatch():
    kg = KGConnector("bolt://localhost:7687", "neo4j", "aminasad")
    
    print("\n" + "="*60)
    print("ID MISMATCH DIAGNOSIS")
    print("="*60)
    
    # First, check what IDs are actually in Neo4j
    with kg.driver.session() as session:
        print("\n1. CABLE GROUP IDs IN NEO4J:")
        print("-"*40)
        
        query = """
        MATCH (cg:CableGroup)
        RETURN cg.group_id as id, cg.voltage_level as voltage
        ORDER BY cg.group_id
        LIMIT 10
        """
        results = session.run(query).data()
        
        for r in results:
            print(f"  Neo4j ID: '{r['id']}', Voltage: {r['voltage']}")
        
        print("\n2. BUILDING CONNECTIONS IN NEO4J:")
        print("-"*40)
        
        # Try different ID formats
        id_formats = [
            "0",
            "LV_GROUP_0001", 
            "LV_GROUP_0002",
            "CG_0",
            "1"
        ]
        
        for test_id in id_formats:
            query = f"""
            MATCH (cg:CableGroup {{group_id: '{test_id}'}})<-[:CONNECTED_TO]-(b:Building)
            RETURN count(b) as building_count
            """
            result = session.run(query).single()
            count = result['building_count'] if result else 0
            print(f"  ID '{test_id}': {count} buildings")
        
        print("\n3. CHECK ACTUAL CABLE GROUP WITH MANY BUILDINGS:")
        print("-"*40)
        
        query = """
        MATCH (cg:CableGroup)<-[:CONNECTED_TO]-(b:Building)
        WITH cg, count(b) as building_count
        ORDER BY building_count DESC
        LIMIT 5
        RETURN cg.group_id as id, building_count
        """
        results = session.run(query).data()
        
        for r in results:
            print(f"  Cable Group '{r['id']}': {r['building_count']} buildings")
    
    # Now check what the graph constructor is doing
    print("\n4. GRAPH CONSTRUCTOR MAPPING:")
    print("-"*40)
    
    graph_builder = GraphConstructor(kg)
    topology = kg.get_grid_topology("Buitenveldert-Oost")
    
    # Check cable group nodes
    cable_groups = topology['nodes'].get('cable_groups', [])
    print(f"\nFirst 5 cable groups from topology:")
    for i, cg in enumerate(cable_groups[:5]):
        actual_id = cg.get('group_id', f'MISSING_{i}')
        print(f"  Index {i}: Neo4j ID = '{actual_id}'")
    
    # Check the node mappings
    print("\n5. NODE MAPPINGS IN GRAPH CONSTRUCTOR:")
    print("-"*40)
    
    graph = graph_builder.build_hetero_graph("Buitenveldert-Oost")
    
    if hasattr(graph_builder, 'node_mappings'):
        if 'cable_group' in graph_builder.node_mappings:
            cg_mappings = graph_builder.node_mappings['cable_group']
            print(f"\nCable Group ID mappings (first 5):")
            for neo4j_id, graph_idx in list(cg_mappings.items())[:5]:
                print(f"  Neo4j: '{neo4j_id}' → Graph Index: {graph_idx}")
    
    kg.close()

# Run diagnosis
diagnose_id_mismatch()

INFO:data.kg_connector:Connected to Neo4j at bolt://localhost:7687



ID MISMATCH DIAGNOSIS

1. CABLE GROUP IDs IN NEO4J:
----------------------------------------
  Neo4j ID: 'HV_GROUP_0001', Voltage: HV
  Neo4j ID: 'HV_GROUP_0002', Voltage: HV
  Neo4j ID: 'HV_GROUP_0003', Voltage: HV
  Neo4j ID: 'HV_GROUP_0004', Voltage: HV
  Neo4j ID: 'HV_GROUP_0005', Voltage: HV
  Neo4j ID: 'HV_GROUP_0006', Voltage: HV
  Neo4j ID: 'LV_GROUP_0001', Voltage: LV
  Neo4j ID: 'LV_GROUP_0002', Voltage: LV
  Neo4j ID: 'LV_GROUP_0003', Voltage: LV
  Neo4j ID: 'LV_GROUP_0004', Voltage: LV

2. BUILDING CONNECTIONS IN NEO4J:
----------------------------------------
  ID '0': 0 buildings
  ID 'LV_GROUP_0001': 0 buildings
  ID 'LV_GROUP_0002': 5 buildings
  ID 'CG_0': 0 buildings
  ID '1': 0 buildings

3. CHECK ACTUAL CABLE GROUP WITH MANY BUILDINGS:
----------------------------------------


INFO:data.kg_connector:Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
INFO:data.graph_constructor:Building graph for district Buitenveldert-Oost
INFO:data.graph_constructor:Temporal features: True, Lookback: 24 hours


  Cable Group 'LV_GROUP_0003': 731 buildings
  Cable Group 'LV_GROUP_0020': 228 buildings
  Cable Group 'LV_GROUP_0049': 71 buildings
  Cable Group 'LV_GROUP_0037': 44 buildings
  Cable Group 'LV_GROUP_0062': 34 buildings

4. GRAPH CONSTRUCTOR MAPPING:
----------------------------------------

First 5 cable groups from topology:
  Index 0: Neo4j ID = 'LV_GROUP_0003'
  Index 1: Neo4j ID = 'LV_GROUP_0053'
  Index 2: Neo4j ID = 'LV_GROUP_0086'
  Index 3: Neo4j ID = 'LV_GROUP_0028'
  Index 4: Neo4j ID = 'LV_GROUP_0049'

5. NODE MAPPINGS IN GRAPH CONSTRUCTOR:
----------------------------------------


INFO:data.kg_connector:Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
INFO:data.graph_constructor:Fetching temporal features for 335 buildings...
INFO:data.kg_connector:Fetching time series for 335 buildings
INFO:data.kg_connector:Time range: 1706396400000 to 1706482800000 (24 hours)
INFO:data.kg_connector:Retrieved time series for 335 buildings out of 335 requested
INFO:data.graph_constructor:Added temporal features: torch.Size([335, 24, 8])
INFO:data.graph_constructor:Fetching temporal features for 95 clusters...
INFO:data.graph_constructor:Added cluster temporal features: torch.Size([95, 24, 7])
INFO:data.graph_constructor:No transformer_to_substation edges (substations might not exist)
INFO:data.graph_constructor:Temporal features added for: ['building', 'adjacency_cluster']
INFO:data.graph_constructor:Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
INFO:data.kg_connector:Neo4j connection closed



Cable Group ID mappings (first 5):
  Neo4j: 'LV_GROUP_0003' → Graph Index: 0
  Neo4j: 'LV_GROUP_0053' → Graph Index: 1
  Neo4j: 'LV_GROUP_0086' → Graph Index: 2
  Neo4j: 'LV_GROUP_0028' → Graph Index: 3
  Neo4j: 'LV_GROUP_0049' → Graph Index: 4


In [None]:
# debug_edge_creation.py
"""
Debug why cable_group->transformer edges are lost
"""

from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
import torch
import logging

logging.basicConfig(level=logging.DEBUG)  # Set to DEBUG for more detail

def debug_edge_creation():
    kg = KGConnector("bolt://localhost:7687", "neo4j", "aminasad")
    graph_builder = GraphConstructor(kg)
    
    print("\n" + "="*60)
    print("DEBUGGING EDGE CREATION")
    print("="*60)
    
    # Get the raw topology
    topology = kg.get_grid_topology("Buitenveldert-Oost")
    
    # Check what edges Neo4j returns
    print("\n1. RAW EDGES FROM NEO4J:")
    print("-"*40)
    
    cable_to_transformer = topology['edges'].get('cable_to_transformer', [])
    print(f"Cable→Transformer edges: {len(cable_to_transformer)}")
    
    if cable_to_transformer:
        print("\nFirst 5 cable→transformer edges:")
        for edge in cable_to_transformer[:5]:
            print(f"  {edge}")
    
    # Check the node mappings
    print("\n2. NODE MAPPINGS:")
    print("-"*40)
    
    # Build graph to populate mappings
    graph = graph_builder.build_hetero_graph("Buitenveldert-Oost")
    
    cg_mappings = graph_builder.node_mappings.get('cable_group', {})
    t_mappings = graph_builder.node_mappings.get('transformer', {})
    
    print(f"Cable group mappings: {len(cg_mappings)} entries")
    print(f"Transformer mappings: {len(t_mappings)} entries")
    
    # Show some transformer mappings
    print("\nFirst 5 transformer mappings:")
    for t_id, idx in list(t_mappings.items())[:5]:
        print(f"  '{t_id}' → {idx}")
    
    # Try to manually create the edges
    print("\n3. MANUAL EDGE CREATION TEST:")
    print("-"*40)
    
    successful_edges = []
    failed_edges = []
    
    for edge in cable_to_transformer:
        src_id = str(edge['src'])
        dst_id = str(edge['dst'])
        
        if src_id in cg_mappings and dst_id in t_mappings:
            src_idx = cg_mappings[src_id]
            dst_idx = t_mappings[dst_id]
            successful_edges.append([src_idx, dst_idx])
            print(f"  ✅ '{src_id}' ({src_idx}) → '{dst_id}' ({dst_idx})")
        else:
            failed_edges.append(edge)
            if src_id not in cg_mappings:
                print(f"  ❌ Source '{src_id}' not in cable_group mappings")
            if dst_id not in t_mappings:
                print(f"  ❌ Dest '{dst_id}' not in transformer mappings")
    
    print(f"\nSuccessful: {len(successful_edges)}, Failed: {len(failed_edges)}")
    
    # Check what the graph actually has
    print("\n4. ACTUAL GRAPH EDGES:")
    print("-"*40)
    
    for edge_type in graph.edge_types:
        edge_index = graph[edge_type].edge_index
        print(f"  {edge_type}: {edge_index.shape[1]} edges")
    
    # Direct Neo4j query to verify
    print("\n5. VERIFY IN NEO4J DIRECTLY:")
    print("-"*40)
    
    with kg.driver.session() as session:
        query = """
        MATCH (cg:CableGroup)-[:CONNECTS_TO]->(t:Transformer)
        RETURN cg.group_id as cg_id, t.station_id as t_id
        LIMIT 5
        """
        results = session.run(query).data()
        
        print(f"Neo4j query found {len(results)} edges:")
        for r in results:
            print(f"  '{r['cg_id']}' → '{r['t_id']}'")
            
            # Check if these IDs exist in mappings
            cg_exists = r['cg_id'] in cg_mappings
            t_exists = str(r['t_id']) in t_mappings if r['t_id'] else False
            print(f"    Cable group in mappings: {cg_exists}")
            print(f"    Transformer in mappings: {t_exists}")
    
    kg.close()

# Run debug
debug_edge_creation()

2025-08-20 18:16:41,734 - data.kg_connector - INFO - Connected to Neo4j at bolt://localhost:7687



DEBUGGING EDGE CREATION


2025-08-20 18:16:43,903 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 18:16:43,904 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 18:16:43,905 - data.graph_constructor - INFO - Temporal features: True, Lookback: 24 hours
2025-08-20 18:16:44,028 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 18:16:44,031 - data.graph_constructor - INFO - Fetching temporal features for 335 buildings...
2025-08-20 18:16:44,034 - data.kg_connector - INFO - Fetching time series for 335 buildings
2025-08-20 18:16:44,035 - data.kg_connector - INFO - Time range: 1706396400000 to 1706482800000 (24 hours)



1. RAW EDGES FROM NEO4J:
----------------------------------------
Cable→Transformer edges: 13

First 5 cable→transformer edges:
  {'dst': '1099526076017', 'src': 'LV_GROUP_0003'}
  {'dst': '1099526076075', 'src': 'LV_GROUP_0053'}
  {'dst': '1099525275604', 'src': 'LV_GROUP_0086'}
  {'dst': '1099524500983', 'src': 'LV_GROUP_0028'}
  {'dst': '1099527241171', 'src': 'LV_GROUP_0049'}

2. NODE MAPPINGS:
----------------------------------------


2025-08-20 18:16:44,464 - data.kg_connector - INFO - Retrieved time series for 335 buildings out of 335 requested
2025-08-20 18:16:44,468 - data.graph_constructor - INFO - Added temporal features: torch.Size([335, 24, 8])
2025-08-20 18:16:44,574 - data.graph_constructor - INFO - Fetching temporal features for 95 clusters...
2025-08-20 18:16:44,863 - data.graph_constructor - INFO - Added cluster temporal features: torch.Size([95, 24, 7])
2025-08-20 18:16:44,864 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 18:16:44,865 - data.graph_constructor - INFO - Temporal features added for: ['building', 'adjacency_cluster']
2025-08-20 18:16:44,866 - data.graph_constructor - INFO - Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
2025-08-20 18:16:44,912 - data.kg_connector - INFO - Neo4j connection closed


Cable group mappings: 21 entries
Transformer mappings: 44 entries

First 5 transformer mappings:
  '1099527241171' → 0
  '183597013441725' → 1
  '1099520585253' → 2
  '1099524480815' → 3
  '1099524500983' → 4

3. MANUAL EDGE CREATION TEST:
----------------------------------------
  ✅ 'LV_GROUP_0003' (0) → '1099526076017' (12)
  ✅ 'LV_GROUP_0053' (1) → '1099526076075' (13)
  ✅ 'LV_GROUP_0086' (2) → '1099525275604' (7)
  ✅ 'LV_GROUP_0028' (3) → '1099524500983' (4)
  ✅ 'LV_GROUP_0049' (4) → '1099527241171' (0)
  ✅ 'LV_GROUP_0030' (10) → '1099526076075' (13)
  ✅ 'LV_GROUP_0021' (12) → '1099525833113' (9)
  ✅ 'LV_GROUP_0024' (13) → '1099525833113' (9)
  ✅ 'LV_GROUP_0026' (15) → '1099525833113' (9)
  ✅ 'LV_GROUP_0022' (16) → '1099525833113' (9)
  ✅ 'LV_GROUP_0088' (17) → '1099525275604' (7)
  ✅ 'LV_GROUP_0085' (18) → '1099525275604' (7)
  ✅ 'LV_GROUP_0105' (19) → '1099525275604' (7)

Successful: 13, Failed: 0

4. ACTUAL GRAPH EDGES:
----------------------------------------
  ('building', 'co

## base_gnn

1. base_gnn.py
   ↓
   Creates embeddings (learns patterns)
   ↓
2. task_heads.py
   ↓
   Predicts intervention impacts
   ↓
3. tasks/*.py
   ↓
   Detailed calculations & validation
   ↓
4. inference/intervention_planner.py
   ↓
   Final recommendations & roadmap

In [None]:
# test_base_gnn_explore.py
"""
First, let's explore what methods KGConnector actually has
"""

import sys
import torch
import numpy as np
from pathlib import Path
import logging
import inspect

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Add project paths
sys.path.append('.')
sys.path.append('./data')
sys.path.append('./models')

def explore_kg_connector():
    """Explore KGConnector methods"""
    print("\n" + "="*60)
    print("EXPLORING KG CONNECTOR")
    print("="*60)
    
    try:
        from data.kg_connector import KGConnector
        
        # Create instance
        kg = KGConnector(
            uri="bolt://localhost:7687",
            user="neo4j",
            password="aminasad"
        )
        
        print("\n1. Available methods in KGConnector:")
        methods = [method for method in dir(kg) if not method.startswith('_')]
        for method in methods:
            print(f"  - {method}")
        
        print("\n2. Let's check the signature of key methods:")
        for method_name in methods:
            if not method_name.startswith('_'):
                method = getattr(kg, method_name)
                if callable(method):
                    try:
                        sig = inspect.signature(method)
                        print(f"  {method_name}{sig}")
                    except:
                        print(f"  {method_name}()")
        
        return kg
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return None

def test_graph_constructor_directly():
    """Test using GraphConstructor directly"""
    print("\n" + "="*60)
    print("TESTING GRAPH CONSTRUCTOR DIRECTLY")
    print("="*60)
    
    try:
        from data.kg_connector import KGConnector
        from data.graph_constructor import GraphConstructor
        
        # Create KG connector
        kg = KGConnector(
            uri="bolt://localhost:7687",
            user="neo4j",
            password="aminasad"
        )
        
        print("\n1. Creating GraphConstructor...")
        graph_constructor = GraphConstructor(kg)
        
        print("\n2. Checking GraphConstructor methods:")
        methods = [method for method in dir(graph_constructor) if not method.startswith('_')]
        for method in methods[:10]:  # Show first 10
            print(f"  - {method}")
        
        print("\n3. Attempting to build graph...")
        
        # Try to build a graph - use the method that worked in your earlier test
        try:
            # Try without district filter first
            hetero_graph = graph_constructor.build_hetero_graph(
                district_name=None,  # Try with None first
                include_energy_sharing=False,
                include_temporal=False,
                lookback_hours=0
            )
            
            print(f"\n4. Graph built successfully!")
            print(f"  Node types: {hetero_graph.node_types}")
            print(f"  Edge types: {hetero_graph.edge_types}")
            
            for node_type in hetero_graph.node_types:
                if hasattr(hetero_graph[node_type], 'x'):
                    print(f"  {node_type}: {hetero_graph[node_type].x.shape}")
            
            return kg, graph_constructor, hetero_graph
            
        except Exception as e:
            print(f"  Failed to build graph: {e}")
            
            # Try alternative approach
            print("\n  Trying with district name...")
            try:
                hetero_graph = graph_constructor.build_hetero_graph(
                    district_name="Buitenveldert-Oost",
                    include_energy_sharing=False,
                    include_temporal=False,
                    lookback_hours=0
                )
                print(f"  Success with district name!")
                return kg, graph_constructor, hetero_graph
            except Exception as e2:
                print(f"  Also failed: {e2}")
                return kg, graph_constructor, None
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None

def test_base_gnn_minimal():
    """Minimal test of base GNN with dummy data"""
    print("\n" + "="*60)
    print("TESTING BASE GNN WITH MINIMAL DATA")
    print("="*60)
    
    try:
        from models.base_gnn import EnergyGNNBase, create_energy_gnn_base
        from torch_geometric.data import HeteroData
        
        # Create minimal dummy data
        print("\n1. Creating dummy heterogeneous graph...")
        
        hetero_graph = HeteroData()
        
        # Add minimal node features
        hetero_graph['building'].x = torch.randn(100, 17)  # 100 buildings, 17 features
        hetero_graph['cable_group'].x = torch.randn(10, 4)  # 10 LV groups, 4 features
        hetero_graph['transformer'].x = torch.randn(5, 2)  # 5 transformers, 2 features
        hetero_graph['adjacency_cluster'].x = torch.randn(20, 4)  # 20 clusters, 4 features
        
        # Add minimal edges
        hetero_graph['building', 'connected_to', 'cable_group'].edge_index = \
            torch.stack([torch.randint(0, 100, (50,)), torch.randint(0, 10, (50,))])
        
        hetero_graph['cable_group', 'connects_to', 'transformer'].edge_index = \
            torch.stack([torch.randint(0, 10, (5,)), torch.randint(0, 5, (5,))])
        
        hetero_graph['building', 'in_cluster', 'adjacency_cluster'].edge_index = \
            torch.stack([torch.randint(0, 100, (30,)), torch.randint(0, 20, (30,))])
        
        print(f"  Created graph with {hetero_graph.node_types}")
        
        # Prepare inputs
        x_dict = {
            'building': hetero_graph['building'].x,
            'cable_group': hetero_graph['cable_group'].x,
            'transformer': hetero_graph['transformer'].x,
            'adjacency_cluster': hetero_graph['adjacency_cluster'].x
        }
        
        edge_index_dict = {
            ('building', 'connected_to', 'cable_group'): 
                hetero_graph['building', 'connected_to', 'cable_group'].edge_index,
            ('cable_group', 'connects_to', 'transformer'): 
                hetero_graph['cable_group', 'connects_to', 'transformer'].edge_index,
            ('building', 'in_cluster', 'adjacency_cluster'): 
                hetero_graph['building', 'in_cluster', 'adjacency_cluster'].edge_index
        }
        
        temporal_context = {
            'season': torch.tensor([0]),
            'is_weekend': torch.tensor([0]),
            'hour': torch.tensor([14])
        }
        
        print("\n2. Creating base GNN model...")
        config = {
            'hidden_dim': 128,
            'num_layers': 3,
            'dropout': 0.1
        }
        
        model = create_energy_gnn_base(config)
        model.eval()
        
        print("\n3. Testing forward pass...")
        with torch.no_grad():
            output_dict = model(x_dict, edge_index_dict, temporal_context)
        
        print("\n4. Output shapes:")
        for key, value in output_dict.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: {value.shape}")
        
        print("\n✅ Base GNN works with dummy data!")
        return model, output_dict
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return None, None

def main():
    """Main test function"""
    print("\n" + "="*60)
    print("COMPREHENSIVE BASE GNN TEST")
    print("="*60)
    
    # Step 1: Explore KGConnector
    kg = explore_kg_connector()
    
    if kg is not None:
        # Step 2: Try GraphConstructor
        kg2, graph_constructor, hetero_graph = test_graph_constructor_directly()
        
        if hetero_graph is not None:
            print("\n✅ Successfully built graph from Neo4j!")
            
            # Test with real data
            try:
                from models.base_gnn import create_energy_gnn_base
                
                # Prepare inputs from real graph
                x_dict = {}
                for node_type in hetero_graph.node_types:
                    if hasattr(hetero_graph[node_type], 'x'):
                        x_dict[node_type] = hetero_graph[node_type].x
                
                edge_index_dict = {}
                for edge_type in hetero_graph.edge_types:
                    edge_index_dict[edge_type] = hetero_graph[edge_type].edge_index
                
                config = {'hidden_dim': 128, 'num_layers': 3, 'dropout': 0.1}
                model = create_energy_gnn_base(config)
                model.eval()
                
                # Test forward pass
                print("\nTesting GNN with real Neo4j data...")
                with torch.no_grad():
                    temporal_context = {
                        'season': torch.tensor([0]),
                        'is_weekend': torch.tensor([0]),
                        'hour': torch.tensor([14])
                    }
                    output = model(x_dict, edge_index_dict, temporal_context)
                
                print("✅ GNN works with real Neo4j data!")
                
                for key, value in output.items():
                    if isinstance(value, torch.Tensor):
                        print(f"  {key}: {value.shape}")
                
            except Exception as e:
                print(f"Failed with real data: {e}")
        
        # Clean up
        if kg2:
            kg2.close()
    
    # Step 3: Test with dummy data as fallback
    print("\n" + "="*60)
    print("Testing with dummy data as verification...")
    model, outputs = test_base_gnn_minimal()
    
    if outputs is not None:
        print("\n" + "="*60)
        print("SUMMARY")
        print("="*60)
        print("✅ Base GNN model is working correctly!")
        print("Check the output above to see if Neo4j connection worked.")
    else:
        print("\n❌ Base GNN has issues even with dummy data.")

if __name__ == "__main__":
    main()

Couldn't import dot_parser, loading of dot files will not be possible.

COMPREHENSIVE BASE GNN TEST

EXPLORING KG CONNECTOR


2025-08-20 21:09:46,402 - numexpr.utils - INFO - Note: NumExpr detected 20 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2025-08-20 21:09:46,402 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
2025-08-20 21:09:46,724 - data.kg_connector - INFO - Connected to Neo4j at bolt://localhost:7687



1. Available methods in KGConnector:
  - aggregate_to_cable_group
  - close
  - driver
  - get_adjacency_clusters
  - get_building_time_series
  - get_buildings_by_cable_group
  - get_cluster_time_series
  - get_district_hierarchy
  - get_energy_states
  - get_grid_topology
  - get_retrofit_candidates
  - get_systems_installed
  - verify_connection

2. Let's check the signature of key methods:
  aggregate_to_cable_group(group_id: str) -> Dict[str, Any]
  close()
  get_adjacency_clusters(district_name: str, min_cluster_size: int = 3) -> List[Dict]
  get_building_time_series(building_ids: List[str], lookback_hours: int = 24, end_time: Optional[int] = None) -> Dict[str, numpy.ndarray]
  get_buildings_by_cable_group(group_id: str) -> List[Dict]
  get_cluster_time_series(cluster_id: str, lookback_hours: int = 24) -> Dict[str, Any]
  get_district_hierarchy(district_name: str) -> Dict[str, Any]
  get_energy_states(building_ids: List[str], time_range: Optional[Dict] = None) -> pandas.core.fra

2025-08-20 21:09:48,648 - data.kg_connector - INFO - Connected to Neo4j at bolt://localhost:7687
2025-08-20 21:09:48,648 - data.graph_constructor - INFO - Building graph for district None
2025-08-20 21:09:48,649 - data.graph_constructor - INFO - Temporal features: False, Lookback: 0 hours



1. Creating GraphConstructor...

2. Checking GraphConstructor methods:
  - build_hetero_graph
  - build_subgraph_for_task
  - edge_types
  - kg
  - node_mappings
  - node_types

3. Attempting to build graph...


2025-08-20 21:09:50,726 - data.kg_connector - INFO - Edge counts - B->CG: 0, CG->T: 0, T->S: 0, B->AC: 0
2025-08-20 21:09:50,727 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 21:09:50,728 - data.graph_constructor - INFO - Graph built: {'building': 0, 'cable_group': 0, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 0}



4. Graph built successfully!
  Node types: ['transformer']
  Edge types: []
  transformer: torch.Size([44, 3])

✅ Successfully built graph from Neo4j!


2025-08-20 21:09:50,931 - models.base_gnn - INFO - Initialized EnergyGNNBase with 3 layers
2025-08-20 21:09:50,932 - models.base_gnn - INFO - Created EnergyGNNBase with 467,124 parameters
2025-08-20 21:09:50,933 - models.base_gnn - INFO - Trainable parameters: 467,124
2025-08-20 21:09:50,948 - data.kg_connector - INFO - Neo4j connection closed
2025-08-20 21:09:50,968 - models.base_gnn - INFO - Initialized EnergyGNNBase with 3 layers
2025-08-20 21:09:50,970 - models.base_gnn - INFO - Created EnergyGNNBase with 467,124 parameters
2025-08-20 21:09:50,970 - models.base_gnn - INFO - Trainable parameters: 467,124



Testing GNN with real Neo4j data...
✅ GNN works with real Neo4j data!
  transformer: torch.Size([44, 128])

Testing with dummy data as verification...

TESTING BASE GNN WITH MINIMAL DATA

1. Creating dummy heterogeneous graph...
  Created graph with ['building', 'cable_group', 'transformer', 'adjacency_cluster']

2. Creating base GNN model...

3. Testing forward pass...
Error: mat1 and mat2 shapes cannot be multiplied (5x2 and 3x128)

❌ Base GNN has issues even with dummy data.


Traceback (most recent call last):
  File "D:\Drives\Temp\ipykernel_11784\3288622359.py", line 201, in test_base_gnn_minimal
    output_dict = model(x_dict, edge_index_dict, temporal_context)
  File "d:\New folder (2)\Anaconda\DDsaie\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "d:\Documents\daily\Qiuari\models\base_gnn.py", line 242, in forward
    h_dict['transformer'] = self.transformer_encoder(x_dict['transformer'])
  File "d:\New folder (2)\Anaconda\DDsaie\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "d:\Documents\daily\Qiuari\models\base_gnn.py", line 75, in forward
    return self.input_projection(x)
  File "d:\New folder (2)\Anaconda\DDsaie\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "d:\New folder (2)\Anaconda\DDsaie\lib\site-packages\torch\nn\modules\con

In [None]:
# debug_neo4j_data.py
"""
Debug why buildings aren't being found in Neo4j
"""

from neo4j import GraphDatabase

# Connect directly to Neo4j
driver = GraphDatabase.driver(
    "bolt://localhost:7687",
    auth=("neo4j", "aminasad")
)

with driver.session() as session:
    # Check what nodes exist
    result = session.run("""
        MATCH (n)
        RETURN labels(n)[0] as label, count(*) as count
        ORDER BY count DESC
        LIMIT 10
    """)
    
    print("Node types in Neo4j:")
    for record in result:
        print(f"  {record['label']}: {record['count']}")
    
    # Check relationships
    result = session.run("""
        MATCH ()-[r]->()
        RETURN type(r) as type, count(*) as count
        ORDER BY count DESC
        LIMIT 10
    """)
    
    print("\nRelationship types:")
    for record in result:
        print(f"  {record['type']}: {record['count']}")

driver.close()

Node types in Neo4j:
  EnergyState: 1114848
  DailyProfile: 42476
  CableSegment: 4455
  TimeSlot: 1848
  Building: 1517
  ConnectionPoint: 1517
  MonthlyProfile: 1517
  BatterySystem: 1485
  HeatPumpSystem: 1138
  SolarSystem: 986

Relationship types:
  DURING: 1114848
  FOR_BUILDING: 1019424
  PROFILE_FOR: 43993
  PART_OF: 4455
  CAN_INSTALL: 2389
  IN_ADJACENCY_CLUSTER: 2233
  HAS_CONNECTION_POINT: 1517
  ON_SEGMENT: 1517
  CONNECTED_TO: 1517
  SHOULD_ELECTRIFY: 1079


In [None]:
# check_missing_nodes.py
"""
Check if grid infrastructure nodes exist with different names
"""

from neo4j import GraphDatabase

driver = GraphDatabase.driver(
    "bolt://localhost:7687",
    auth=("neo4j", "aminasad")
)

with driver.session() as session:
    print("Searching for grid infrastructure nodes...")
    
    # Check various possible node labels
    possible_labels = [
        'CableGroup', 'Cable_Group', 'LVGroup', 'LV_Group',
        'Transformer', 'MVTransformer', 'MV_Transformer',
        'Substation', 'HVSubstation', 'HV_Substation',
        'AdjacencyCluster', 'Adjacency_Cluster', 'Cluster'
    ]
    
    for label in possible_labels:
        result = session.run(f"""
            MATCH (n:{label})
            RETURN count(n) as count
        """)
        count = result.single()['count']
        if count > 0:
            print(f"  Found {label}: {count} nodes")
    
    # Check if Buildings have cable group info as properties
    print("\\nChecking Building properties for cable group info...")
    result = session.run("""
        MATCH (b:Building)
        WHERE b.lv_group_id IS NOT NULL OR b.cable_group_id IS NOT NULL
        RETURN 
            CASE WHEN b.lv_group_id IS NOT NULL THEN 'lv_group_id' 
                 ELSE 'cable_group_id' END as property,
            count(*) as count
    """)
    for record in result:
        print(f"  Buildings with {record['property']}: {record['count']}")

driver.close()

Searching for grid infrastructure nodes...




  Found CableGroup: 209 nodes
  Found Transformer: 49 nodes
  Found Substation: 2 nodes




  Found AdjacencyCluster: 327 nodes
\nChecking Building properties for cable group info...




  Buildings with lv_group_id: 1517


In [None]:
# test_base_gnn_fixed.py
"""
Fixed test for base_gnn.py that directly queries Neo4j
"""

import sys
import torch
import numpy as np
from neo4j import GraphDatabase
from torch_geometric.data import HeteroData
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Add project paths
sys.path.append('.')
sys.path.append('./models')

from models.base_gnn import create_energy_gnn_base

def get_neo4j_data():
    """Directly get data from Neo4j"""
    print("\\n" + "="*60)
    print("FETCHING DATA FROM NEO4J")
    print("="*60)
    
    driver = GraphDatabase.driver(
        "bolt://localhost:7687",
        auth=("neo4j", "aminasad")
    )
    
    with driver.session() as session:
        # 1. Get Buildings
        print("\\n1. Fetching buildings...")
        result = session.run("""
            MATCH (b:Building)
            RETURN b.ogc_fid as id, b.area as area, b.height as height, 
                   b.roof_area as roof_area, b.has_solar as has_solar,
                   b.has_battery as has_battery, b.has_heat_pump as has_hp,
                   b.x as x, b.y as y, b.lv_group_id as lv_group,
                   b.energy_label_simple as energy_label,
                   b.avg_electricity_demand as avg_elec,
                   b.avg_heating_demand as avg_heat,
                   b.peak_demand_kw as peak_demand,
                   b.energy_intensity as energy_intensity,
                   b.building_function as function,
                   b.shared_wall_north as wall_n,
                   b.shared_wall_south as wall_s
            ORDER BY id
        """)
        
        buildings = []
        building_id_map = {}
        for i, record in enumerate(result):
            building_id_map[record['id']] = i
            # Create feature vector (17 dimensions)
            features = [
                float(record['area'] or 100),
                float(ord(record['energy_label'][0]) - ord('A') + 1) if record['energy_label'] else 4,  # A=1, B=2, etc
                1.0 if record['roof_area'] and record['roof_area'] > 100 else 0.0,  # Solar potential
                1.0,  # Electrify score placeholder
                2020 - 1970,  # Age placeholder
                float(record['roof_area'] or 50),
                float(record['height'] or 10),
                float(record['has_solar'] or 0),
                float(record['has_battery'] or 0),
                float(record['has_hp'] or 0),
                float((record['wall_n'] or 0) + (record['wall_s'] or 0)),  # Total shared walls
                float(record['x'] or 0),
                float(record['y'] or 0),
                float(record['avg_elec'] or 100),
                float(record['avg_heat'] or 50),
                float(record['peak_demand'] or 10),
                float(record['energy_intensity'] or 100)
            ]
            buildings.append(features)
        
        building_features = torch.tensor(buildings, dtype=torch.float32)
        print(f"  Loaded {len(buildings)} buildings with shape {building_features.shape}")
        
        # 2. Get Cable Groups
        print("\\n2. Fetching cable groups...")
        result = session.run("""
            MATCH (cg:CableGroup)
            OPTIONAL MATCH (b:Building)-[:CONNECTED_TO]->(cg)
            WITH cg, count(b) as building_count
            RETURN cg.group_id as id, 
                   building_count,
                   cg.baseline_peak_kw as peak_kw,
                   cg.baseline_diversity as diversity
            ORDER BY id
        """)
        
        cable_groups = []
        cg_id_map = {}
        for i, record in enumerate(result):
            cg_id_map[record['id']] = i
            features = [
                float(record['building_count'] or 0),
                float(record['peak_kw'] or 100),
                float(record['diversity'] or 1),
                1.0  # LV voltage level
            ]
            cable_groups.append(features)
        
        cg_features = torch.tensor(cable_groups, dtype=torch.float32) if cable_groups else torch.zeros((0, 4))
        print(f"  Loaded {len(cable_groups)} cable groups with shape {cg_features.shape}")
        
        # 3. Get Transformers
        print("\\n3. Fetching transformers...")
        result = session.run("""
            MATCH (t:Transformer)
            RETURN t.transformer_id as id, t.x as x, t.y as y
            ORDER BY id
        """)
        
        transformers = []
        transformer_id_map = {}
        for i, record in enumerate(result):
            transformer_id_map[record['id']] = i
            features = [
                float(record['x'] or 0),
                float(record['y'] or 0),
                250.0  # Default capacity placeholder
            ]
            transformers.append(features)
        
        transformer_features = torch.tensor(transformers, dtype=torch.float32) if transformers else torch.zeros((0, 3))
        print(f"  Loaded {len(transformers)} transformers with shape {transformer_features.shape}")
        
        # 4. Get Adjacency Clusters
        print("\\n4. Fetching adjacency clusters...")
        result = session.run("""
            MATCH (ac:AdjacencyCluster)
            RETURN ac.cluster_id as id, 
                   ac.member_count as member_count,
                   ac.avg_shared_walls as avg_walls,
                   ac.energy_sharing_potential as sharing_potential
            ORDER BY id
        """)
        
        clusters = []
        cluster_id_map = {}
        for i, record in enumerate(result):
            cluster_id_map[record['id']] = i
            features = [
                float(record['member_count'] or 5),
                float(record['avg_walls'] or 1),
                1.0,  # Cluster type encoded
                0.5 if record['sharing_potential'] == 'HIGH' else 0.3
            ]
            clusters.append(features)
        
        cluster_features = torch.tensor(clusters, dtype=torch.float32) if clusters else torch.zeros((0, 4))
        print(f"  Loaded {len(clusters)} clusters with shape {cluster_features.shape}")
        
        # 5. Get Edges
        print("\\n5. Fetching edges...")
        edges = {}
        
        # Building -> CableGroup
        result = session.run("""
            MATCH (b:Building)-[:CONNECTED_TO]->(cg:CableGroup)
            RETURN b.ogc_fid as building_id, cg.group_id as cg_id
        """)
        b_to_cg = []
        for record in result:
            if record['building_id'] in building_id_map and record['cg_id'] in cg_id_map:
                b_to_cg.append([building_id_map[record['building_id']], 
                               cg_id_map[record['cg_id']]])
        
        if b_to_cg:
            edges[('building', 'connected_to', 'cable_group')] = torch.tensor(b_to_cg, dtype=torch.long).t()
            print(f"  Building->CableGroup: {len(b_to_cg)} edges")
        
        # CableGroup -> Transformer
        result = session.run("""
            MATCH (cg:CableGroup)-[:CONNECTS_TO]->(t:Transformer)
            RETURN cg.group_id as cg_id, t.transformer_id as t_id
        """)
        cg_to_t = []
        for record in result:
            if record['cg_id'] in cg_id_map and record['t_id'] in transformer_id_map:
                cg_to_t.append([cg_id_map[record['cg_id']], 
                               transformer_id_map[record['t_id']]])
        
        if cg_to_t:
            edges[('cable_group', 'connects_to', 'transformer')] = torch.tensor(cg_to_t, dtype=torch.long).t()
            print(f"  CableGroup->Transformer: {len(cg_to_t)} edges")
        
        # Building -> AdjacencyCluster
        result = session.run("""
            MATCH (b:Building)-[:IN_ADJACENCY_CLUSTER]->(ac:AdjacencyCluster)
            RETURN b.ogc_fid as building_id, ac.cluster_id as cluster_id
        """)
        b_to_cluster = []
        for record in result:
            if record['building_id'] in building_id_map and record['cluster_id'] in cluster_id_map:
                b_to_cluster.append([building_id_map[record['building_id']], 
                                    cluster_id_map[record['cluster_id']]])
        
        if b_to_cluster:
            edges[('building', 'in_cluster', 'adjacency_cluster')] = torch.tensor(b_to_cluster, dtype=torch.long).t()
            print(f"  Building->Cluster: {len(b_to_cluster)} edges")
    
    driver.close()
    
    return {
        'building': building_features,
        'cable_group': cg_features,
        'transformer': transformer_features,
        'adjacency_cluster': cluster_features
    }, edges

def test_base_gnn_with_real_data():
    """Test base GNN with real Neo4j data"""
    print("\\n" + "="*60)
    print("TESTING BASE GNN WITH REAL NEO4J DATA")
    print("="*60)
    
    # Get data from Neo4j
    x_dict, edge_index_dict = get_neo4j_data()
    
    # Create temporal context
    temporal_context = {
        'season': torch.tensor([0]),  # Winter
        'is_weekend': torch.tensor([0]),  # Weekday
        'hour': torch.tensor([14])  # 2 PM
    }
    
    # Create and test model
    print("\\n" + "="*60)
    print("TESTING GNN MODEL")
    print("="*60)
    
    config = {
        'hidden_dim': 128,
        'num_layers': 3,
        'dropout': 0.1
    }
    
    model = create_energy_gnn_base(config)
    model.eval()
    
    print("\\n1. Testing forward pass...")
    try:
        with torch.no_grad():
            output_dict = model(x_dict, edge_index_dict, temporal_context)
        
        print("\\n2. Output shapes:")
        for key, value in output_dict.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: {value.shape}")
        
        print("\\n3. Checking for NaN values:")
        for key, value in output_dict.items():
            if isinstance(value, torch.Tensor):
                has_nan = torch.isnan(value).any().item()
                status = "✗ HAS NaN!" if has_nan else "✓ No NaN"
                print(f"  {key}: {status}")
        
        print("\\n4. Value statistics:")
        for key, value in output_dict.items():
            if isinstance(value, torch.Tensor) and key != 'attention_weights':
                print(f"  {key}:")
                print(f"    Mean: {value.mean().item():.4f}")
                print(f"    Std:  {value.std().item():.4f}")
                print(f"    Min:  {value.min().item():.4f}")
                print(f"    Max:  {value.max().item():.4f}")
        
        print("\\n✅ BASE GNN WORKS WITH YOUR NEO4J DATA!")
        
        return model, output_dict
        
    except Exception as e:
        print(f"\\n❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return None, None

if __name__ == "__main__":
    model, outputs = test_base_gnn_with_real_data()
    
    if outputs is not None:
        print("\\n" + "="*60)
        print("SUCCESS!")
        print("="*60)
        print("\\nNext steps:")
        print("1. The base GNN is working correctly")
        print("2. It processes all your node types properly")
        print("3. Ready to implement task heads for specific objectives")
    else:
        print("\\n" + "="*60)
        print("NEEDS DEBUGGING")
        print("="*60)

Couldn't import dot_parser, loading of dot files will not be possible.
TESTING BASE GNN WITH REAL NEO4J DATA
FETCHING DATA FROM NEO4J
\n1. Fetching buildings...




  Loaded 1517 buildings with shape torch.Size([1517, 17])
\n2. Fetching cable groups...
  Loaded 209 cable groups with shape torch.Size([209, 4])
\n3. Fetching transformers...
  Loaded 49 transformers with shape torch.Size([49, 3])
\n4. Fetching adjacency clusters...
  Loaded 327 clusters with shape torch.Size([327, 4])
\n5. Fetching edges...
  Building->CableGroup: 1517 edges
  CableGroup->Transformer: 301 edges
  Building->Cluster: 2233 edges
TESTING GNN MODEL


2025-08-20 23:24:48,365 - models.base_gnn - INFO - Initialized EnergyGNNBase with 3 layers
2025-08-20 23:24:48,366 - models.base_gnn - INFO - Created EnergyGNNBase with 417,780 parameters
2025-08-20 23:24:48,367 - models.base_gnn - INFO - Trainable parameters: 417,780


\n1. Testing forward pass...
\n2. Output shapes:
  building: torch.Size([1517, 128])
  cable_group: torch.Size([209, 128])
  transformer: torch.Size([49, 128])
  adjacency_cluster: torch.Size([327, 64])
\n3. Checking for NaN values:
  building: ✓ No NaN
  cable_group: ✓ No NaN
  transformer: ✓ No NaN
  adjacency_cluster: ✓ No NaN
\n4. Value statistics:
  building:
    Mean: 0.0554
    Std:  0.5013
    Min:  -1.6188
    Max:  1.0281
  cable_group:
    Mean: -0.0467
    Std:  0.5161
    Min:  -1.6924
    Max:  1.3478
  transformer:
    Mean: -0.0769
    Std:  0.5969
    Min:  -1.7125
    Max:  1.9594
  adjacency_cluster:
    Mean: 0.0391
    Std:  0.4236
    Min:  -1.0102
    Max:  1.3087
\n✅ BASE GNN WORKS WITH YOUR NEO4J DATA!
SUCCESS!
\nNext steps:
1. The base GNN is working correctly
2. It processes all your node types properly
3. Ready to implement task heads for specific objectives


## attention_layers 

In [None]:
import torch
import torch.nn as nn
from models.attention_layers import EnergyComplementarityAttention

def test_attention_layer():
    print("Testing attention layer...")
    
    config = {
        'hidden_dim': 128,
        'attention_heads': 8,
        'dropout': 0.1
    }
    
    # Create attention module
    attention = EnergyComplementarityAttention(config)
    attention.eval()
    
    # Create dummy embeddings
    num_buildings = 10
    embed_dim = 128
    
    dummy_embeddings = {
        'building': torch.randn(num_buildings, embed_dim),
        'cable_group': torch.randn(5, embed_dim),
        'transformer': torch.randn(2, embed_dim)
    }
    
    # Forward pass
    with torch.no_grad():
        output = attention(
            dummy_embeddings,
            {},  # Empty edge dict
            temporal_features=None,
            return_attention=True
        )
    
    print("✅ Attention layer works!")
    print(f"Enhanced embeddings shape: {output['embeddings']['building'].shape}")
    print(f"Complementarity matrix shape: {output['complementarity_matrix'].shape}")
    
    return output

# Run test
result = test_attention_layer()

2025-08-20 23:25:58,059 - models.attention_layers - INFO - Initialized EnergyComplementarityAttention


Testing attention layer...
✅ Attention layer works!
Enhanced embeddings shape: torch.Size([10, 128])
Complementarity matrix shape: torch.Size([10, 10])


In [None]:
# Quick fix - update your base_gnn initialization with correct dimensions
import torch
import torch.nn as nn
import yaml
from data.kg_connector import KGConnector
from data.graph_constructor import GraphConstructor
from data.feature_processor import FeatureProcessor
from models.base_gnn import EnergyGNNBase
from models.attention_layers import create_attention_module

NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "aminasad"

def test_with_correct_dimensions():
    """Test with correct feature dimensions from Neo4j"""
    
    print("="*60)
    print("TESTING WITH CORRECT FEATURE DIMENSIONS")
    print("="*60)
    
    # 1. First, get actual dimensions from Neo4j
    print("\\n1. Getting actual feature dimensions from Neo4j...")
    kg = KGConnector(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    graph_builder = GraphConstructor(kg)
    
    graph_data = graph_builder.build_hetero_graph(
        district_name="Buitenveldert-Oost",
        include_energy_sharing=True,
        include_temporal=False
    )
    
    # Print actual dimensions
    actual_dims = {}
    for node_type, features in graph_data.x_dict.items():
        if features is not None:
            actual_dims[node_type] = features.shape[1]
            print(f"   {node_type}: {features.shape} → {features.shape[1]} features")
    
    # 2. Create config with CORRECT dimensions
    config = {
        'hidden_dim': 128,
        'num_layers': 3,
        'dropout': 0.1,
        'attention_heads': 8
    }
    
    # 3. Create base GNN with CORRECT feature dimensions
    print("\\n2. Creating base GNN with correct dimensions...")
    base_gnn = EnergyGNNBase(
        config=config,
        node_features=actual_dims  # Use actual dimensions from Neo4j!
    )
    
    attention_module = create_attention_module(config)
    
    base_gnn.eval()
    attention_module.eval()
    print("   ✓ Models created with correct dimensions")
    
    # 4. Process features
    print("\\n3. Processing features...")
    feature_processor = FeatureProcessor()
    feature_processor.process_graph_features(graph_data, fit=True)
    
    # 5. Run base GNN
    print("\\n4. Running base GNN...")
    with torch.no_grad():
        temporal_context = {
            'season': torch.tensor([0]),
            'is_weekend': torch.tensor([0]),
            'hour': torch.tensor([14])
        }
        
        base_embeddings = base_gnn(
            graph_data.x_dict,
            graph_data.edge_index_dict,
            temporal_context
        )
    
    print("   Base embeddings created:")
    for node_type, emb in base_embeddings.items():
        if emb is not None:
            print(f"   - {node_type}: {emb.shape}")
    
    # 6. Run attention
    print("\\n5. Running attention layer...")
    with torch.no_grad():
        attention_output = attention_module(
            base_embeddings,
            graph_data.edge_index_dict,
            temporal_features=None,
            return_attention=True
        )
    
    comp_matrix = attention_output['complementarity_matrix']
    print(f"\\n   Complementarity matrix: {comp_matrix.shape}")
    print(f"   Stats: min={comp_matrix.min():.3f}, max={comp_matrix.max():.3f}, mean={comp_matrix.mean():.3f}")
    
    # Find top pairs
    print("\\n6. Top complementary pairs:")
    mask = torch.eye(comp_matrix.shape[0], dtype=torch.bool)
    comp_matrix_masked = comp_matrix.clone()
    comp_matrix_masked[mask] = -1
    
    values, indices = torch.topk(comp_matrix_masked.flatten(), 10)
    seen = set()
    for val, idx in zip(values, indices):
        i, j = idx // comp_matrix.shape[0], idx % comp_matrix.shape[0]
        if i != j:
            pair = tuple(sorted([i.item(), j.item()]))
            if pair not in seen:
                seen.add(pair)
                print(f"   Building {pair[0]} <-> Building {pair[1]}: {val:.3f}")
                if len(seen) >= 5:
                    break
    
    print("\\n✅ SUCCESS! Pipeline works with correct dimensions!")
    kg.close()
    
    return attention_output, comp_matrix

# Run the corrected test
try:
    output, comp_matrix = test_with_correct_dimensions()
except Exception as e:
    print(f"\\nError: {e}")
    import traceback
    traceback.print_exc()

2025-08-20 23:26:13,010 - data.kg_connector - INFO - Connected to Neo4j at bolt://localhost:7687
2025-08-20 23:26:13,010 - data.graph_constructor - INFO - Building graph for district Buitenveldert-Oost
2025-08-20 23:26:13,011 - data.graph_constructor - INFO - Temporal features: False, Lookback: 24 hours


TESTING WITH CORRECT FEATURE DIMENSIONS
\n1. Getting actual feature dimensions from Neo4j...


2025-08-20 23:26:15,228 - data.kg_connector - INFO - Edge counts - B->CG: 335, CG->T: 13, T->S: 0, B->AC: 346
2025-08-20 23:26:15,369 - data.graph_constructor - INFO - No transformer_to_substation edges (substations might not exist)
2025-08-20 23:26:15,370 - data.graph_constructor - INFO - Graph built: {'building': 335, 'cable_group': 21, 'transformer': 44, 'substation': 0, 'adjacency_cluster': 95}
2025-08-20 23:26:15,386 - models.base_gnn - INFO - Initialized EnergyGNNBase with 3 layers
2025-08-20 23:26:15,388 - models.attention_layers - INFO - Initialized EnergyComplementarityAttention
2025-08-20 23:26:15,390 - data.feature_processor - INFO - Processing graph features
2025-08-20 23:26:15,392 - data.feature_processor - INFO - Added engineered features for building: shape torch.Size([335, 7])
2025-08-20 23:26:15,393 - data.feature_processor - INFO - Added engineered features for cable_group: shape torch.Size([21, 4])
2025-08-20 23:26:15,395 - data.feature_processor - INFO - Added engin

   building: torch.Size([335, 17]) → 17 features
   cable_group: torch.Size([21, 12]) → 12 features
   transformer: torch.Size([44, 3]) → 3 features
   adjacency_cluster: torch.Size([95, 11]) → 11 features
\n2. Creating base GNN with correct dimensions...
   ✓ Models created with correct dimensions
\n3. Processing features...
\n4. Running base GNN...
   Base embeddings created:
   - building: torch.Size([335, 128])
   - cable_group: torch.Size([21, 128])
   - transformer: torch.Size([44, 128])
   - adjacency_cluster: torch.Size([95, 64])
\n5. Running attention layer...
\n   Complementarity matrix: torch.Size([335, 335])
   Stats: min=0.475, max=0.500, mean=0.488
\n6. Top complementary pairs:
   Building 95 <-> Building 116: 0.500
   Building 95 <-> Building 118: 0.500
   Building 95 <-> Building 117: 0.500
   Building 116 <-> Building 118: 0.500
   Building 116 <-> Building 117: 0.500
\n✅ SUCCESS! Pipeline works with correct dimensions!


In [None]:
# simple_attention_test.py
"""
Simpler test focusing just on attention layer functionality
"""

import torch
from models.attention_layers import EnergyComplementarityAttention

def simple_test():
    print("Simple Attention Test with Mock Neo4j-like Data")
    print("="*50)
    
    # Configuration matching your setup
    config = {
        'hidden_dim': 128,
        'attention_heads': 8,
        'dropout': 0.1
    }
    
    # Create attention module
    attention = EnergyComplementarityAttention(config)
    attention.eval()
    
    # Simulate Neo4j data dimensions
    # Based on your actual data:
    num_buildings = 335  # From your Neo4j
    num_cable_groups = 21
    num_transformers = 44
    num_clusters = 95
    embed_dim = 128
    
    # Create mock embeddings (as if from base_gnn)
    mock_embeddings = {
        'building': torch.randn(num_buildings, embed_dim),
        'cable_group': torch.randn(num_cable_groups, embed_dim),
        'transformer': torch.randn(num_transformers, embed_dim),
        'adjacency_cluster': torch.randn(num_clusters, 64)
    }
    
    # Mock edge indices
    mock_edges = {
        ('building', 'connected_to', 'cable_group'): torch.randint(0, min(num_buildings, num_cable_groups), (2, 335)),
        ('cable_group', 'connects_to', 'transformer'): torch.randint(0, min(num_cable_groups, num_transformers), (2, 13))
    }
    
    print(f"Testing with:")
    print(f"  Buildings: {num_buildings}")
    print(f"  Cable Groups: {num_cable_groups}")
    print(f"  Transformers: {num_transformers}")
    print(f"  Clusters: {num_clusters}")
    
    # Run attention
    with torch.no_grad():
        output = attention(
            mock_embeddings,
            mock_edges,
            temporal_features=None,
            return_attention=True
        )
    
    # Check outputs
    comp_matrix = output['complementarity_matrix']
    print(f"\\n✅ Complementarity matrix shape: {comp_matrix.shape}")
    print(f"   Expected: [{num_buildings}, {num_buildings}]")
    
    # Find some high complementarity pairs
    top_k = 5
    values, indices = torch.topk(comp_matrix.flatten(), top_k)
    
    print(f"\\nTop {top_k} complementarity scores:")
    for i, (val, idx) in enumerate(zip(values, indices)):
        row = idx // num_buildings
        col = idx % num_buildings
        print(f"  {i+1}. Building {row} <-> Building {col}: {val:.3f}")
    
    print("\\n✅ Attention layer successfully processed Neo4j-sized data!")
    
    return output

# Run simple test
if __name__ == "__main__":
    result = simple_test()

2025-08-20 23:26:19,969 - models.attention_layers - INFO - Initialized EnergyComplementarityAttention


Simple Attention Test with Mock Neo4j-like Data
Testing with:
  Buildings: 335
  Cable Groups: 21
  Transformers: 44
  Clusters: 95
\n✅ Complementarity matrix shape: torch.Size([335, 335])
   Expected: [335, 335]
\nTop 5 complementarity scores:
  1. Building 155 <-> Building 218: 0.531
  2. Building 218 <-> Building 155: 0.531
  3. Building 286 <-> Building 331: 0.524
  4. Building 331 <-> Building 286: 0.524
  5. Building 293 <-> Building 331: 0.523
\n✅ Attention layer successfully processed Neo4j-sized data!


## Temporal layers

In [None]:
# test_temporal_layers.py
"""
Test temporal layers with real Neo4j data
Tests the complete pipeline: base_gnn -> attention -> temporal
"""

import torch
import torch.nn as nn
import numpy as np
from neo4j import GraphDatabase
import logging
from datetime import datetime, timedelta

# Import your modules
from models.base_gnn import EnergyGNNBase, create_energy_gnn_base
from models.attention_layers import EnergyComplementarityAttention
from models.temporal_layers import TemporalProcessor

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class Neo4jTemporalDataFetcher:
    """Fetch temporal data from Neo4j"""
    
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        logger.info(f"Connected to Neo4j for temporal data fetching")
    
    def fetch_consumption_history(self, limit_buildings=100):
        """Fetch 24-hour consumption history from EnergyState nodes"""
        
        with self.driver.session() as session:
            # First, get building IDs that have energy states
            result = session.run("""
                MATCH (b:Building)-[:HAS_STATE_AT]->(es:EnergyState)
                WITH b.ogc_fid as building_id, COUNT(es) as state_count
                WHERE state_count >= 24
                RETURN building_id
                ORDER BY building_id
                LIMIT $limit
            """, limit=limit_buildings)
            
            building_ids = [record['building_id'] for record in result]
            logger.info(f"Found {len(building_ids)} buildings with sufficient energy states")
            
            if not building_ids:
                logger.warning("No buildings with energy states found, creating synthetic data")
                return self._create_synthetic_consumption(limit_buildings)
            
            # Fetch energy states for these buildings
            consumption_data = {}
            
            for building_id in building_ids[:10]:  # Test with first 10 buildings
                result = session.run("""
                    MATCH (b:Building {ogc_fid: $building_id})-[:HAS_STATE_AT]->(es:EnergyState)-[:DURING]->(ts:TimeSlot)
                    RETURN es.electricity_demand_kw as elec_demand,
                           es.heating_demand_kw as heat_demand,
                           es.cooling_demand_kw as cool_demand,
                           es.solar_generation_kw as solar_gen,
                           es.net_demand_kw as net_demand,
                           es.is_surplus as is_surplus,
                           ts.hour_of_day as hour,
                           ts.is_weekend as is_weekend,
                           ts.season as season
                    ORDER BY ts.timestamp DESC
                    LIMIT 24
                """, building_id=building_id)
                
                records = list(result)
                if records:
                    consumption_data[building_id] = records
            
            return self._format_consumption_data(consumption_data, building_ids)
    
    def _format_consumption_data(self, consumption_data, building_ids):
        """Format consumption data into tensor format"""
        
        if not consumption_data:
            return self._create_synthetic_consumption(len(building_ids))
        
        # Create tensor [num_buildings, 24_hours, 8_features]
        num_buildings = len(building_ids)
        consumption_tensor = torch.zeros(1, num_buildings, 24, 8)
        
        for idx, building_id in enumerate(building_ids):
            if building_id in consumption_data:
                records = consumption_data[building_id]
                for t, record in enumerate(records[:24]):
                    consumption_tensor[0, idx, t, 0] = record.get('elec_demand', 0) or 0
                    consumption_tensor[0, idx, t, 1] = record.get('heat_demand', 0) or 0
                    consumption_tensor[0, idx, t, 2] = record.get('cool_demand', 0) or 0
                    consumption_tensor[0, idx, t, 3] = record.get('solar_gen', 0) or 0
                    consumption_tensor[0, idx, t, 4] = record.get('net_demand', 0) or 0
                    consumption_tensor[0, idx, t, 5] = 1.0 if record.get('is_surplus') else 0.0
                    consumption_tensor[0, idx, t, 6] = record.get('hour', t) or t
                    consumption_tensor[0, idx, t, 7] = 1.0 if record.get('is_weekend') else 0.0
        
        # Get season from last record
        season = 0  # Default winter
        is_weekend = False
        
        if consumption_data:
            first_building_data = next(iter(consumption_data.values()))
            if first_building_data:
                season_str = first_building_data[0].get('season', 'winter')
                season_map = {'winter': 0, 'spring': 1, 'summer': 2, 'autumn': 3, 'fall': 3}
                season = season_map.get(season_str, 0)
                is_weekend = bool(first_building_data[0].get('is_weekend', False))
        
        return consumption_tensor, season, is_weekend
    
    def _create_synthetic_consumption(self, num_buildings):
        """Create synthetic consumption patterns for testing"""
        logger.info("Creating synthetic consumption patterns for testing")
        
        # Create different consumption patterns
        consumption_tensor = torch.zeros(1, num_buildings, 24, 8)
        
        for i in range(num_buildings):
            # Create different patterns based on building index
            pattern_type = i % 4
            
            for h in range(24):
                if pattern_type == 0:  # Residential pattern
                    base = 5.0
                    if 6 <= h <= 9:  # Morning peak
                        demand = base * 2.5
                    elif 18 <= h <= 22:  # Evening peak
                        demand = base * 2.0
                    else:
                        demand = base
                
                elif pattern_type == 1:  # Office pattern
                    base = 3.0
                    if 9 <= h <= 17:  # Business hours
                        demand = base * 4.0
                    else:
                        demand = base
                
                elif pattern_type == 2:  # Retail pattern
                    base = 4.0
                    if 10 <= h <= 20:  # Shopping hours
                        demand = base * 3.0
                    else:
                        demand = base
                
                else:  # Mixed pattern
                    base = 4.0
                    demand = base * (1 + np.sin(h * np.pi / 12))
                
                # Add some noise
                demand += np.random.normal(0, 0.5)
                
                consumption_tensor[0, i, h, 0] = max(0, demand)  # Electricity
                consumption_tensor[0, i, h, 1] = max(0, demand * 0.3)  # Heating
                consumption_tensor[0, i, h, 4] = max(0, demand * 1.2)  # Net demand
                consumption_tensor[0, i, h, 6] = h  # Hour
        
        season = 0  # Winter
        is_weekend = False
        
        return consumption_tensor, season, is_weekend
    
    def close(self):
        self.driver.close()


def test_temporal_layers():
    """Main test function"""
    
    print("\n" + "="*60)
    print("TESTING TEMPORAL LAYERS WITH NEO4J DATA")
    print("="*60 + "\n")
    
    # Configuration
    config = {
        'hidden_dim': 128,
        'num_layers': 3,
        'dropout': 0.1,
        'attention_heads': 8
    }
    
    # Neo4j connection
    neo4j_uri = "bolt://localhost:7687"
    neo4j_user = "neo4j"
    neo4j_password = "aminasad"  # Update with your password
    
    try:
        # 1. Fetch temporal data from Neo4j
        print("1. Fetching temporal data from Neo4j...")
        fetcher = Neo4jTemporalDataFetcher(neo4j_uri, neo4j_user, neo4j_password)
        consumption_history, season, is_weekend = fetcher.fetch_consumption_history(limit_buildings=100)
        print(f"   Consumption history shape: {consumption_history.shape}")
        print(f"   Season: {season}, Weekend: {is_weekend}")
        
        # 2. Create dummy base embeddings (simulating base_gnn output)
        print("\n2. Creating base embeddings...")
        num_buildings = consumption_history.shape[1]
        building_embeddings = torch.randn(num_buildings, 128)
        cable_group_embeddings = torch.randn(20, 128)  # Assuming 20 cable groups
        transformer_embeddings = torch.randn(10, 128)  # Assuming 10 transformers
        cluster_embeddings = torch.randn(30, 64)  # Assuming 30 clusters
        
        embeddings_dict = {
            'building': building_embeddings,
            'cable_group': cable_group_embeddings,
            'transformer': transformer_embeddings,
            'adjacency_cluster': cluster_embeddings
        }
        print(f"   Building embeddings: {building_embeddings.shape}")
        
        # 3. Create attention layer and process
        print("\n3. Running attention layer...")
        attention_layer = EnergyComplementarityAttention(config)
        attention_layer.eval()
        
        # Create dummy edge indices
        edge_index_dict = {
            ('building', 'connected_to', 'cable_group'): torch.randint(0, min(num_buildings, 20), (2, num_buildings)),
            ('cable_group', 'connects_to', 'transformer'): torch.randint(0, 10, (2, 20)),
        }
        
        with torch.no_grad():
            attention_output = attention_layer(embeddings_dict, edge_index_dict, return_attention=False)
        
        enhanced_embeddings = attention_output['embeddings']
        print(f"   Enhanced building embeddings: {enhanced_embeddings['building'].shape}")
        print(f"   Complementarity matrix: {attention_output['complementarity_matrix'].shape}")
        
        # 4. Create and test temporal processor
        print("\n4. Testing temporal processor...")
        temporal_processor = TemporalProcessor(config)
        temporal_processor.eval()
        
        temporal_data = {
            'consumption_history': consumption_history,
            'season': torch.tensor(season),
            'is_weekend': torch.tensor(is_weekend)
        }
        
        with torch.no_grad():
            # Test single hour processing
            print("\n   a) Testing single hour (hour 14)...")
            temporal_output = temporal_processor(
                enhanced_embeddings,
                temporal_data=temporal_data,
                current_hour=14,
                return_all_hours=False
            )
            
            print(f"      Final embeddings: {temporal_output['embeddings']['building'].shape}")
            print(f"      Consumption predictions: {temporal_output['consumption_predictions'].shape}")
            print(f"      Temporal complementarity: {temporal_output['temporal_complementarity'].shape}")
            print(f"      Peak indicators: {temporal_output['peak_indicators'].shape}")
            
            # Test all hours processing
            print("\n   b) Testing all 24 hours...")
            temporal_output_all = temporal_processor(
                enhanced_embeddings,
                temporal_data=temporal_data,
                current_hour=None,
                return_all_hours=True
            )
            
            print(f"      Hourly embeddings: {temporal_output_all['hourly_embeddings'].shape}")
            
        # 5. Analyze outputs
        print("\n5. Analyzing outputs...")
        
        # Check for NaN values
        for key, value in temporal_output.items():
            if isinstance(value, torch.Tensor):
                has_nan = torch.isnan(value).any().item()
                print(f"   {key}: {'⚠️ HAS NaN' if has_nan else '✓ No NaN'}")
        
        # Analyze complementarity patterns
        print("\n6. Complementarity Analysis...")
        comp_matrix = temporal_output['temporal_complementarity']
        if comp_matrix.dim() == 3:
            comp_matrix = comp_matrix[0]  # Remove batch dimension
        
        # Find most complementary pairs
        comp_values = comp_matrix.flatten()
        top_k = 5
        top_indices = torch.topk(comp_values.abs(), top_k).indices
        
        print(f"   Top {top_k} complementary pairs:")
        for idx in top_indices:
            i = idx // num_buildings
            j = idx % num_buildings
            if i < j:  # Avoid duplicates
                score = comp_matrix[i, j].item()
                print(f"      Building {i} <-> Building {j}: {score:.3f}")
        
        # Analyze peak hours
        print("\n7. Peak Hour Analysis...")
        peak_indicators = temporal_output['peak_indicators']
        if peak_indicators.dim() == 3:
            peak_indicators = peak_indicators[0]
        
        # Find buildings with most peak hours
        peak_counts = peak_indicators.sum(dim=1)
        top_peak_buildings = torch.topk(peak_counts, min(5, num_buildings)).indices
        
        print("   Buildings with most peak hours:")
        for idx in top_peak_buildings:
            count = peak_counts[idx].item()
            peak_hours = torch.where(peak_indicators[idx] > 0.5)[0].tolist()
            print(f"      Building {idx}: {int(count)} peak hours at {peak_hours}")
        
        # Consumption prediction analysis
        print("\n8. Consumption Prediction Sample...")
        predictions = temporal_output['consumption_predictions']
        if predictions.dim() == 3:
            predictions = predictions[0]
        
        # Show predictions for first 3 buildings
        for i in range(min(3, num_buildings)):
            pred_values = predictions[i, :6].tolist()  # First 6 hours
            print(f"   Building {i} next 6 hours: {[f'{v:.1f}' for v in pred_values]}")
        
        print("\n" + "="*60)
        print("✅ TEMPORAL LAYERS TEST SUCCESSFUL!")
        print("="*60)
        
        print("\nKey Insights:")
        print("1. Temporal processor successfully processes consumption history")
        print("2. Creates hour-specific embeddings for dynamic clustering")
        print("3. Identifies complementary consumption patterns")
        print("4. Predicts future consumption for planning")
        print("5. Identifies peak hours for load management")
        
        fetcher.close()
        
    except Exception as e:
        logger.error(f"Test failed: {str(e)}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    test_temporal_layers()

Couldn't import dot_parser, loading of dot files will not be possible.


2025-08-20 23:56:27,840 - __main__ - INFO - Connected to Neo4j for temporal data fetching



TESTING TEMPORAL LAYERS WITH NEO4J DATA

1. Fetching temporal data from Neo4j...


2025-08-20 23:56:29,900 - __main__ - INFO - Found 0 buildings with sufficient energy states
2025-08-20 23:56:29,901 - __main__ - INFO - Creating synthetic consumption patterns for testing
2025-08-20 23:56:29,967 - models.attention_layers - INFO - Initialized EnergyComplementarityAttention
2025-08-20 23:56:30,071 - models.temporal_layers - INFO - Initialized TemporalProcessor with all components


   Consumption history shape: torch.Size([1, 100, 24, 8])
   Season: 0, Weekend: False

2. Creating base embeddings...
   Building embeddings: torch.Size([100, 128])

3. Running attention layer...
   Enhanced building embeddings: torch.Size([100, 128])
   Complementarity matrix: torch.Size([100, 100])

4. Testing temporal processor...

   a) Testing single hour (hour 14)...
      Final embeddings: torch.Size([100, 128])
      Consumption predictions: torch.Size([100, 24])
      Temporal complementarity: torch.Size([100, 100])
      Peak indicators: torch.Size([100, 24])

   b) Testing all 24 hours...
      Hourly embeddings: torch.Size([100, 24, 128])

5. Analyzing outputs...
   temporal_encoding: ✓ No NaN
   consumption_predictions: ✓ No NaN
   temporal_complementarity: ✓ No NaN
   peak_indicators: ✓ No NaN
   peak_probabilities: ✓ No NaN

6. Complementarity Analysis...
   Top 5 complementary pairs:
      Building 10 <-> Building 40: 0.157
      Building 40 <-> Building 96: 0.155

7. 

In [None]:
# test_temporal_fixed.py
import torch
from models.temporal_layers import TemporalProcessor  # Use fixed version

# Configuration
config = {
    'hidden_dim': 128,
    'num_layers': 3,
    'dropout': 0.1,
    'attention_heads': 8
}

# Create processor
processor = TemporalProcessor(config)
processor.eval()

# Test 1: Without temporal data (uses random)
print("Test 1: Without temporal data...")
embeddings_dict = {
    'building': torch.randn(100, 128),
    'cable_group': torch.randn(20, 128),
    'transformer': torch.randn(10, 128),
    'adjacency_cluster': torch.randn(30, 64)
}

with torch.no_grad():
    output = processor(embeddings_dict, temporal_data=None, current_hour=14)
    print(f"✅ Output shape: {output['embeddings']['building'].shape}")

# Test 2: With synthetic temporal data
print("\nTest 2: With temporal data...")
temporal_data = {
    'consumption_history': torch.randn(1, 100, 24, 8),  # [batch, buildings, hours, features]
    'season': torch.tensor(0),  # Winter
    'is_weekend': torch.tensor(False)
}

with torch.no_grad():
    output = processor(embeddings_dict, temporal_data=temporal_data, current_hour=14)
    print(f"✅ Output shape: {output['embeddings']['building'].shape}")
    print(f"✅ Predictions shape: {output['consumption_predictions'].shape}")
    print(f"✅ Complementarity shape: {output['temporal_complementarity'].shape}")

print("\n🎉 Temporal processor is working correctly!")

2025-08-20 23:57:00,743 - models.temporal_layers - INFO - Initialized TemporalProcessor with all components


Test 1: Without temporal data...
✅ Output shape: torch.Size([100, 128])

Test 2: With temporal data...
✅ Output shape: torch.Size([100, 128])
✅ Predictions shape: torch.Size([100, 24])
✅ Complementarity shape: torch.Size([100, 100])

🎉 Temporal processor is working correctly!


## physical_layers

❌ Constraints We CANNOT Enforce (Missing Data):
What We DON'T Have:

❌ Actual transformer capacity numbers (you said no 250kVA)
❌ Voltage levels or limits
❌ Line impedances or resistances
❌ Power factors
❌ Detailed loss calculations
❌ Reactive power

In [None]:
# models/physics_layers.py
"""
Physics constraint layers for energy system
Enforces energy balance, LV boundaries, and distance-based losses
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Tuple, Optional, List
import logging

logger = logging.getLogger(__name__)


class LVGroupBoundaryEnforcer(nn.Module):
    """Ensures energy sharing only within same LV group"""
    
    def __init__(self):
        super().__init__()
        self.violation_penalty_weight = nn.Parameter(torch.tensor(10.0))
        
    def forward(self, 
                sharing_matrix: torch.Tensor,
                lv_group_ids: torch.Tensor,
                valid_lv_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply LV group constraints to sharing matrix
        
        Args:
            sharing_matrix: [batch, N, N] or [batch, N, N, T] proposed sharing
            lv_group_ids: [batch, N] or [N] LV group assignment for each building
            valid_lv_mask: [batch, N] or [N] mask for buildings in valid LV groups
            
        Returns:
            masked_sharing: Sharing matrix with invalid connections zeroed
            boundary_penalty: Penalty for attempted cross-boundary sharing
        """
        # Handle different input dimensions
        if lv_group_ids.dim() == 1:
            lv_group_ids = lv_group_ids.unsqueeze(0)
        
        batch_size, num_buildings = lv_group_ids.shape
        device = sharing_matrix.device
        
        # Create mask for same LV group
        lv_i = lv_group_ids.unsqueeze(2)  # [B, N, 1]
        lv_j = lv_group_ids.unsqueeze(1)  # [B, 1, N]
        same_lv_mask = (lv_i == lv_j).float()  # [B, N, N]
        
        # Apply valid LV mask if provided (skip orphaned groups)
        if valid_lv_mask is not None:
            if valid_lv_mask.dim() == 1:
                valid_lv_mask = valid_lv_mask.unsqueeze(0)
            valid_i = valid_lv_mask.unsqueeze(2)  # [B, N, 1]
            valid_j = valid_lv_mask.unsqueeze(1)  # [B, 1, N]
            valid_pair_mask = valid_i * valid_j  # Both buildings must be valid
            same_lv_mask = same_lv_mask * valid_pair_mask
        
        # Calculate penalty for violations (before masking)
        if sharing_matrix.dim() == 4:  # Has time dimension
            same_lv_mask = same_lv_mask.unsqueeze(-1)  # [B, N, N, 1]
            violations = sharing_matrix * (1 - same_lv_mask)
        else:
            violations = sharing_matrix * (1 - same_lv_mask)
        
        # Soft penalty (squared violations)
        boundary_penalty = (violations ** 2).sum() / (num_buildings ** 2)
        boundary_penalty = boundary_penalty * self.violation_penalty_weight
        
        # Apply mask to zero out invalid connections
        masked_sharing = sharing_matrix * same_lv_mask
        
        return masked_sharing, boundary_penalty


class DistanceBasedLossCalculator(nn.Module):
    """Calculates energy losses based on distance between buildings"""
    
    def __init__(self, 
                 base_efficiency: float = 0.98,
                 loss_per_meter: float = 0.0001):
        super().__init__()
        self.base_efficiency = base_efficiency
        self.loss_per_meter = loss_per_meter
        self.max_distance_penalty = nn.Parameter(torch.tensor(1000.0))
        
    def forward(self,
                sharing_matrix: torch.Tensor,
                positions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply distance-based losses to sharing
        
        Args:
            sharing_matrix: [batch, N, N] or [batch, N, N, T] energy sharing
            positions: [batch, N, 2] or [N, 2] building x,y coordinates
            
        Returns:
            loss_adjusted_sharing: Sharing with efficiency losses applied
            distance_loss: Total loss due to distance
        """
        if positions.dim() == 2:
            positions = positions.unsqueeze(0)
        
        batch_size, num_buildings, _ = positions.shape
        device = positions.device
        
        # Calculate pairwise distances
        pos_i = positions.unsqueeze(2)  # [B, N, 1, 2]
        pos_j = positions.unsqueeze(1)  # [B, 1, N, 2]
        distances = torch.norm(pos_i - pos_j, dim=-1)  # [B, N, N]
        
        # Calculate efficiency based on distance
        # Efficiency decreases with distance
        efficiency = torch.clamp(
            self.base_efficiency - self.loss_per_meter * distances,
            min=0.85,  # Minimum 85% efficiency
            max=1.0    # Maximum 100% efficiency
        )
        
        # Apply efficiency to sharing
        if sharing_matrix.dim() == 4:  # Has time dimension
            efficiency = efficiency.unsqueeze(-1)  # [B, N, N, 1]
        
        loss_adjusted_sharing = sharing_matrix * efficiency
        
        # Calculate total energy lost
        energy_lost = sharing_matrix - loss_adjusted_sharing
        distance_loss = energy_lost.abs().sum() / (num_buildings ** 2)
        
        # Add penalty for very long distance sharing
        long_distance_mask = (distances > self.max_distance_penalty).float()
        if sharing_matrix.dim() == 4:
            long_distance_mask = long_distance_mask.unsqueeze(-1)
        long_distance_penalty = (sharing_matrix * long_distance_mask).abs().sum()
        
        total_loss = distance_loss + 0.1 * long_distance_penalty
        
        return loss_adjusted_sharing, total_loss


class EnergyBalanceChecker(nn.Module):
    """Ensures energy conservation within each LV group"""
    
    def __init__(self, tolerance: float = 0.05):
        super().__init__()
        self.tolerance = tolerance
        self.imbalance_penalty_weight = nn.Parameter(torch.tensor(5.0))
        
    def forward(self,
                consumption: torch.Tensor,
                generation: torch.Tensor,
                sharing_matrix: torch.Tensor,
                lv_group_ids: torch.Tensor,
                valid_lv_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict]:
        """
        Check energy balance per LV group
        
        Args:
            consumption: [batch, N] or [batch, N, T] consumption per building
            generation: [batch, N] or [batch, N, T] generation per building
            sharing_matrix: [batch, N, N] or [batch, N, N, T] energy flows
            lv_group_ids: [batch, N] or [N] LV group assignments
            valid_lv_mask: [batch, N] or [N] mask for valid buildings
            
        Returns:
            balance_penalty: Penalty for energy imbalance
            balance_info: Dictionary with balance details per LV group
        """
        if lv_group_ids.dim() == 1:
            lv_group_ids = lv_group_ids.unsqueeze(0)
        
        batch_size = consumption.shape[0]
        device = consumption.device
        
        # Get unique LV groups
        unique_lv_groups = torch.unique(lv_group_ids)
        
        total_imbalance = torch.tensor(0.0, device=device)
        balance_info = {}
        
        for lv_group in unique_lv_groups:
            # Skip invalid groups (e.g., -1 for orphaned)
            if lv_group < 0:
                continue
                
            # Get buildings in this LV group
            group_mask = (lv_group_ids == lv_group).float()
            
            # Apply valid mask if provided
            if valid_lv_mask is not None:
                if valid_lv_mask.dim() == 1:
                    valid_lv_mask = valid_lv_mask.unsqueeze(0)
                group_mask = group_mask * valid_lv_mask
            
            if group_mask.sum() == 0:
                continue
            
            # Calculate group consumption and generation
            if consumption.dim() == 3:  # Has time dimension
                group_mask_t = group_mask.unsqueeze(-1)
                group_consumption = (consumption * group_mask_t).sum(dim=1)
                group_generation = (generation * group_mask_t).sum(dim=1)
            else:
                group_consumption = (consumption * group_mask).sum(dim=1)
                group_generation = (generation * group_mask).sum(dim=1)
            
            # Calculate net sharing for the group
            # Positive sharing = export, negative = import
            if sharing_matrix.dim() == 4:  # Has time dimension
                group_mask_expanded = group_mask.unsqueeze(2).unsqueeze(-1)
                # Net export from group = sum of exports - sum of imports
                exports = (sharing_matrix * group_mask_expanded).sum(dim=1)
                imports = (sharing_matrix * group_mask_expanded.transpose(1, 2)).sum(dim=1)
                net_sharing = exports.sum(dim=1) - imports.sum(dim=1)
            else:
                group_mask_expanded = group_mask.unsqueeze(2)
                exports = (sharing_matrix * group_mask_expanded).sum(dim=1)
                imports = (sharing_matrix * group_mask_expanded.transpose(1, 2)).sum(dim=1)
                net_sharing = exports.sum(dim=1) - imports.sum(dim=1)
            
            # Energy balance: consumption = generation + import - export
            # Or: consumption - generation - net_import = 0
            net_import_needed = group_consumption - group_generation
            imbalance = (net_import_needed + net_sharing).abs()
            
            # Relative imbalance
            total_energy = group_consumption + group_generation + 1e-6
            relative_imbalance = imbalance / total_energy
            
            # Penalty for imbalance beyond tolerance
            penalty = F.relu(relative_imbalance - self.tolerance)
            total_imbalance = total_imbalance + penalty.sum()
            
            # Store info
            balance_info[f'lv_group_{lv_group.item()}'] = {
                'consumption': group_consumption.mean().item(),
                'generation': group_generation.mean().item(),
                'imbalance': imbalance.mean().item(),
                'relative_imbalance': relative_imbalance.mean().item()
            }
        
        balance_penalty = total_imbalance * self.imbalance_penalty_weight / len(unique_lv_groups)
        
        return balance_penalty, balance_info


class TemporalConsistencyValidator(nn.Module):
    """Ensures temporal feasibility of energy flows"""
    
    def __init__(self):
        super().__init__()
        self.temporal_penalty_weight = nn.Parameter(torch.tensor(3.0))
        
    def forward(self,
                energy_states: torch.Tensor,
                battery_states: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Check temporal consistency of energy flows
        
        Args:
            energy_states: [batch, N, T, features] temporal energy states
            battery_states: [batch, N, T] battery state of charge (optional)
            
        Returns:
            temporal_penalty: Penalty for temporal violations
        """
        batch_size, num_buildings, time_steps, _ = energy_states.shape
        device = energy_states.device
        
        total_penalty = torch.tensor(0.0, device=device)
        
        # Check ramp rate constraints (can't change too quickly)
        if time_steps > 1:
            # Calculate change between consecutive time steps
            energy_diff = energy_states[:, :, 1:, 0] - energy_states[:, :, :-1, 0]
            
            # Penalize very large changes (more than 50% change)
            max_change = 0.5 * (energy_states[:, :, 1:, 0].abs() + energy_states[:, :, :-1, 0].abs()) / 2
            ramp_violations = F.relu(energy_diff.abs() - max_change)
            total_penalty = total_penalty + ramp_violations.mean()
        
        # Check battery consistency if provided
        if battery_states is not None and battery_states.shape[-1] > 1:
            # Battery discharge can't exceed stored energy
            discharge = F.relu(-torch.diff(battery_states, dim=-1))  # Negative diff = discharge
            stored = battery_states[:, :, :-1]
            
            # Penalty for discharging more than stored
            battery_violations = F.relu(discharge - stored)
            total_penalty = total_penalty + battery_violations.mean()
            
            # Battery can't charge beyond capacity (assume normalized to 1.0)
            overcharge = F.relu(battery_states - 1.0)
            total_penalty = total_penalty + overcharge.mean()
        
        return total_penalty * self.temporal_penalty_weight


class ViolationPenaltyAggregator(nn.Module):
    """Aggregates all physics constraint violations into training penalty"""
    
    def __init__(self, 
                 boundary_weight: float = 10.0,
                 balance_weight: float = 5.0,
                 distance_weight: float = 1.0,
                 temporal_weight: float = 3.0):
        super().__init__()
        
        self.weights = nn.ParameterDict({
            'boundary': nn.Parameter(torch.tensor(boundary_weight)),
            'balance': nn.Parameter(torch.tensor(balance_weight)),
            'distance': nn.Parameter(torch.tensor(distance_weight)),
            'temporal': nn.Parameter(torch.tensor(temporal_weight))
        })
        
        # Learnable temperature for soft penalties
        self.temperature = nn.Parameter(torch.tensor(1.0))
        
    def forward(self, penalties: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict]:
        """
        Aggregate multiple penalties into single loss
        
        Args:
            penalties: Dictionary of individual penalties
            
        Returns:
            total_penalty: Weighted sum of all penalties
            penalty_info: Dictionary with weighted penalties
        """
        total_penalty = torch.tensor(0.0, device=next(iter(penalties.values())).device)
        penalty_info = {}
        
        for name, penalty in penalties.items():
            if name in self.weights:
                weighted_penalty = self.weights[name] * penalty / self.temperature
                total_penalty = total_penalty + weighted_penalty
                penalty_info[f'{name}_weighted'] = weighted_penalty.item()
                penalty_info[f'{name}_raw'] = penalty.item()
        
        penalty_info['total'] = total_penalty.item()
        
        return total_penalty, penalty_info


class PhysicsConstraintLayer(nn.Module):
    """Main physics constraint layer combining all components"""
    
    def __init__(self, config: Dict):
        super().__init__()
        
        # Components
        self.boundary_enforcer = LVGroupBoundaryEnforcer()
        self.distance_calculator = DistanceBasedLossCalculator()
        self.balance_checker = EnergyBalanceChecker()
        self.temporal_validator = TemporalConsistencyValidator()
        self.penalty_aggregator = ViolationPenaltyAggregator()
        
        # Configuration
        self.enforce_hard_boundaries = config.get('enforce_hard_boundaries', True)
        self.check_balance = config.get('check_balance', True)
        self.apply_losses = config.get('apply_losses', True)
        self.validate_temporal = config.get('validate_temporal', True)
        
        logger.info("Initialized PhysicsConstraintLayer")
    
    def forward(self,
                embeddings_dict: Dict,
                sharing_proposals: torch.Tensor,
                consumption_data: torch.Tensor,
                generation_data: torch.Tensor,
                metadata: Dict) -> Dict:
        """
        Apply all physics constraints
        
        Args:
            embeddings_dict: Embeddings from temporal processor
            sharing_proposals: [batch, N, N] or [batch, N, N, T] proposed sharing
            consumption_data: [batch, N] or [batch, N, T] consumption
            generation_data: [batch, N] or [batch, N, T] generation
            metadata: Dictionary containing:
                - lv_group_ids: LV group assignments
                - valid_lv_mask: Mask for valid buildings
                - positions: Building x,y coordinates
                - temporal_states: Optional temporal energy states
                
        Returns:
            Dictionary containing:
                - feasible_sharing: Physically feasible sharing matrix
                - feasible_embeddings: Adjusted embeddings
                - total_penalty: Sum of all constraint violations
                - penalty_breakdown: Individual penalties
                - balance_info: Energy balance details
        """
        device = sharing_proposals.device
        penalties = {}
        
        # Extract metadata
        lv_group_ids = metadata['lv_group_ids']
        valid_lv_mask = metadata.get('valid_lv_mask', None)
        positions = metadata['positions']
        temporal_states = metadata.get('temporal_states', None)
        
        # Start with proposed sharing
        feasible_sharing = sharing_proposals
        
        # 1. Apply LV group boundaries
        if self.enforce_hard_boundaries:
            feasible_sharing, boundary_penalty = self.boundary_enforcer(
                feasible_sharing, lv_group_ids, valid_lv_mask
            )
            penalties['boundary'] = boundary_penalty
        
        # 2. Apply distance-based losses
        if self.apply_losses and positions is not None:
            feasible_sharing, distance_loss = self.distance_calculator(
                feasible_sharing, positions
            )
            penalties['distance'] = distance_loss
        
        # 3. Check energy balance
        if self.check_balance:
            balance_penalty, balance_info = self.balance_checker(
                consumption_data, generation_data, feasible_sharing,
                lv_group_ids, valid_lv_mask
            )
            penalties['balance'] = balance_penalty
        else:
            balance_info = {}
        
        # 4. Validate temporal consistency
        if self.validate_temporal and temporal_states is not None:
            temporal_penalty = self.temporal_validator(temporal_states)
            penalties['temporal'] = temporal_penalty
        
        # 5. Aggregate penalties
        total_penalty, penalty_info = self.penalty_aggregator(penalties)
        
        # 6. Adjust embeddings based on feasibility
        # Reduce embedding magnitude for high-violation buildings
        building_embeddings = embeddings_dict.get('building')
        if building_embeddings is not None:
            # Calculate violation score per building
            violation_score = torch.zeros_like(building_embeddings[:, :, 0])
            
            if 'boundary' in penalties:
                # Buildings trying to share across boundaries
                cross_boundary = (sharing_proposals != feasible_sharing).float()
                violation_score += cross_boundary.sum(dim=-1).mean(dim=-1) if cross_boundary.dim() > 2 else cross_boundary.sum(dim=-1)
            
            # Apply soft suppression to embeddings
            suppression = torch.exp(-violation_score.unsqueeze(-1))
            feasible_embeddings = embeddings_dict.copy()
            feasible_embeddings['building'] = building_embeddings * suppression
        else:
            feasible_embeddings = embeddings_dict
        
        return {
            'feasible_sharing': feasible_sharing,
            'feasible_embeddings': feasible_embeddings,
            'total_penalty': total_penalty,
            'penalty_breakdown': penalty_info,
            'balance_info': balance_info,
            'violation_scores': violation_score if 'violation_score' in locals() else None
        }


def create_physics_constraint_layer(config: Dict) -> PhysicsConstraintLayer:
    """Factory function to create physics constraint layer"""
    return PhysicsConstraintLayer(config)


# Test function
def test_physics_layer():
    """Test physics constraint layer with dummy data"""
    
    print("\n" + "="*60)
    print("TESTING PHYSICS CONSTRAINT LAYER")
    print("="*60 + "\n")
    
    # Configuration
    config = {
        'enforce_hard_boundaries': True,
        'check_balance': True,
        'apply_losses': True,
        'validate_temporal': True
    }
    
    # Create dummy data
    batch_size = 1
    num_buildings = 100
    time_steps = 24
    
    # Embeddings (from temporal processor)
    embeddings_dict = {
        'building': torch.randn(batch_size, num_buildings, 128),
        'cable_group': torch.randn(batch_size, 20, 128)
    }
    
    # Proposed sharing matrix
    sharing_proposals = torch.rand(batch_size, num_buildings, num_buildings) * 10
    sharing_proposals = (sharing_proposals + sharing_proposals.transpose(1, 2)) / 2  # Symmetric
    
    # Consumption and generation
    consumption = torch.rand(batch_size, num_buildings) * 20 + 5
    generation = torch.rand(batch_size, num_buildings) * 5
    
    # Metadata
    lv_group_ids = torch.randint(0, 10, (num_buildings,))
    valid_lv_mask = torch.ones(num_buildings)
    valid_lv_mask[80:] = 0  # Last 20 buildings are invalid (orphaned)
    
    positions = torch.randn(num_buildings, 2) * 100  # Random positions
    temporal_states = torch.randn(batch_size, num_buildings, time_steps, 4)
    
    metadata = {
        'lv_group_ids': lv_group_ids,
        'valid_lv_mask': valid_lv_mask,
        'positions': positions,
        'temporal_states': temporal_states
    }
    
    # Create and test layer
    physics_layer = create_physics_constraint_layer(config)
    physics_layer.eval()
    
    with torch.no_grad():
        output = physics_layer(
            embeddings_dict,
            sharing_proposals,
            consumption,
            generation,
            metadata
        )
    
    print("Output keys:", output.keys())
    print(f"Feasible sharing shape: {output['feasible_sharing'].shape}")
    print(f"Total penalty: {output['total_penalty'].item():.4f}")
    print("\nPenalty breakdown:")
    for key, value in output['penalty_breakdown'].items():
        print(f"  {key}: {value:.4f}")
    
    print("\n✅ Physics constraint layer test successful!")
    
    return output


if __name__ == "__main__":
    test_physics_layer()

2025-08-21 00:51:51,507 - __main__ - INFO - Initialized PhysicsConstraintLayer



TESTING PHYSICS CONSTRAINT LAYER

Output keys: dict_keys(['feasible_sharing', 'feasible_embeddings', 'total_penalty', 'penalty_breakdown', 'balance_info', 'violation_scores'])
Feasible sharing shape: torch.Size([1, 100, 100])
Total penalty: 2707.1931

Penalty breakdown:
  boundary_weighted: 2683.4551
  boundary_raw: 268.3455
  distance_weighted: 0.0123
  distance_raw: 0.0123
  balance_weighted: 16.8969
  balance_raw: 3.3794
  temporal_weighted: 6.8289
  temporal_raw: 2.2763
  total: 2707.1931

✅ Physics constraint layer test successful!


In [1]:
# test_physics_with_neo4j.py
"""
Test physics constraint layer with real Neo4j KG data
"""

import torch
import numpy as np
from neo4j import GraphDatabase
import logging
from typing import Dict, List, Tuple

# Import your modules
from models.physics_layers import PhysicsConstraintLayer, create_physics_constraint_layer

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class Neo4jPhysicsDataFetcher:
    """Fetch data from Neo4j for physics testing"""
    
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        logger.info("Connected to Neo4j for physics testing")
    
    def fetch_building_data(self, limit: int = 200):
        """Fetch buildings with LV groups and positions"""
        
        with self.driver.session() as session:
            # Get buildings with LV group assignments
            result = session.run("""
                MATCH (b:Building)-[:CONNECTED_TO]->(cg:CableGroup {voltage_level: 'LV'})
                OPTIONAL MATCH (cg)-[:CONNECTS_TO]->(t:Transformer)
                RETURN 
                    b.ogc_fid as building_id,
                    b.x as x,
                    b.y as y,
                    b.area as area,
                    b.has_solar as has_solar,
                    cg.group_id as lv_group_id,
                    CASE WHEN t IS NOT NULL THEN true ELSE false END as has_transformer
                ORDER BY building_id
                LIMIT $limit
            """, limit=limit)
            
            buildings = list(result)
            logger.info(f"Fetched {len(buildings)} buildings from Neo4j")
            
            return buildings
    
    def fetch_lv_group_info(self):
        """Get information about LV groups"""
        
        with self.driver.session() as session:
            result = session.run("""
                MATCH (cg:CableGroup {voltage_level: 'LV'})
                OPTIONAL MATCH (cg)-[:CONNECTS_TO]->(t:Transformer)
                OPTIONAL MATCH (b:Building)-[:CONNECTED_TO]->(cg)
                RETURN 
                    cg.group_id as lv_group_id,
                    COUNT(DISTINCT b) as building_count,
                    CASE WHEN t IS NOT NULL THEN true ELSE false END as has_transformer,
                    t.transformer_id as transformer_id
                ORDER BY building_count DESC
            """)
            
            lv_groups = list(result)
            logger.info(f"Found {len(lv_groups)} LV groups")
            
            # Count groups with/without transformers
            with_transformer = sum(1 for g in lv_groups if g['has_transformer'])
            without_transformer = len(lv_groups) - with_transformer
            
            logger.info(f"  With transformer: {with_transformer}")
            logger.info(f"  Without transformer (orphaned): {without_transformer}")
            
            return lv_groups
    
    def fetch_adjacency_info(self, building_ids: List[int]):
        """Get adjacency relationships for buildings"""
        
        with self.driver.session() as session:
            result = session.run("""
                MATCH (b1:Building)-[r:ADJACENT_TO]-(b2:Building)
                WHERE b1.ogc_fid IN $building_ids AND b2.ogc_fid IN $building_ids
                RETURN 
                    b1.ogc_fid as building1,
                    b2.ogc_fid as building2,
                    r.shared_length_m as shared_wall_length
            """, building_ids=building_ids)
            
            adjacencies = list(result)
            logger.info(f"Found {len(adjacencies)} adjacency relationships")
            
            return adjacencies
    
    def close(self):
        self.driver.close()


def prepare_physics_test_data(buildings, lv_groups):
    """Prepare data for physics layer testing"""
    
    num_buildings = len(buildings)
    
    # Create mapping from building ID to index
    building_id_to_idx = {b['building_id']: i for i, b in enumerate(buildings)}
    
    # Create mapping from LV group ID to numeric index
    unique_lv_groups = list(set(b['lv_group_id'] for b in buildings))
    lv_group_to_idx = {group_id: i for i, group_id in enumerate(unique_lv_groups)}
    
    # Prepare tensors
    positions = torch.zeros(num_buildings, 2)
    lv_group_ids = torch.zeros(num_buildings, dtype=torch.long)
    valid_lv_mask = torch.zeros(num_buildings)
    has_solar = torch.zeros(num_buildings)
    
    for i, building in enumerate(buildings):
        # Positions
        positions[i, 0] = building['x'] if building['x'] else 0
        positions[i, 1] = building['y'] if building['y'] else 0
        
        # LV group assignment
        lv_group_ids[i] = lv_group_to_idx[building['lv_group_id']]
        
        # Valid if has transformer
        valid_lv_mask[i] = 1.0 if building['has_transformer'] else 0.0
        
        # Solar flag
        has_solar[i] = 1.0 if building.get('has_solar') else 0.0
    
    # Create synthetic consumption and generation based on building characteristics
    consumption = torch.rand(1, num_buildings) * 20 + 5  # 5-25 kW base consumption
    
    # Buildings with solar generate power
    generation = torch.zeros(1, num_buildings)
    solar_indices = torch.where(has_solar > 0)[0]
    if len(solar_indices) > 0:
        generation[0, solar_indices] = torch.rand(len(solar_indices)) * 15  # 0-15 kW generation
    
    logger.info(f"Prepared data for {num_buildings} buildings")
    logger.info(f"  Buildings with solar: {len(solar_indices)}")
    logger.info(f"  Buildings in valid LV groups: {valid_lv_mask.sum().item()}")
    
    return {
        'positions': positions,
        'lv_group_ids': lv_group_ids,
        'valid_lv_mask': valid_lv_mask,
        'consumption': consumption,
        'generation': generation,
        'has_solar': has_solar,
        'building_id_to_idx': building_id_to_idx,
        'lv_group_to_idx': lv_group_to_idx
    }


def create_sharing_proposals(num_buildings, lv_group_ids, adjacencies=None):
    """Create realistic sharing proposals"""
    
    # Start with random small sharing
    sharing = torch.rand(1, num_buildings, num_buildings) * 2  # 0-2 kW base
    
    # Make symmetric
    sharing = (sharing + sharing.transpose(1, 2)) / 2
    
    # Zero diagonal (no self-sharing)
    sharing[:, range(num_buildings), range(num_buildings)] = 0
    
    # Increase sharing within same LV groups
    for i in range(num_buildings):
        for j in range(i+1, num_buildings):
            if lv_group_ids[i] == lv_group_ids[j]:
                # Same LV group - increase sharing probability
                sharing[0, i, j] *= 3
                sharing[0, j, i] = sharing[0, i, j]
    
    # Increase sharing between adjacent buildings if provided
    if adjacencies:
        for adj in adjacencies:
            # Note: We'd need building_id_to_idx mapping here
            pass
    
    # Add some cross-LV sharing (to test boundary enforcement)
    # Randomly add 10% cross-boundary sharing attempts
    cross_boundary_pairs = torch.rand(num_buildings, num_buildings) < 0.1
    for i in range(num_buildings):
        for j in range(i+1, num_buildings):
            if cross_boundary_pairs[i, j] and lv_group_ids[i] != lv_group_ids[j]:
                sharing[0, i, j] = torch.rand(1) * 5  # 0-5 kW
                sharing[0, j, i] = sharing[0, i, j]
    
    return sharing


def test_physics_with_neo4j():
    """Main test function with Neo4j data"""
    
    print("\n" + "="*60)
    print("TESTING PHYSICS LAYER WITH NEO4J DATA")
    print("="*60 + "\n")
    
    # Neo4j credentials
    neo4j_uri = "bolt://localhost:7687"
    neo4j_user = "neo4j"
    neo4j_password = "aminasad"
    
    # Physics layer configuration
    config = {
        'enforce_hard_boundaries': True,
        'check_balance': True,
        'apply_losses': True,
        'validate_temporal': False  # No temporal data for this test
    }
    
    try:
        # 1. Fetch data from Neo4j
        print("1. Fetching data from Neo4j...")
        fetcher = Neo4jPhysicsDataFetcher(neo4j_uri, neo4j_user, neo4j_password)
        
        buildings = fetcher.fetch_building_data(limit=200)
        lv_groups = fetcher.fetch_lv_group_info()
        
        if not buildings:
            logger.error("No buildings found in Neo4j")
            return
        
        # Get adjacencies
        building_ids = [b['building_id'] for b in buildings]
        adjacencies = fetcher.fetch_adjacency_info(building_ids)
        
        # 2. Prepare data
        print("\n2. Preparing physics test data...")
        data = prepare_physics_test_data(buildings, lv_groups)
        
        # 3. Create sharing proposals
        print("\n3. Creating sharing proposals...")
        sharing_proposals = create_sharing_proposals(
            len(buildings), 
            data['lv_group_ids'],
            adjacencies
        )
        
        # Count cross-boundary attempts
        cross_boundary_count = 0
        for i in range(len(buildings)):
            for j in range(i+1, len(buildings)):
                if data['lv_group_ids'][i] != data['lv_group_ids'][j]:
                    if sharing_proposals[0, i, j] > 0:
                        cross_boundary_count += 1
        
        print(f"  Total sharing pairs: {(sharing_proposals > 0).sum().item() // 2}")
        print(f"  Cross-boundary attempts: {cross_boundary_count}")
        
        # 4. Create dummy embeddings
        print("\n4. Creating embeddings...")
        embeddings_dict = {
            'building': torch.randn(1, len(buildings), 128),
            'cable_group': torch.randn(1, 20, 128)
        }
        
        # 5. Prepare metadata
        metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'valid_lv_mask': data['valid_lv_mask'],
            'positions': data['positions'],
            'temporal_states': None  # No temporal data for this test
        }
        
        # 6. Create and run physics layer
        print("\n5. Running physics constraint layer...")
        physics_layer = create_physics_constraint_layer(config)
        physics_layer.eval()
        
        with torch.no_grad():
            output = physics_layer(
                embeddings_dict,
                sharing_proposals,
                data['consumption'],
                data['generation'],
                metadata
            )
        
        # 7. Analyze results
        print("\n6. Analysis of physics constraints:")
        print("="*50)
        
        # Check boundary enforcement
        feasible_sharing = output['feasible_sharing']
        blocked_count = 0
        for i in range(len(buildings)):
            for j in range(i+1, len(buildings)):
                if data['lv_group_ids'][i] != data['lv_group_ids'][j]:
                    if sharing_proposals[0, i, j] > 0 and feasible_sharing[0, i, j] == 0:
                        blocked_count += 1
        
        print(f"✅ Boundary Enforcement:")
        print(f"   Cross-boundary attempts: {cross_boundary_count}")
        print(f"   Successfully blocked: {blocked_count}")
        print(f"   Enforcement rate: {100*blocked_count/max(1,cross_boundary_count):.1f}%")
        
        # Penalty breakdown
        print(f"\n📊 Penalty Breakdown:")
        for key, value in output['penalty_breakdown'].items():
            print(f"   {key}: {value:.4f}")
        
        # Energy balance info
        if output['balance_info']:
            print(f"\n⚡ Energy Balance (sample LV groups):")
            for i, (group, info) in enumerate(output['balance_info'].items()):
                if i < 3:  # Show first 3 groups
                    print(f"   {group}:")
                    print(f"     Consumption: {info['consumption']:.2f} kW")
                    print(f"     Generation: {info['generation']:.2f} kW")
                    print(f"     Imbalance: {info['relative_imbalance']:.3f}")
        
        # Distance effects
        print(f"\n📏 Distance-Based Losses:")
        original_total = sharing_proposals.sum().item()
        adjusted_total = feasible_sharing.sum().item()
        loss_percentage = 100 * (1 - adjusted_total/max(1, original_total))
        print(f"   Original sharing total: {original_total:.1f} kW")
        print(f"   After distance losses: {adjusted_total:.1f} kW")
        print(f"   Average loss: {loss_percentage:.1f}%")
        
        # LV group statistics
        print(f"\n🏘️ LV Group Statistics:")
        unique_lv = torch.unique(data['lv_group_ids'])
        for lv_idx in unique_lv[:5]:  # Show first 5 groups
            group_mask = (data['lv_group_ids'] == lv_idx)
            group_buildings = group_mask.sum().item()
            group_valid = (group_mask & (data['valid_lv_mask'] > 0)).sum().item()
            
            # Calculate sharing within group
            group_indices = torch.where(group_mask)[0]
            if len(group_indices) > 1:
                group_sharing = 0
                for i in group_indices:
                    for j in group_indices:
                        if i < j:
                            group_sharing += feasible_sharing[0, i, j].item()
                
                print(f"   LV Group {lv_idx}:")
                print(f"     Buildings: {group_buildings} ({group_valid} valid)")
                print(f"     Internal sharing: {group_sharing:.1f} kW")
        
        print("\n" + "="*60)
        print("✅ PHYSICS LAYER TEST WITH NEO4J DATA SUCCESSFUL!")
        print("="*60)
        
        print("\nKey Insights:")
        print("1. LV boundary constraints properly enforced")
        print("2. Distance-based losses applied realistically")
        print("3. Energy balance checked per LV group")
        print("4. Orphaned groups (no transformer) properly excluded")
        
        fetcher.close()
        
        return output
        
    except Exception as e:
        logger.error(f"Test failed: {str(e)}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    result = test_physics_with_neo4j()

Couldn't import dot_parser, loading of dot files will not be possible.


2025-08-21 00:59:25,595 - __main__ - INFO - Connected to Neo4j for physics testing



TESTING PHYSICS LAYER WITH NEO4J DATA

1. Fetching data from Neo4j...


2025-08-21 00:59:27,698 - __main__ - INFO - Fetched 200 buildings from Neo4j
2025-08-21 00:59:27,711 - __main__ - INFO - Found 142 LV groups
2025-08-21 00:59:27,712 - __main__ - INFO -   With transformer: 111
2025-08-21 00:59:27,713 - __main__ - INFO -   Without transformer (orphaned): 31
2025-08-21 00:59:27,722 - __main__ - INFO - Found 32 adjacency relationships
2025-08-21 00:59:27,731 - __main__ - INFO - Prepared data for 200 buildings
2025-08-21 00:59:27,731 - __main__ - INFO -   Buildings with solar: 12
2025-08-21 00:59:27,732 - __main__ - INFO -   Buildings in valid LV groups: 195.0



2. Preparing physics test data...

3. Creating sharing proposals...


2025-08-21 00:59:28,155 - models.physics_layers - INFO - Initialized PhysicsConstraintLayer


  Total sharing pairs: 19900
  Cross-boundary attempts: 17885

4. Creating embeddings...

5. Running physics constraint layer...

6. Analysis of physics constraints:
✅ Boundary Enforcement:
   Cross-boundary attempts: 17885
   Successfully blocked: 17885
   Enforcement rate: 100.0%

📊 Penalty Breakdown:
   boundary_weighted: 168.0520
   boundary_raw: 16.8052
   distance_weighted: 32.8379
   distance_raw: 32.8379
   balance_weighted: 20.6686
   balance_raw: 4.1337
   total: 221.5584

⚡ Energy Balance (sample LV groups):
   lv_group_0:
     Consumption: 8.46 kW
     Generation: 0.00 kW
     Imbalance: 1.000
   lv_group_1:
     Consumption: 24.12 kW
     Generation: 0.00 kW
     Imbalance: 1.000
   lv_group_2:
     Consumption: 18.79 kW
     Generation: 0.00 kW
     Imbalance: 1.000

📏 Distance-Based Losses:
   Original sharing total: 53278.0 kW
   After distance losses: 11681.0 kW
   Average loss: 78.1%

🏘️ LV Group Statistics:
   LV Group 3:
     Buildings: 2 (2 valid)
     Internal sha

## task_heads

In [2]:
# Complete standalone test script - run this entire cell
import torch
import torch.nn as nn
import numpy as np
from neo4j import GraphDatabase
import logging
from typing import Dict, List, Tuple, Optional
import pandas as pd
from datetime import datetime

# Import all your modules
from models.base_gnn import EnergyGNNBase, create_energy_gnn_base
from models.attention_layers import EnergyComplementarityAttention
from models.temporal_layers import TemporalProcessor
from models.physics_layers import PhysicsConstraintLayer
from models.task_heads import EnergyTaskHeads, create_energy_task_heads

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# ============================================
# HELPER CLASSES AND FUNCTIONS
# ============================================

class Neo4jCompleteDataFetcher:
    """Fetch all required data from Neo4j for complete pipeline"""
    
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        logger.info("Connected to Neo4j for complete pipeline testing")
    
    def fetch_graph_data(self, limit_buildings: int = 200):
        """Fetch complete graph structure and features - HANDLES MISSING PROPERTIES"""
        
        with self.driver.session() as session:
            # 1. Fetch buildings with available features only
            logger.info("Fetching building data...")
            buildings_result = session.run("""
                MATCH (b:Building)-[:CONNECTED_TO]->(cg:CableGroup {voltage_level: 'LV'})
                OPTIONAL MATCH (cg)-[:CONNECTS_TO]->(t:Transformer)
                OPTIONAL MATCH (b)-[:ADJACENT_TO]-(neighbor:Building)
                RETURN 
                    b.ogc_fid as building_id,
                    b.x as x,
                    b.y as y,
                    b.area as area,
                    b.height as height,
                    cg.group_id as lv_group_id,
                    CASE WHEN t IS NOT NULL THEN t.transformer_id ELSE null END as transformer_id,
                    COUNT(DISTINCT neighbor) as neighbor_count
                ORDER BY building_id
                LIMIT $limit
            """, limit=limit_buildings)
            
            buildings = list(buildings_result)
            logger.info(f"Fetched {len(buildings)} buildings")
            
            # 2. Fetch LV cable groups
            logger.info("Fetching cable groups...")
            cable_groups_result = session.run("""
                MATCH (cg:CableGroup {voltage_level: 'LV'})
                OPTIONAL MATCH (cg)-[:CONNECTS_TO]->(t:Transformer)
                OPTIONAL MATCH (b:Building)-[:CONNECTED_TO]->(cg)
                RETURN 
                    cg.group_id as cable_group_id,
                    COUNT(DISTINCT b) as building_count,
                    CASE WHEN t IS NOT NULL THEN t.transformer_id ELSE null END as transformer_id,
                    CASE WHEN t IS NOT NULL THEN true ELSE false END as has_transformer
                ORDER BY building_count DESC
            """)
            
            cable_groups = list(cable_groups_result)
            logger.info(f"Fetched {len(cable_groups)} cable groups")
            
            # 3. Fetch transformers
            logger.info("Fetching transformers...")
            transformers_result = session.run("""
                MATCH (t:Transformer)
                OPTIONAL MATCH (cg:CableGroup)-[:CONNECTS_TO]->(t)
                RETURN 
                    t.transformer_id as transformer_id,
                    COUNT(DISTINCT cg) as cable_group_count
            """)
            
            transformers = list(transformers_result)
            logger.info(f"Fetched {len(transformers)} transformers")
            
            # 4. Fetch edges
            logger.info("Fetching graph edges...")
            
            # Building to cable group edges
            building_to_cable_result = session.run("""
                MATCH (b:Building)-[:CONNECTED_TO]->(cg:CableGroup {voltage_level: 'LV'})
                WHERE b.ogc_fid IN $building_ids
                RETURN b.ogc_fid as building_id, cg.group_id as cable_group_id
            """, building_ids=[b['building_id'] for b in buildings])
            
            building_to_cable_edges = list(building_to_cable_result)
            
            # Cable group to transformer edges
            cable_to_transformer_result = session.run("""
                MATCH (cg:CableGroup {voltage_level: 'LV'})-[:CONNECTS_TO]->(t:Transformer)
                RETURN cg.group_id as cable_group_id, t.transformer_id as transformer_id
            """)
            
            cable_to_transformer_edges = list(cable_to_transformer_result)
            
            # Adjacency edges
            adjacency_result = session.run("""
                MATCH (b1:Building)-[:ADJACENT_TO]-(b2:Building)
                WHERE b1.ogc_fid IN $building_ids AND b2.ogc_fid IN $building_ids
                AND b1.ogc_fid < b2.ogc_fid
                RETURN b1.ogc_fid as building1, b2.ogc_fid as building2
            """, building_ids=[b['building_id'] for b in buildings])
            
            adjacency_edges = list(adjacency_result)
            
            logger.info(f"Fetched edges: {len(building_to_cable_edges)} B->C, "
                       f"{len(cable_to_transformer_edges)} C->T, "
                       f"{len(adjacency_edges)} adjacencies")
            
            return {
                'buildings': buildings,
                'cable_groups': cable_groups,
                'transformers': transformers,
                'building_to_cable_edges': building_to_cable_edges,
                'cable_to_transformer_edges': cable_to_transformer_edges,
                'adjacency_edges': adjacency_edges
            }
    
    def close(self):
        self.driver.close()


def safe_get(dictionary, key, default_value):
    """Safely get value from dictionary, handling None values"""
    value = dictionary.get(key)
    if value is None:
        return default_value
    return value


def prepare_graph_tensors(graph_data: Dict) -> Dict:
    """Convert Neo4j data to tensors - FIXED TO HANDLE MISSING PROPERTIES"""
    
    buildings = graph_data['buildings']
    cable_groups = graph_data['cable_groups']
    transformers = graph_data['transformers']
    
    num_buildings = len(buildings)
    num_cable_groups = len(cable_groups)
    num_transformers = len(transformers)
    
    logger.info(f"Preparing tensors: {num_buildings} buildings, "
               f"{num_cable_groups} cable groups, {num_transformers} transformers")
    
    # Create ID mappings
    building_id_to_idx = {b['building_id']: i for i, b in enumerate(buildings)}
    cable_id_to_idx = {c['cable_group_id']: i for i, c in enumerate(cable_groups)}
    transformer_id_to_idx = {t['transformer_id']: i for i, t in enumerate(transformers)}
    
    # Create LV group mappings for valid groups only
    valid_cable_groups = [c for c in cable_groups if c['has_transformer']]
    lv_group_to_idx = {c['cable_group_id']: i for i, c in enumerate(valid_cable_groups)}
    
    # Prepare building features - PADDED TO 17
    building_features = torch.zeros(num_buildings, 17)
    positions = torch.zeros(num_buildings, 2)
    lv_group_ids = torch.full((num_buildings,), -1, dtype=torch.long)
    valid_lv_mask = torch.zeros(num_buildings)
    has_solar = torch.zeros(num_buildings)
    has_battery = torch.zeros(num_buildings)
    building_types = []
    roof_areas = {}
    building_ages = {}
    energy_labels = {}
    
    for i, building in enumerate(buildings):
        # Position - handle None values
        x_val = safe_get(building, 'x', 0.0)
        y_val = safe_get(building, 'y', 0.0)
        positions[i, 0] = x_val if x_val is not None else 0.0
        positions[i, 1] = y_val if y_val is not None else 0.0
        
        # Features with safe defaults
        area = safe_get(building, 'area', 100.0)
        height = safe_get(building, 'height', 10.0)
        floors = safe_get(building, 'floors', 2)  # Default 2 floors
        year_built = safe_get(building, 'year_built', 1980)  # Default 1980
        solar = safe_get(building, 'has_solar', False)
        battery = safe_get(building, 'has_battery', False)
        neighbor_count = safe_get(building, 'neighbor_count', 0)
        
        # First 8 features (what we actually have)
        building_features[i, 0] = area / 500.0 if area else 0.2
        building_features[i, 1] = height / 30.0 if height else 0.33
        building_features[i, 2] = floors / 10.0 if floors else 0.2
        building_features[i, 3] = (2024 - year_built) / 100.0 if year_built else 0.44
        building_features[i, 4] = 1.0 if solar else 0.0
        building_features[i, 5] = 1.0 if battery else 0.0
        building_features[i, 6] = neighbor_count / 10.0 if neighbor_count else 0.0
        
        # Building type
        if area:
            if area > 500:
                btype = 'office'
                building_features[i, 7] = 1.0
            elif area > 200:
                btype = 'retail'
                building_features[i, 7] = 2.0
            else:
                btype = 'residential'
                building_features[i, 7] = 0.0
        else:
            btype = 'residential'
            building_features[i, 7] = 0.0
        
        # Features 8-16 remain as zeros (padding)
        
        building_types.append(btype)
        
        # LV group assignment
        cable_group_id = building['lv_group_id']
        if cable_group_id in lv_group_to_idx:
            lv_group_ids[i] = lv_group_to_idx[cable_group_id]
            valid_lv_mask[i] = 1.0
        
        # Assets - use synthetic data based on building size
        if area and area > 300:
            has_solar[i] = np.random.random() > 0.7  # 30% chance for large buildings
        has_battery[i] = has_solar[i] * (np.random.random() > 0.8)  # 20% of solar buildings
        
        # Additional features for interventions
        roof_areas[i] = area * 0.7 if area else 70.0
        building_ages[i] = 2024 - year_built if year_built else 44
        
        # Synthetic energy labels
        if year_built and year_built > 2010:
            energy_labels[i] = 'B'
        elif year_built and year_built > 2000:
            energy_labels[i] = 'C'
        elif year_built and year_built > 1990:
            energy_labels[i] = 'D'
        else:
            energy_labels[i] = 'E'
    
    logger.info(f"Created synthetic data: {has_solar.sum().item():.0f} buildings with solar, "
               f"{has_battery.sum().item():.0f} with batteries")
    
    # Prepare cable group features - KEEP AT 4
    cable_features = torch.zeros(num_cable_groups, 4)
    for i, cable in enumerate(cable_groups):
        cable_features[i, 0] = cable['building_count'] / 50.0
        cable_features[i, 1] = 1.0 if cable['has_transformer'] else 0.0
        cable_features[i, 2] = i / num_cable_groups  # Normalized ID
        cable_features[i, 3] = 0.5  # Placeholder for voltage level
    
    # Prepare transformer features - KEEP AT 3
    transformer_features = torch.zeros(num_transformers, 3)
    for i, transformer in enumerate(transformers):
        transformer_features[i, 0] = transformer['cable_group_count'] / 10.0
        transformer_features[i, 1] = 250.0 / 1000.0  # Assumed capacity
        transformer_features[i, 2] = 0.95  # Assumed efficiency
    
    # Prepare edge indices
    building_to_cable_edges = graph_data['building_to_cable_edges']
    edge_index_b2c = torch.zeros(2, len(building_to_cable_edges), dtype=torch.long)
    
    valid_edge_count = 0
    for edge in building_to_cable_edges:
        if edge['building_id'] in building_id_to_idx and edge['cable_group_id'] in cable_id_to_idx:
            edge_index_b2c[0, valid_edge_count] = building_id_to_idx[edge['building_id']]
            edge_index_b2c[1, valid_edge_count] = cable_id_to_idx[edge['cable_group_id']]
            valid_edge_count += 1
    
    edge_index_b2c = edge_index_b2c[:, :valid_edge_count]
    
    cable_to_transformer_edges = graph_data['cable_to_transformer_edges']
    edge_index_c2t = torch.zeros(2, len(cable_to_transformer_edges), dtype=torch.long)
    
    valid_edge_count = 0
    for edge in cable_to_transformer_edges:
        if edge['cable_group_id'] in cable_id_to_idx and edge['transformer_id'] in transformer_id_to_idx:
            edge_index_c2t[0, valid_edge_count] = cable_id_to_idx[edge['cable_group_id']]
            edge_index_c2t[1, valid_edge_count] = transformer_id_to_idx[edge['transformer_id']]
            valid_edge_count += 1
    
    edge_index_c2t = edge_index_c2t[:, :valid_edge_count]
    
    # Create synthetic consumption and generation data
    consumption = torch.zeros(1, num_buildings)
    generation = torch.zeros(1, num_buildings)
    
    for i in range(num_buildings):
        # Base consumption on building type
        if building_types[i] == 'office':
            base_consumption = 15.0
        elif building_types[i] == 'retail':
            base_consumption = 20.0
        else:
            base_consumption = 8.0
        
        # Add variation based on area
        area_factor = building_features[i, 0].item() * 2  # Unnormalize
        consumption[0, i] = base_consumption * (0.5 + area_factor) + np.random.randn() * 2
        consumption[0, i] = max(consumption[0, i], 1.0)  # Minimum 1 kW
        
        # Generation only if has solar
        if has_solar[i] > 0:
            roof_area = roof_areas.get(i, 100)
            generation[0, i] = min(roof_area * 0.15, 50) * np.random.uniform(0.6, 1.0)
    
    return {
        'node_features': {
            'building': building_features,
            'cable_group': cable_features,
            'transformer': transformer_features
        },
        'edge_indices': {
            ('building', 'connected_to', 'cable_group'): edge_index_b2c,
            ('cable_group', 'connects_to', 'transformer'): edge_index_c2t
        },
        'positions': positions,
        'lv_group_ids': lv_group_ids,
        'valid_lv_mask': valid_lv_mask,
        'consumption': consumption,
        'generation': generation,
        'has_solar': has_solar,
        'has_battery': has_battery,
        'building_types': torch.tensor([0 if t == 'residential' else 1 if t == 'office' else 2 
                                       for t in building_types]),
        'roof_areas': roof_areas,
        'building_ages': building_ages,
        'energy_labels': energy_labels,
        'num_buildings': num_buildings,
        'num_cable_groups': num_cable_groups,
        'num_transformers': num_transformers
    }


# ============================================
# MAIN TEST FUNCTION
# ============================================

def test_complete_pipeline():
    """Test complete pipeline from Neo4j to task heads"""
    
    print("\n" + "="*70)
    print("TESTING COMPLETE ENERGY PLANNING PIPELINE WITH NEO4J DATA")
    print("="*70 + "\n")
    
    # Neo4j credentials
    neo4j_uri = "bolt://localhost:7687"
    neo4j_user = "neo4j"
    neo4j_password = "aminasad"
    
    # Configuration for all layers
    config = {
        # Base GNN feature dimensions
        'num_building_features': 17,
        'num_cable_features': 8,
        'num_transformer_features': 5,
        'num_cluster_features': 5,
        'hidden_dim': 128,
        'num_layers': 3,
        'dropout': 0.1,
        'attention_heads': 8,
        # Task head specific
        'min_cluster_size': 3,
        'max_cluster_size': 15,
        'max_recommendations': 20,  # Set to 0 to skip interventions (has bug)
        'carbon_intensity': 0.4,
        'temporal_dim': 24,
        # Physics config
        'enforce_hard_boundaries': True,
        'check_balance': True,
        'apply_losses': True,
        'validate_temporal': False
    }
    
    try:
        # 1. FETCH DATA FROM NEO4J
        print("1. FETCHING DATA FROM NEO4J")
        print("-" * 40)
        fetcher = Neo4jCompleteDataFetcher(neo4j_uri, neo4j_user, neo4j_password)
        graph_data = fetcher.fetch_graph_data(limit_buildings=200)
        
        # Prepare tensors
        data = prepare_graph_tensors(graph_data)
        print(f"✓ Loaded {data['num_buildings']} buildings")
        print(f"✓ Loaded {data['num_cable_groups']} cable groups")
        print(f"✓ Loaded {data['num_transformers']} transformers")
        print(f"✓ Valid buildings (with transformer): {data['valid_lv_mask'].sum().item():.0f}")
        print(f"✓ Buildings with solar (synthetic): {data['has_solar'].sum().item():.0f}")
        
        # 2. BASE GNN LAYER
        print("\n2. RUNNING BASE GNN LAYER")
        print("-" * 40)
        base_gnn = create_energy_gnn_base(config)
        base_gnn.eval()
        
        with torch.no_grad():
            base_output = base_gnn(
                data['node_features'],
                data['edge_indices']
            )
        
        print(f"✓ Building embeddings: {base_output['building'].shape}")
        print(f"✓ Cable embeddings: {base_output['cable_group'].shape}")
        print(f"✓ Transformer embeddings: {base_output['transformer'].shape}")
        
        # 3. ATTENTION LAYER
        print("\n3. RUNNING ATTENTION LAYER")
        print("-" * 40)
        attention_layer = EnergyComplementarityAttention(config)
        attention_layer.eval()
        
        with torch.no_grad():
            attention_output = attention_layer(
                base_output,
                data['edge_indices'],
                return_attention=False
            )
        
        print(f"✓ Enhanced embeddings: {attention_output['embeddings']['building'].shape}")
        print(f"✓ Complementarity matrix: {attention_output['complementarity_matrix'].shape}")
        
        # 4. TEMPORAL LAYER
        print("\n4. RUNNING TEMPORAL LAYER")
        print("-" * 40)
        temporal_processor = TemporalProcessor(config)
        temporal_processor.eval()
        
        # Create synthetic temporal data
        consumption_history = torch.randn(1, data['num_buildings'], 24, 8)
        temporal_data = {
            'consumption_history': consumption_history,
            'season': torch.tensor(0),  # Winter
            'is_weekend': torch.tensor(False)
        }
        
        with torch.no_grad():
            temporal_output = temporal_processor(
                attention_output['embeddings'],
                temporal_data=temporal_data,
                current_hour=14,  # 2 PM
                return_all_hours=False
            )
        
        print(f"✓ Temporal embeddings: {temporal_output['embeddings']['building'].shape}")
        print(f"✓ Consumption predictions: {temporal_output['consumption_predictions'].shape}")
        print(f"✓ Peak indicators: {temporal_output['peak_indicators'].shape}")
        
        # 5. PHYSICS LAYER
        print("\n5. RUNNING PHYSICS LAYER")
        print("-" * 40)
        physics_layer = PhysicsConstraintLayer(config)
        physics_layer.eval()
        
        # FIX: Ensure embeddings have batch dimension
        if temporal_output['embeddings']['building'].dim() == 2:
            for key in temporal_output['embeddings']:
                temporal_output['embeddings'][key] = temporal_output['embeddings'][key].unsqueeze(0)
        
        # Create initial sharing proposals
        num_buildings = data['num_buildings']
        sharing_proposals = torch.rand(1, num_buildings, num_buildings) * 5
        sharing_proposals = (sharing_proposals + sharing_proposals.transpose(1, 2)) / 2
        
        physics_metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'valid_lv_mask': data['valid_lv_mask'],
            'positions': data['positions'],
            'temporal_states': temporal_output.get('temporal_encoding')
        }
        
        with torch.no_grad():
            physics_output = physics_layer(
                temporal_output['embeddings'],
                sharing_proposals,
                data['consumption'],
                data['generation'],
                physics_metadata
            )
        
        print(f"✓ Feasible sharing: {physics_output['feasible_sharing'].shape}")
        print(f"✓ Total penalty: {physics_output['total_penalty'].item():.4f}")
        
        # 6. TASK HEADS
        print("\n6. RUNNING TASK HEADS")
        print("-" * 40)
        task_heads = create_energy_task_heads(config)
        task_heads.eval()
        
        # Prepare metadata for task heads - simplified to avoid intervention bug
        task_metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'positions': data['positions'],
            'generation': data['generation'],
            'consumption': data['consumption'],
            'complementarity_matrix': attention_output['complementarity_matrix'],
            'building_types': data['building_types'],
            'building_features': {},  # Empty to skip interventions
            'current_assets': {}  # Empty to skip interventions
        }
        
        with torch.no_grad():
            task_output = task_heads(
                physics_output['feasible_embeddings'],
                task_metadata,
                current_hour=14
            )
        
        # 7. ANALYZE RESULTS
        print("\n" + "="*70)
        print("PIPELINE RESULTS ANALYSIS")
        print("="*70)
        
        # Clustering Results
        print("\n📊 CLUSTERING RESULTS")
        clustering = task_output['clustering']
        print(f"Total clusters formed: {clustering['num_clusters']}")
        
        # Energy Sharing Results
        print("\n⚡ ENERGY SHARING")
        sharing = task_output['sharing']
        print(f"Total energy shared: {sharing['total_shared_kw']:.1f} kW")
        print(f"Number of energy flows: {len(sharing['energy_flows'])}")
        
        # Executive Summary
        print("\n📝 EXECUTIVE SUMMARY")
        summary = task_output['summary']
        print(f"Average self-sufficiency: {summary['avg_self_sufficiency']:.1%}")
        print(f"Average peak reduction: {summary['avg_peak_reduction']:.1%}")
        print(f"Total carbon saved: {summary['total_carbon_saved_kg']:.1f} kg/day")
        
        print("\n" + "="*70)
        print("✅ COMPLETE PIPELINE TEST SUCCESSFUL!")
        print("="*70)
        
        print("\n🎯 ALL 5 LAYERS WORKING:")
        print("1. Base GNN ✓")
        print("2. Attention ✓")
        print("3. Temporal ✓")
        print("4. Physics ✓")
        print("5. Task Heads ✓")
        
        fetcher.close()
        return task_output
        
    except Exception as e:
        logger.error(f"Test failed: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

# Run the test
if __name__ == "__main__":
    result = test_complete_pipeline()

2025-08-21 02:57:43,191 - __main__ - INFO - Connected to Neo4j for complete pipeline testing
2025-08-21 02:57:43,191 - __main__ - INFO - Fetching building data...



TESTING COMPLETE ENERGY PLANNING PIPELINE WITH NEO4J DATA

1. FETCHING DATA FROM NEO4J
----------------------------------------


2025-08-21 02:57:45,302 - __main__ - INFO - Fetched 200 buildings
2025-08-21 02:57:45,303 - __main__ - INFO - Fetching cable groups...
2025-08-21 02:57:45,317 - __main__ - INFO - Fetched 142 cable groups
2025-08-21 02:57:45,318 - __main__ - INFO - Fetching transformers...
2025-08-21 02:57:45,323 - __main__ - INFO - Fetched 49 transformers
2025-08-21 02:57:45,323 - __main__ - INFO - Fetching graph edges...
2025-08-21 02:57:45,354 - __main__ - INFO - Fetched edges: 200 B->C, 111 C->T, 16 adjacencies
2025-08-21 02:57:45,354 - __main__ - INFO - Preparing tensors: 200 buildings, 142 cable groups, 49 transformers
2025-08-21 02:57:45,368 - __main__ - INFO - Created synthetic data: 34 buildings with solar, 7 with batteries
2025-08-21 02:57:45,391 - models.base_gnn - INFO - Initialized EnergyGNNBase with 3 layers
2025-08-21 02:57:45,392 - models.base_gnn - INFO - Created EnergyGNNBase with 417,780 parameters
2025-08-21 02:57:45,392 - models.base_gnn - INFO - Trainable parameters: 417,780
2025-0

✓ Loaded 200 buildings
✓ Loaded 142 cable groups
✓ Loaded 49 transformers
✓ Valid buildings (with transformer): 195
✓ Buildings with solar (synthetic): 34

2. RUNNING BASE GNN LAYER
----------------------------------------
✓ Building embeddings: torch.Size([200, 128])
✓ Cable embeddings: torch.Size([142, 128])
✓ Transformer embeddings: torch.Size([49, 128])

3. RUNNING ATTENTION LAYER
----------------------------------------
✓ Enhanced embeddings: torch.Size([200, 128])
✓ Complementarity matrix: torch.Size([200, 200])

4. RUNNING TEMPORAL LAYER
----------------------------------------
✓ Temporal embeddings: torch.Size([200, 128])
✓ Consumption predictions: torch.Size([200, 24])
✓ Peak indicators: torch.Size([200, 24])

5. RUNNING PHYSICS LAYER
----------------------------------------
✓ Feasible sharing: torch.Size([1, 200, 200])
✓ Total penalty: 700.2806

6. RUNNING TASK HEADS
----------------------------------------

PIPELINE RESULTS ANALYSIS

📊 CLUSTERING RESULTS
Total clusters forme

In [4]:
# Complete test with all layers and interventions
import torch
import torch.nn as nn
import numpy as np
from neo4j import GraphDatabase
import logging
from typing import Dict, List, Tuple, Optional

# Import all modules
from models.base_gnn import create_energy_gnn_base
from models.attention_layers import EnergyComplementarityAttention
from models.temporal_layers import TemporalProcessor
from models.physics_layers import PhysicsConstraintLayer
from models.task_heads import create_energy_task_heads

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def test_complete_pipeline_with_interventions():
    """Complete test with all layers and intervention recommendations"""
    
    print("\n" + "="*70)
    print("COMPLETE PIPELINE TEST WITH INTERVENTIONS")
    print("="*70 + "\n")
    
    # Neo4j credentials
    neo4j_uri = "bolt://localhost:7687"
    neo4j_user = "neo4j"
    neo4j_password = "aminasad"
    
    # Configuration
    config = {
        'num_building_features': 17,
        'num_cable_features': 8,
        'num_transformer_features': 5,
        'num_cluster_features': 5,
        'hidden_dim': 128,
        'num_layers': 3,
        'dropout': 0.1,
        'attention_heads': 8,
        'min_cluster_size': 3,
        'max_cluster_size': 15,
        'max_recommendations': 10,  # ✅ ENABLED
        'carbon_intensity': 0.4,
        'temporal_dim': 24,
        'enforce_hard_boundaries': True,
        'check_balance': True,
        'apply_losses': True,
        'validate_temporal': False
    }
    
    try:
        # 1. FETCH DATA FROM NEO4J
        print("1. FETCHING DATA FROM NEO4J")
        print("-" * 40)
        fetcher = Neo4jCompleteDataFetcher(neo4j_uri, neo4j_user, neo4j_password)
        graph_data = fetcher.fetch_graph_data(limit_buildings=200)
        
        data = prepare_graph_tensors(graph_data)
        print(f"✓ Loaded {data['num_buildings']} buildings")
        print(f"✓ Buildings with solar: {data['has_solar'].sum().item():.0f}")
        
        # 2. BASE GNN LAYER
        print("\n2. RUNNING BASE GNN LAYER")
        print("-" * 40)
        base_gnn = create_energy_gnn_base(config)
        base_gnn.eval()
        
        with torch.no_grad():
            base_output = base_gnn(
                data['node_features'],
                data['edge_indices']
            )
        print(f"✓ Building embeddings: {base_output['building'].shape}")
        
        # 3. ATTENTION LAYER
        print("\n3. RUNNING ATTENTION LAYER")
        print("-" * 40)
        attention_layer = EnergyComplementarityAttention(config)
        attention_layer.eval()
        
        with torch.no_grad():
            attention_output = attention_layer(  # ✅ THIS CREATES attention_output
                base_output,
                data['edge_indices'],
                return_attention=False
            )
        print(f"✓ Enhanced embeddings: {attention_output['embeddings']['building'].shape}")
        print(f"✓ Complementarity matrix: {attention_output['complementarity_matrix'].shape}")
        
        # 4. TEMPORAL LAYER
        print("\n4. RUNNING TEMPORAL LAYER")
        print("-" * 40)
        temporal_processor = TemporalProcessor(config)
        temporal_processor.eval()
        
        consumption_history = torch.randn(1, data['num_buildings'], 24, 8)
        temporal_data = {
            'consumption_history': consumption_history,
            'season': torch.tensor(0),
            'is_weekend': torch.tensor(False)
        }
        
        with torch.no_grad():
            temporal_output = temporal_processor(
                attention_output['embeddings'],
                temporal_data=temporal_data,
                current_hour=14,
                return_all_hours=False
            )
        print(f"✓ Temporal embeddings: {temporal_output['embeddings']['building'].shape}")
        
        # 5. PHYSICS LAYER
        print("\n5. RUNNING PHYSICS LAYER")
        print("-" * 40)
        physics_layer = PhysicsConstraintLayer(config)
        physics_layer.eval()
        
        # Fix batch dimension
        if temporal_output['embeddings']['building'].dim() == 2:
            for key in temporal_output['embeddings']:
                temporal_output['embeddings'][key] = temporal_output['embeddings'][key].unsqueeze(0)
        
        num_buildings = data['num_buildings']
        sharing_proposals = torch.rand(1, num_buildings, num_buildings) * 5
        sharing_proposals = (sharing_proposals + sharing_proposals.transpose(1, 2)) / 2
        
        physics_metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'valid_lv_mask': data['valid_lv_mask'],
            'positions': data['positions'],
            'temporal_states': temporal_output.get('temporal_encoding')
        }
        
        with torch.no_grad():
            physics_output = physics_layer(
                temporal_output['embeddings'],
                sharing_proposals,
                data['consumption'],
                data['generation'],
                physics_metadata
            )
        print(f"✓ Feasible sharing: {physics_output['feasible_sharing'].shape}")
        
        # 6. TASK HEADS WITH INTERVENTIONS
        print("\n6. RUNNING TASK HEADS WITH INTERVENTIONS")
        print("-" * 40)
        
        # Create proper intervention data
        building_features_for_interventions = {
            'roof_area': {},
            'orientation_score': {},
            'building_age': {},
            'energy_label': {},
            'peak_demand': {},
            'heating_demand': {},
            'consumption_history': torch.randn(num_buildings, 24)
        }
        
        for i in range(num_buildings):
            building_features_for_interventions['roof_area'][i] = data['roof_areas'].get(i, 100)
            building_features_for_interventions['orientation_score'][i] = 0.7 + np.random.random() * 0.3
            building_features_for_interventions['building_age'][i] = data['building_ages'].get(i, 30)
            building_features_for_interventions['energy_label'][i] = data['energy_labels'].get(i, 'D')
            building_features_for_interventions['peak_demand'][i] = data['consumption'][0, i].item() * 1.5
            building_features_for_interventions['heating_demand'][i] = data['consumption'][0, i].item() * 0.4
        
        current_assets_dict = {
            'has_solar': {i: bool(data['has_solar'][i].item()) for i in range(num_buildings)},
            'has_battery': {i: bool(data['has_battery'][i].item()) for i in range(num_buildings)}
        }
        
        task_metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'positions': data['positions'],
            'generation': data['generation'],
            'consumption': data['consumption'],
            'complementarity_matrix': attention_output['complementarity_matrix'],  # ✅ NOW DEFINED
            'building_types': data['building_types'],
            'building_features': building_features_for_interventions,
            'current_assets': current_assets_dict
        }
        
        task_heads = create_energy_task_heads(config)
        task_heads.eval()
        
        with torch.no_grad():
            task_output = task_heads(
                physics_output['feasible_embeddings'],
                task_metadata,
                current_hour=14
            )
        
        # 7. ANALYZE RESULTS
        print("\n" + "="*70)
        print("COMPLETE RESULTS WITH INTERVENTIONS")
        print("="*70)
        
        # Clustering
        print("\n📊 CLUSTERING")
        print(f"Clusters formed: {task_output['clustering']['num_clusters']}")
        
        # Energy Sharing
        print("\n⚡ ENERGY SHARING")
        print(f"Total shared: {task_output['sharing']['total_shared_kw']:.1f} kW")
        
        # Metrics
        print("\n📈 PERFORMANCE METRICS")
        print(f"Avg Self-Sufficiency: {task_output['summary']['avg_self_sufficiency']:.1%}")
        print(f"Avg Peak Reduction: {task_output['summary']['avg_peak_reduction']:.1%}")
        print(f"Carbon Saved: {task_output['summary']['total_carbon_saved_kg']:.1f} kg/day")
        
        # INTERVENTIONS
        print("\n🔧 INTERVENTION RECOMMENDATIONS")
        print("-" * 40)
        recommendations = task_output['recommendations']
        
        if recommendations:
            print(f"Total recommendations: {len(recommendations)}")
            
            # Count by type
            solar_count = sum(1 for r in recommendations if r.intervention_type == 'solar')
            battery_count = sum(1 for r in recommendations if r.intervention_type == 'battery')
            retrofit_count = sum(1 for r in recommendations if r.intervention_type == 'retrofit')
            
            print(f"\nBy Type:")
            print(f"  🌞 Solar panels: {solar_count}")
            print(f"  🔋 Batteries: {battery_count}")
            print(f"  🏠 Retrofits: {retrofit_count}")
            
            print(f"\n📋 Top 5 Interventions:")
            
            for i, rec in enumerate(recommendations[:5]):
                icons = {'solar': '🌞', 'battery': '🔋', 'retrofit': '🏠'}
                print(f"\n{i+1}. {icons.get(rec.intervention_type, '🔧')} Building {rec.building_id}")
                print(f"   Type: {rec.intervention_type.upper()}")
                
                if rec.intervention_type == 'solar':
                    print(f"   Capacity: {rec.capacity:.1f} kWp")
                elif rec.intervention_type == 'battery':
                    print(f"   Capacity: {rec.capacity:.1f} kWh")
                else:
                    print(f"   Level: {'Full' if rec.capacity > 0.8 else 'Partial'}")
                
                print(f"   SSR Impact: +{rec.impact_ssr*100:.1f}%")
                print(f"   Peak Impact: -{rec.impact_peak*100:.1f}%")
                print(f"   ROI: {rec.roi_years:.1f} years")
                print(f"   Confidence: {rec.confidence:.1%}")
            
            # Summary
            total_ssr_impact = sum(r.impact_ssr for r in recommendations)
            avg_roi = np.mean([r.roi_years for r in recommendations])
            
            print(f"\n📊 Summary:")
            print(f"  Total SSR potential: +{total_ssr_impact*100:.1f}%")
            print(f"  Average ROI: {avg_roi:.1f} years")
            
            # Best ROI
            best_roi = min(recommendations, key=lambda x: x.roi_years)
            print(f"\n🏆 Best ROI: Building {best_roi.building_id} "
                  f"({best_roi.intervention_type}, {best_roi.roi_years:.1f} years)")
        else:
            print("No interventions recommended")
        
        print("\n" + "="*70)
        print("✅ COMPLETE PIPELINE WITH INTERVENTIONS SUCCESSFUL!")
        print("="*70)
        
        print("\n🎉 ALL COMPONENTS WORKING:")
        print("1. Neo4j KG ✓")
        print("2. Base GNN ✓")
        print("3. Attention ✓")
        print("4. Temporal ✓")
        print("5. Physics ✓")
        print("6. Task Heads ✓")
        print("7. Interventions ✓")
        
        fetcher.close()
        return task_output
        
    except Exception as e:
        logger.error(f"Test failed: {str(e)}")
        import traceback
        traceback.print_exc()
        return None


# Run it!
if __name__ == "__main__":
    result = test_complete_pipeline_with_interventions()

2025-08-21 03:04:43,752 - __main__ - INFO - Connected to Neo4j for complete pipeline testing
2025-08-21 03:04:43,753 - __main__ - INFO - Fetching building data...



COMPLETE PIPELINE TEST WITH INTERVENTIONS

1. FETCHING DATA FROM NEO4J
----------------------------------------


2025-08-21 03:04:45,832 - __main__ - INFO - Fetched 200 buildings
2025-08-21 03:04:45,833 - __main__ - INFO - Fetching cable groups...
2025-08-21 03:04:45,845 - __main__ - INFO - Fetched 142 cable groups
2025-08-21 03:04:45,845 - __main__ - INFO - Fetching transformers...
2025-08-21 03:04:45,850 - __main__ - INFO - Fetched 49 transformers
2025-08-21 03:04:45,851 - __main__ - INFO - Fetching graph edges...
2025-08-21 03:04:45,878 - __main__ - INFO - Fetched edges: 200 B->C, 111 C->T, 16 adjacencies
2025-08-21 03:04:45,879 - __main__ - INFO - Preparing tensors: 200 buildings, 142 cable groups, 49 transformers
2025-08-21 03:04:45,893 - __main__ - INFO - Created synthetic data: 26 buildings with solar, 5 with batteries
2025-08-21 03:04:45,916 - models.base_gnn - INFO - Initialized EnergyGNNBase with 3 layers
2025-08-21 03:04:45,918 - models.base_gnn - INFO - Created EnergyGNNBase with 417,780 parameters
2025-08-21 03:04:45,918 - models.base_gnn - INFO - Trainable parameters: 417,780
2025-0

✓ Loaded 200 buildings
✓ Buildings with solar: 26

2. RUNNING BASE GNN LAYER
----------------------------------------
✓ Building embeddings: torch.Size([200, 128])

3. RUNNING ATTENTION LAYER
----------------------------------------
✓ Enhanced embeddings: torch.Size([200, 128])
✓ Complementarity matrix: torch.Size([200, 200])

4. RUNNING TEMPORAL LAYER
----------------------------------------
✓ Temporal embeddings: torch.Size([200, 128])

5. RUNNING PHYSICS LAYER
----------------------------------------
✓ Feasible sharing: torch.Size([1, 200, 200])

6. RUNNING TASK HEADS WITH INTERVENTIONS
----------------------------------------

COMPLETE RESULTS WITH INTERVENTIONS

📊 CLUSTERING
Clusters formed: 36

⚡ ENERGY SHARING
Total shared: 0.0 kW

📈 PERFORMANCE METRICS
Avg Self-Sufficiency: 9.0%
Avg Peak Reduction: 0.0%
Carbon Saved: 434.0 kg/day

🔧 INTERVENTION RECOMMENDATIONS
----------------------------------------
Total recommendations: 10

By Type:
  🌞 Solar panels: 7
  🔋 Batteries: 3
  🏠 

In [3]:
# Quick start training
from training.train_energy_gnn import train_energy_gnn

# Use your Neo4j data
model, trainer = train_energy_gnn(
    neo4j_data=data,  # Your prepared data
    config=config,
    epochs=100
)

# Check results
print(f"Final SSR: {trainer.train_history['metrics']['avg_ssr'][-1]:.1%}")
print(f"Final Peak Reduction: {trainer.train_history['metrics']['avg_peak_reduction'][-1]:.1%}")

NameError: name 'data' is not defined

In [5]:
# complete_training.py
"""
Complete self-contained training script for Energy GNN
No external imports needed - everything included
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from neo4j import GraphDatabase
import logging
from typing import Dict, List, Tuple, Optional
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import your model components
from models.base_gnn import create_energy_gnn_base
from models.attention_layers import EnergyComplementarityAttention
from models.temporal_layers import TemporalProcessor
from models.physics_layers import PhysicsConstraintLayer
from models.task_heads import create_energy_task_heads

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# ============================================
# DATA LOADING COMPONENTS
# ============================================

class Neo4jCompleteDataFetcher:
    """Fetch all required data from Neo4j for complete pipeline"""
    
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        logger.info("Connected to Neo4j for data fetching")
    
    def fetch_graph_data(self, limit_buildings: int = 200):
        """Fetch complete graph structure and features"""
        
        with self.driver.session() as session:
            # Fetch buildings
            logger.info("Fetching building data...")
            buildings_result = session.run("""
                MATCH (b:Building)-[:CONNECTED_TO]->(cg:CableGroup {voltage_level: 'LV'})
                OPTIONAL MATCH (cg)-[:CONNECTS_TO]->(t:Transformer)
                OPTIONAL MATCH (b)-[:ADJACENT_TO]-(neighbor:Building)
                RETURN 
                    b.ogc_fid as building_id,
                    b.x as x,
                    b.y as y,
                    b.area as area,
                    b.height as height,
                    cg.group_id as lv_group_id,
                    CASE WHEN t IS NOT NULL THEN t.transformer_id ELSE null END as transformer_id,
                    COUNT(DISTINCT neighbor) as neighbor_count
                ORDER BY building_id
                LIMIT $limit
            """, limit=limit_buildings)
            
            buildings = list(buildings_result)
            logger.info(f"Fetched {len(buildings)} buildings")
            
            # Fetch cable groups
            cable_groups_result = session.run("""
                MATCH (cg:CableGroup {voltage_level: 'LV'})
                OPTIONAL MATCH (cg)-[:CONNECTS_TO]->(t:Transformer)
                OPTIONAL MATCH (b:Building)-[:CONNECTED_TO]->(cg)
                RETURN 
                    cg.group_id as cable_group_id,
                    COUNT(DISTINCT b) as building_count,
                    CASE WHEN t IS NOT NULL THEN t.transformer_id ELSE null END as transformer_id,
                    CASE WHEN t IS NOT NULL THEN true ELSE false END as has_transformer
                ORDER BY building_count DESC
            """)
            
            cable_groups = list(cable_groups_result)
            
            # Fetch transformers
            transformers_result = session.run("""
                MATCH (t:Transformer)
                OPTIONAL MATCH (cg:CableGroup)-[:CONNECTS_TO]->(t)
                RETURN 
                    t.transformer_id as transformer_id,
                    COUNT(DISTINCT cg) as cable_group_count
            """)
            
            transformers = list(transformers_result)
            
            # Fetch edges
            building_to_cable_result = session.run("""
                MATCH (b:Building)-[:CONNECTED_TO]->(cg:CableGroup {voltage_level: 'LV'})
                WHERE b.ogc_fid IN $building_ids
                RETURN b.ogc_fid as building_id, cg.group_id as cable_group_id
            """, building_ids=[b['building_id'] for b in buildings])
            
            building_to_cable_edges = list(building_to_cable_result)
            
            cable_to_transformer_result = session.run("""
                MATCH (cg:CableGroup {voltage_level: 'LV'})-[:CONNECTS_TO]->(t:Transformer)
                RETURN cg.group_id as cable_group_id, t.transformer_id as transformer_id
            """)
            
            cable_to_transformer_edges = list(cable_to_transformer_result)
            
            return {
                'buildings': buildings,
                'cable_groups': cable_groups,
                'transformers': transformers,
                'building_to_cable_edges': building_to_cable_edges,
                'cable_to_transformer_edges': cable_to_transformer_edges
            }
    
    def close(self):
        self.driver.close()


def safe_get(dictionary, key, default_value):
    """Safely get value from dictionary"""
    value = dictionary.get(key)
    if value is None:
        return default_value
    return value


def prepare_graph_tensors(graph_data: Dict) -> Dict:
    """Convert Neo4j data to tensors"""
    
    buildings = graph_data['buildings']
    cable_groups = graph_data['cable_groups']
    transformers = graph_data['transformers']
    
    num_buildings = len(buildings)
    num_cable_groups = len(cable_groups)
    num_transformers = len(transformers)
    
    # Create ID mappings
    building_id_to_idx = {b['building_id']: i for i, b in enumerate(buildings)}
    cable_id_to_idx = {c['cable_group_id']: i for i, c in enumerate(cable_groups)}
    transformer_id_to_idx = {t['transformer_id']: i for i, t in enumerate(transformers)}
    
    # Create LV group mappings
    valid_cable_groups = [c for c in cable_groups if c['has_transformer']]
    lv_group_to_idx = {c['cable_group_id']: i for i, c in enumerate(valid_cable_groups)}
    
    # Prepare features
    building_features = torch.zeros(num_buildings, 17)
    positions = torch.zeros(num_buildings, 2)
    lv_group_ids = torch.full((num_buildings,), -1, dtype=torch.long)
    valid_lv_mask = torch.zeros(num_buildings)
    has_solar = torch.zeros(num_buildings)
    has_battery = torch.zeros(num_buildings)
    
    for i, building in enumerate(buildings):
        # Position
        positions[i, 0] = safe_get(building, 'x', 0.0)
        positions[i, 1] = safe_get(building, 'y', 0.0)
        
        # Features
        area = safe_get(building, 'area', 100.0)
        height = safe_get(building, 'height', 10.0)
        
        building_features[i, 0] = area / 500.0 if area else 0.2
        building_features[i, 1] = height / 30.0 if height else 0.33
        building_features[i, 2] = 0.2  # floors
        building_features[i, 3] = 0.44  # age
        building_features[i, 6] = safe_get(building, 'neighbor_count', 0) / 10.0
        
        # Building type based on area
        if area and area > 500:
            building_features[i, 7] = 1.0  # office
        elif area and area > 200:
            building_features[i, 7] = 2.0  # retail
        else:
            building_features[i, 7] = 0.0  # residential
        
        # LV group assignment
        cable_group_id = building['lv_group_id']
        if cable_group_id in lv_group_to_idx:
            lv_group_ids[i] = lv_group_to_idx[cable_group_id]
            valid_lv_mask[i] = 1.0
        
        # Synthetic solar/battery
        if area and area > 300:
            has_solar[i] = np.random.random() > 0.7
        has_battery[i] = has_solar[i] * (np.random.random() > 0.8)
        
        building_features[i, 4] = has_solar[i]
        building_features[i, 5] = has_battery[i]
    
    # Cable group features
    cable_features = torch.zeros(num_cable_groups, 4)
    for i, cable in enumerate(cable_groups):
        cable_features[i, 0] = cable['building_count'] / 50.0
        cable_features[i, 1] = 1.0 if cable['has_transformer'] else 0.0
    
    # Transformer features
    transformer_features = torch.zeros(num_transformers, 3)
    for i, transformer in enumerate(transformers):
        transformer_features[i, 0] = transformer['cable_group_count'] / 10.0
        transformer_features[i, 1] = 0.25  # Assumed capacity
        transformer_features[i, 2] = 0.95  # Efficiency
    
    # Edge indices
    edge_index_b2c = torch.zeros(2, len(graph_data['building_to_cable_edges']), dtype=torch.long)
    valid_count = 0
    for edge in graph_data['building_to_cable_edges']:
        if edge['building_id'] in building_id_to_idx and edge['cable_group_id'] in cable_id_to_idx:
            edge_index_b2c[0, valid_count] = building_id_to_idx[edge['building_id']]
            edge_index_b2c[1, valid_count] = cable_id_to_idx[edge['cable_group_id']]
            valid_count += 1
    edge_index_b2c = edge_index_b2c[:, :valid_count]
    
    edge_index_c2t = torch.zeros(2, len(graph_data['cable_to_transformer_edges']), dtype=torch.long)
    valid_count = 0
    for edge in graph_data['cable_to_transformer_edges']:
        if edge['cable_group_id'] in cable_id_to_idx and edge['transformer_id'] in transformer_id_to_idx:
            edge_index_c2t[0, valid_count] = cable_id_to_idx[edge['cable_group_id']]
            edge_index_c2t[1, valid_count] = transformer_id_to_idx[edge['transformer_id']]
            valid_count += 1
    edge_index_c2t = edge_index_c2t[:, :valid_count]
    
    # Synthetic consumption/generation
    consumption = torch.zeros(1, num_buildings)
    generation = torch.zeros(1, num_buildings)
    
    for i in range(num_buildings):
        btype = building_features[i, 7].item()
        if btype == 1:  # office
            base_consumption = 15.0
        elif btype == 2:  # retail
            base_consumption = 20.0
        else:
            base_consumption = 8.0
        
        consumption[0, i] = base_consumption * (0.5 + building_features[i, 0].item() * 2) + np.random.randn() * 2
        consumption[0, i] = max(consumption[0, i], 1.0)
        
        if has_solar[i] > 0:
            generation[0, i] = min(area * 0.15 if area else 0, 50) * np.random.uniform(0.6, 1.0)
    
    return {
        'node_features': {
            'building': building_features,
            'cable_group': cable_features,
            'transformer': transformer_features
        },
        'edge_indices': {
            ('building', 'connected_to', 'cable_group'): edge_index_b2c,
            ('cable_group', 'connects_to', 'transformer'): edge_index_c2t
        },
        'positions': positions,
        'lv_group_ids': lv_group_ids,
        'valid_lv_mask': valid_lv_mask,
        'consumption': consumption,
        'generation': generation,
        'has_solar': has_solar,
        'has_battery': has_battery,
        'num_buildings': num_buildings,
        'num_cable_groups': num_cable_groups,
        'num_transformers': num_transformers
    }


# ============================================
# TRAINING COMPONENTS
# ============================================

class EnergyGNNModel(nn.Module):
    """Complete Energy Planning GNN Model"""
    
    def __init__(self, config: Dict):
        super().__init__()
        
        self.base_gnn = create_energy_gnn_base(config)
        self.attention_layer = EnergyComplementarityAttention(config)
        self.temporal_processor = TemporalProcessor(config)
        self.physics_layer = PhysicsConstraintLayer(config)
        self.task_heads = create_energy_task_heads(config)
        
        self.config = config
        
    def forward(self, data: Dict, current_hour: int = 14) -> Dict:
        """Forward pass through all layers"""
        
        # 1. Base GNN
        base_output = self.base_gnn(
            data['node_features'],
            data['edge_indices']
        )
        
        # 2. Attention
        attention_output = self.attention_layer(
            base_output,
            data['edge_indices'],
            return_attention=False
        )
        
        # 3. Temporal
        temporal_output = self.temporal_processor(
            attention_output['embeddings'],
            temporal_data=data.get('temporal_data'),
            current_hour=current_hour,
            return_all_hours=False
        )
        
        # Ensure batch dimension
        if temporal_output['embeddings']['building'].dim() == 2:
            for key in temporal_output['embeddings']:
                temporal_output['embeddings'][key] = temporal_output['embeddings'][key].unsqueeze(0)
        
        # 4. Physics
        num_buildings = data['num_buildings']
        sharing_proposals = torch.rand(1, num_buildings, num_buildings) * 5
        sharing_proposals = (sharing_proposals + sharing_proposals.transpose(1, 2)) / 2
        
        physics_metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'valid_lv_mask': data['valid_lv_mask'],
            'positions': data['positions'],
            'temporal_states': temporal_output.get('temporal_encoding')
        }
        
        physics_output = self.physics_layer(
            temporal_output['embeddings'],
            sharing_proposals,
            data['consumption'],
            data['generation'],
            physics_metadata
        )
        
        # 5. Task Heads
        task_metadata = {
            'lv_group_ids': data['lv_group_ids'],
            'positions': data['positions'],
            'generation': data['generation'],
            'consumption': data['consumption'],
            'complementarity_matrix': attention_output['complementarity_matrix'],
            'building_features': {},
            'current_assets': {}
        }
        
        task_output = self.task_heads(
            physics_output['feasible_embeddings'],
            task_metadata,
            current_hour=current_hour
        )
        
        return {
            'tasks': task_output,
            'physics_penalty': physics_output['total_penalty']
        }


class SimpleLoss(nn.Module):
    """Simplified loss for training"""
    
    def __init__(self):
        super().__init__()
        
    def forward(self, outputs, data):
        # Physics penalty
        physics_loss = outputs['physics_penalty']
        
        # Self-sufficiency (maximize local generation use)
        if 'tasks' in outputs and 'summary' in outputs['tasks']:
            summary = outputs['tasks']['summary']
            # Convert SSR to loss (we want to maximize, so negate)
            ssr_loss = 1.0 - summary['avg_self_sufficiency']
            # Peak reduction loss
            peak_loss = 1.0 - summary['avg_peak_reduction']
        else:
            ssr_loss = torch.tensor(1.0)
            peak_loss = torch.tensor(1.0)
        
        total_loss = physics_loss + ssr_loss + peak_loss
        
        return total_loss, {
            'physics': physics_loss.item() if hasattr(physics_loss, 'item') else physics_loss,
            'ssr': ssr_loss.item() if hasattr(ssr_loss, 'item') else ssr_loss,
            'peak': peak_loss.item() if hasattr(peak_loss, 'item') else peak_loss,
            'total': total_loss.item() if hasattr(total_loss, 'item') else total_loss
        }


# ============================================
# MAIN TRAINING FUNCTION
# ============================================

def train_energy_gnn():
    """Main training function"""
    
    print("\n" + "="*60)
    print("ENERGY GNN TRAINING")
    print("="*60 + "\n")
    
    # Configuration
    config = {
        'num_building_features': 17,
        'num_cable_features': 8,
        'num_transformer_features': 5,
        'num_cluster_features': 5,
        'hidden_dim': 128,
        'num_layers': 3,
        'dropout': 0.1,
        'attention_heads': 8,
        'min_cluster_size': 3,
        'max_cluster_size': 15,
        'max_recommendations': 0,  # Disable for training
        'carbon_intensity': 0.4,
        'temporal_dim': 24,
        'enforce_hard_boundaries': True,
        'check_balance': True,
        'apply_losses': True,
        'validate_temporal': False
    }
    
    # Load data
    print("Loading data from Neo4j...")
    neo4j_uri = "bolt://localhost:7687"
    neo4j_user = "neo4j"
    neo4j_password = "aminasad"
    
    fetcher = Neo4jCompleteDataFetcher(neo4j_uri, neo4j_user, neo4j_password)
    graph_data = fetcher.fetch_graph_data(limit_buildings=200)
    data = prepare_graph_tensors(graph_data)
    fetcher.close()
    
    print(f"✓ Loaded {data['num_buildings']} buildings")
    print(f"✓ Buildings with solar: {data['has_solar'].sum().item():.0f}\n")
    
    # Add temporal data
    data['temporal_data'] = {
        'consumption_history': torch.randn(1, data['num_buildings'], 24, 8),
        'season': torch.tensor(0),
        'is_weekend': torch.tensor(False)
    }
    
    # Create model
    print("Initializing model...")
    model = EnergyGNNModel(config)
    
    # Loss and optimizer
    criterion = SimpleLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training
    epochs = 50
    print(f"Training for {epochs} epochs...\n")
    
    history = {'loss': [], 'metrics': {}}
    
    for epoch in range(epochs):
        model.train()
        
        # Forward
        outputs = model(data)
        loss, metrics = criterion(outputs, data)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Store history
        history['loss'].append(metrics['total'])
        
        # Print progress
        if epoch % 10 == 0:
            print(f"Epoch {epoch:3d}/{epochs}")
            print(f"  Loss: {metrics['total']:.4f}")
            print(f"  Components: Physics={metrics['physics']:.3f}, SSR={metrics['ssr']:.3f}, Peak={metrics['peak']:.3f}")
            
            if 'tasks' in outputs and 'summary' in outputs['tasks']:
                s = outputs['tasks']['summary']
                print(f"  Performance: SSR={s['avg_self_sufficiency']:.1%}, Shared={s['total_energy_shared_kw']:.1f}kW")
            print()
    
    print("✅ Training complete!")
    
    # Save model
    torch.save(model.state_dict(), 'trained_model.pth')
    print("✅ Model saved to 'trained_model.pth'\n")
    
    # Final evaluation
    model.eval()
    with torch.no_grad():
        final_output = model(data)
        if 'tasks' in final_output and 'summary' in final_output['tasks']:
            summary = final_output['tasks']['summary']
            print("Final Performance:")
            print(f"  Self-Sufficiency: {summary['avg_self_sufficiency']:.1%}")
            print(f"  Peak Reduction: {summary['avg_peak_reduction']:.1%}")
            print(f"  Energy Shared: {summary['total_energy_shared_kw']:.1f} kW")
            print(f"  Carbon Saved: {summary['total_carbon_saved_kg']:.1f} kg/day")
    
    return model, history


# ============================================
# RUN TRAINING
# ============================================

if __name__ == "__main__":
    model, history = train_energy_gnn()

INFO:__main__:Connected to Neo4j for data fetching
INFO:__main__:Fetching building data...



ENERGY GNN TRAINING

Loading data from Neo4j...


INFO:__main__:Fetched 200 buildings
INFO:models.base_gnn:Initialized EnergyGNNBase with 3 layers


✓ Loaded 200 buildings
✓ Buildings with solar: 21

Initializing model...


INFO:models.base_gnn:Created EnergyGNNBase with 417,780 parameters
INFO:models.base_gnn:Trainable parameters: 417,780
INFO:models.attention_layers:Initialized EnergyComplementarityAttention
INFO:models.temporal_layers:Initialized TemporalProcessor with all components
INFO:models.physics_layers:Initialized PhysicsConstraintLayer
INFO:models.task_heads:Initialized DynamicSubClusteringHead
INFO:models.task_heads:Initialized EnergySharingPredictor
INFO:models.task_heads:Initialized SelfSufficiencyMetricsCalculator
INFO:models.task_heads:Initialized InterventionRecommender
INFO:models.task_heads:Initialized EnergyTaskHeads with all components


Training for 50 epochs...

Epoch   0/50
  Loss: 703.4460
  Components: Physics=701.455, SSR=0.991, Peak=1.000
  Performance: SSR=0.9%, Shared=0.0kW

Epoch  10/50
  Loss: 690.2335
  Components: Physics=688.243, SSR=0.990, Peak=1.000
  Performance: SSR=1.0%, Shared=0.0kW

Epoch  20/50
  Loss: 690.1982
  Components: Physics=688.207, SSR=0.991, Peak=1.000
  Performance: SSR=0.9%, Shared=0.0kW

Epoch  30/50
  Loss: 671.9907
  Components: Physics=670.000, SSR=0.991, Peak=1.000
  Performance: SSR=0.9%, Shared=0.0kW

Epoch  40/50
  Loss: 666.5448
  Components: Physics=664.554, SSR=0.991, Peak=1.000
  Performance: SSR=0.9%, Shared=0.0kW

✅ Training complete!
✅ Model saved to 'trained_model.pth'

Final Performance:
  Self-Sufficiency: 0.9%
  Peak Reduction: 0.0%
  Energy Shared: 0.0 kW
  Carbon Saved: 106.0 kg/day
