In [None]:
# main.py

import geopandas as gpd
import pandas as pd
import numpy as np
from grid_traffic_monitor import GridTrafficMonitor
from traffic_congestion_pred import TrafficCongestionPredictor
import time
from datetime import datetime
import h3
from IPython.display import display, clear_output
import warnings
import sys

warnings.filterwarnings('ignore')

def prepare_features(traffic_df: pd.DataFrame, num_nodes: int, num_timesteps: int) -> np.ndarray:
    """
    Prepare features in the correct shape for the predictor.
    
    Args:
        traffic_df: DataFrame containing traffic data
        num_nodes: Number of nodes (hexagons) in the grid
        num_timesteps: Number of timesteps for prediction
    
    Returns:
        Properly shaped feature array
    """
    print(f"\nPreparing features:")
    print(f"Input DataFrame shape: {traffic_df.shape}")
    print(f"Number of nodes: {num_nodes}")
    print(f"Number of timesteps: {num_timesteps}")

    # Ensure the necessary columns exist
    required_columns = ['traffic_density', 'time_of_day', 'day_of_week']
    if not all(col in traffic_df.columns for col in required_columns):
        raise ValueError(f"Missing one or more required columns: {required_columns}")

    # Get unique timestamps and sort them
    timestamps = traffic_df['time_of_day'].unique()
    timestamps.sort()
    
    # Take the most recent num_timesteps timestamps
    recent_timestamps = timestamps[-num_timesteps:] if len(timestamps) >= num_timesteps else timestamps

    # Initialize feature array
    features = np.zeros((1, num_timesteps, num_nodes, len(required_columns)))

    # Fill in the features array with available data
    for t, timestamp in enumerate(recent_timestamps):
        time_data = traffic_df[traffic_df['time_of_day'] == timestamp]
        
        for n, hex_id in enumerate(traffic_df['hex_id'].unique()):
            hex_data = time_data[time_data['hex_id'] == hex_id]
            if not hex_data.empty:
                features[0, t, n, :] = hex_data[required_columns].iloc[0].values

    print(f"Output feature array shape: {features.shape}")
    return features

def main():
    # Initialize GridTrafficMonitor with more conservative parameters
    monitor = GridTrafficMonitor(
        base_resolution=8,
        min_resolution=7,
        max_resolution=10,
        min_traffic_density=100.0,
        max_merge_threshold=20.0,
        smoothing_factor=0.3
    )

    # Load city boundary
    try:
        print("Loading geographic data...")
        city_gdf = monitor.load_city_boundary('/home/raw/Desktop/Coding/Jhakaas_Rasta/geopkg/Ahmedabad.gpkg')
        boundary_gdf = gpd.read_file('/home/raw/Desktop/Coding/Jhakaas_Rasta/geopkg/clipping_boundary.geojson')
        
        if city_gdf.crs != boundary_gdf.crs:
            boundary_gdf = boundary_gdf.to_crs(city_gdf.crs)
        
        city_gdf = gpd.clip(city_gdf, boundary_gdf)
        print("Geographic data loaded successfully")
    except Exception as e:
        print(f"Error loading geographic data: {str(e)}")
        return

    # Initialize grid system
    try:
        print("Initializing grid system...")
        hex_polygons = monitor.initialize_grid(city_gdf)
        num_nodes = len(hex_polygons)
        print(f"Grid system initialized with {num_nodes} hexagons")
    except Exception as e:
        print(f"Error initializing grid system: {str(e)}")
        return

    # Initialize traffic predictor with correct number of nodes
    try:
        print("Initializing traffic predictor...")
        predictor = TrafficCongestionPredictor(
            num_nodes=num_nodes,  # Using the actual number of hexagons
            input_dim=3,  # traffic_density, time_of_day, day_of_week
            hidden_dims=[64, 32, 16],
            output_dim=1,
            num_timesteps=12,
            batch_size=32
        )
        print("Traffic predictor initialized successfully")
    except Exception as e:
        print(f"Error initializing traffic predictor: {str(e)}")
        return

    # Initialize historical data storage
    historical_data = []

    # Simulate real-time updates
    try:
        while True:
            current_time = datetime.now()
            
            # Simulate traffic updates
            traffic_data = []
            hex_ids = list(monitor.current_grids.keys())
            
            print(f"Number of hexagons being processed: {len(hex_ids)}")
            
            for hex_id in hex_ids:
                hour = current_time.hour
                base_density = 50 + 50 * np.sin(np.pi * hour / 12)
                noise = np.random.normal(0, 10)
                new_density = max(0, base_density + noise)

                monitor.update_traffic_density(hex_id, new_density)

                traffic_data.append({
                    'hex_id': hex_id,
                    'traffic_density': new_density,
                    'time_of_day': current_time.hour + current_time.minute / 60,
                    'day_of_week': current_time.weekday()
                })

            # Create DataFrame and store in historical data
            current_df = pd.DataFrame(traffic_data)
            print("\nTraffic DataFrame preview:")
            print(current_df.head())
            print(f"Total records in current update: {len(current_df)}")

            historical_data.append(current_df)

            # Keep only recent history
            if len(historical_data) > predictor.num_timesteps:
                historical_data.pop(0)

            # Combine historical data
            traffic_df = pd.concat(historical_data, ignore_index=True)
            print(f"\nCombined historical data shape: {traffic_df.shape}")

            # Create visualization first (so we always have it)
            center_lat = city_gdf.geometry.centroid.y.iloc[0]
            center_lon = city_gdf.geometry.centroid.x.iloc[0]
            m = monitor.visualize_grid(center_lat, center_lon)

            # Try predictions
            try:
                features, targets, adj_matrix = predictor.prepare_data(
                    traffic_df,
                    hex_ids
                )
                
                shaped_features = prepare_features(
                    traffic_df, 
                    num_nodes,
                    predictor.num_timesteps
                )

                predictions = predictor.predict(shaped_features, adj_matrix)

                # Display predictions
                print("\nPredicted traffic densities (next timestep):")
                for hex_id, pred in zip(hex_ids, predictions[0]):
                    print(f"{hex_id}: {pred[0]:.2f}")

            except Exception as e:
                print(f"Error in prediction step: {str(e)}")
                print("Continuing monitoring without predictions...")

            # Display statistics
            stats = monitor.get_grid_stats()
            print("\nGrid Statistics:")
            for key, value in stats.items():
                print(f"{key}: {value:.2f}")

            # Display visualization (guaranteed to exist now)
            display(m)

            # Wait before next update
            time.sleep(300)  # 5 minutes
            clear_output(wait=True)

    except KeyboardInterrupt:
        print("\nMonitoring stopped by user")
    except Exception as e:
        print(f"Error during monitoring: {str(e)}")
        raise

if __name__ == "__main__":
    main()