In [2]:
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import json
import pickle
import folium
from typing import List, Union, Optional, Tuple, Dict
import sys
from collections import defaultdict
from tqdm import tqdm


def load_all_real_trajectories(preprocessing_dir: Path, transport_mode: str):
    """Load ALL real trajectories for a transport mode."""
    # Load interpolated trips
    with open(preprocessing_dir / 'interpolated_trips.pkl', 'rb') as f:
        interpolated_trips = pickle.load(f)
    
    # Filter by transport mode
    mode_trips = [t for t in interpolated_trips if t['category'] == transport_mode]
    
    if len(mode_trips) == 0:
        print(f"No trips found for mode: {transport_mode}")
        return []
    
    print(f"Found {len(mode_trips)} {transport_mode} trajectories")
    
    # Convert all trips
    real_trajectories = []
    for trip in mode_trips:
        # GPS points: [timestamp, lat, lon, speed]
        gps_points = trip['gps_points']
        # Extract lat, lon, speed
        trajectory = gps_points[:, 1:4].astype(np.float32)
        
        real_trajectories.append({
            'trajectory': trajectory,
            'trip_id': trip['trip_id'],
            'user_id': trip['user_id'],
            'category': trip['category'],
            'trip_type': trip['trip_type'],
            'original_duration': trip['duration_minutes'],
            'length': len(trajectory),
            'weight': trip.get('weight', 1.0)
        })
    
    return real_trajectories

def create_all_real_trajectories_map(
    all_real_trajectories: Dict[str, List[Dict]],
    output_file: str = "all_real_trajectories.html",
    max_trajectories_per_mode: int = None
) -> folium.Map:
    """
    Create a map with ALL real trajectories, colored by transport type.
    
    Args:
        all_real_trajectories: Dict with transport mode as key and list of trajectory dicts as value
        output_file: Path to save the HTML file
        max_trajectories_per_mode: Optional limit on trajectories per mode (for performance)
    """
    
    # Define colors for each transport mode
    mode_colors = {
        'CAR': '#FF0000',        # Red
        'WALKING': '#00FF00',    # Green
        'BIKE': '#0000FF',       # Blue
        'PUBLIC_TRANSPORT': '#FF00FF',  # Magenta
        'MIXED': '#FFA500'       # Orange
    }
    
    # Collect all trajectories with their modes
    all_trajectories = []
    all_lats = []
    all_lons = []
    
    total_trajectories = 0
    for mode, trajectories in all_real_trajectories.items():
        if len(trajectories) == 0:
            continue
            
        # Limit trajectories per mode if specified
        mode_trajectories = trajectories
        if max_trajectories_per_mode and len(trajectories) > max_trajectories_per_mode:
            sample_indices = np.random.choice(len(trajectories), max_trajectories_per_mode, replace=False)
            mode_trajectories = [trajectories[i] for i in sample_indices]
        
        for i, traj_info in enumerate(mode_trajectories):
            traj = traj_info['trajectory']
            
            # For real trajectories, check for zero padding differently
            valid_mask = ~np.all(traj == 0, axis=1)
            valid_traj = traj[valid_mask]
            
            # Extract coordinates
            points = [(lat, lon) for lat, lon in valid_traj[:, :2]]
            if len(points) > 1:  # Only add trajectories with more than 1 point
                all_trajectories.append({
                    'points': points,
                    'mode': mode,
                    'color': mode_colors.get(mode, '#808080'),  # Default to gray
                    'label': f"{mode} - Trip ID: {traj_info.get('trip_id', 'Unknown')}"
                })
                
                # Collect coordinates for centering
                lats, lons = zip(*points)
                all_lats.extend(lats)
                all_lons.extend(lons)
                total_trajectories += 1
    
    print(f"Creating map with {total_trajectories} real trajectories")
    
    if not all_trajectories:
        print("No valid trajectories found!")
        return None
    
    # Calculate center
    center = (np.mean(all_lats), np.mean(all_lons))
    
    # Create map
    m = folium.Map(location=center, zoom_start=10)
    
    # Add trajectories with thicker lines for real data
    for traj_data in all_trajectories:
        folium.PolyLine(
            traj_data['points'],
            color=traj_data['color'],
            weight=3,  # Slightly thicker for real trajectories
            opacity=0.8,  # Slightly more opaque
            popup=traj_data['label']
        ).add_to(m)
    
    # Create legend
    legend_items = []
    mode_counts = {}
    for mode, trajectories in all_real_trajectories.items():
        if len(trajectories) > 0:
            actual_count = len(trajectories)
            displayed_count = min(actual_count, max_trajectories_per_mode) if max_trajectories_per_mode else actual_count
            mode_counts[mode] = (displayed_count, actual_count)
            color = mode_colors.get(mode, '#808080')
            
            # Calculate total weight for this mode
            total_weight = sum(t.get('weight', 1.0) for t in trajectories)
            
            legend_items.append(f'<span style="color: {color};">‚óè</span> {mode}: {displayed_count:,} trajectories (weight: {total_weight:,.0f})')
            if max_trajectories_per_mode and actual_count > max_trajectories_per_mode:
                legend_items[-1] += f' [showing {displayed_count} of {actual_count:,} total]'
    
    legend_html = f'''
    <div style="position: fixed; 
                top: 20px; right: 20px; width: 350px; height: auto;
                background-color: white; border:2px solid grey; z-index:9999; 
                font-size:14px; padding: 15px; border-radius: 5px;">
    <h4 style="margin: 0 0 10px 0;">Real Trajectories by Transport Mode</h4>
    <p style="margin: 5px 0;"><b>Total displayed: {total_trajectories:,} trajectories</b></p>
    {'<br>'.join(legend_items)}
    <p style="margin: 10px 0 0 0; font-size: 12px; color: #666;">
    Click on any trajectory line for trip details
    </p>
    </div>
    '''
    m.get_root().html.add_child(folium.Element(legend_html))
    
    # Save map
    m.save(output_file)
    print(f"All real trajectories visualization saved to: {output_file}")
    
    # Print summary
    print(f"\nVisualization Summary:")
    for mode, (displayed, total) in mode_counts.items():
        print(f"  {mode}: {displayed:,} trajectories displayed (of {total:,} total)")
    
    return m


def load_and_visualize_all_real_trajectories(
    preprocessing_dir: Path,
    transport_modes: List[str] = ["CAR", "WALKING", "MIXED", "BIKE", "PUBLIC_TRANSPORT"],
    output_file: str = "all_real_trajectories.html",
    max_trajectories_per_mode: int = 500
):
    """
    Load all real trajectories and create a combined visualization map.
    
    Args:
        preprocessing_dir: Path to the preprocessing directory
        transport_modes: List of transport modes to include
        output_file: Path to save the HTML file
        max_trajectories_per_mode: Maximum trajectories to display per mode
    """
    
    # Storage for all real trajectories
    all_real_trajectories = {}
    
    print("Loading real trajectories for all transport modes...")
    print("="*60)
    
    # Load trajectories for each mode
    for mode in transport_modes:
        print(f"\nLoading {mode} trajectories...")
        real_trajectories = load_all_real_trajectories(preprocessing_dir, mode)
        
        if len(real_trajectories) > 0:
            all_real_trajectories[mode] = real_trajectories
            print(f"  Loaded {len(real_trajectories)} {mode} trajectories")
        else:
            print(f"  No {mode} trajectories found")
    
    # Create the visualization
    if all_real_trajectories:
        print(f"\n{'='*60}")
        print("Creating combined map with all real trajectories...")
        print('='*60)
        
        map_obj = create_all_real_trajectories_map(
            all_real_trajectories,
            output_file=output_file,
            max_trajectories_per_mode=max_trajectories_per_mode
        )
        
        # Print statistics
        print(f"\n{'='*60}")
        print("Real Trajectories Statistics:")
        print('='*60)
        
        total_trajectories = 0
        total_weight = 0
        
        for mode, trajectories in all_real_trajectories.items():
            mode_weight = sum(t.get('weight', 1.0) for t in trajectories)
            total_trajectories += len(trajectories)
            total_weight += mode_weight
            
            print(f"{mode:20s}: {len(trajectories):6,} trajectories (weight: {mode_weight:8,.0f})")
        
        print(f"{'TOTAL':20s}: {total_trajectories:6,} trajectories (weight: {total_weight:8,.0f})")
        
        return map_obj, all_real_trajectories
    else:
        print("No trajectories found to visualize!")
        return None, {}


# Example usage to create real trajectories map
if __name__ == "__main__":
    # Set paths
    preprocessing_dir = Path("../data/processed")
    
    # Create visualization of all real trajectories
    map_obj, real_trajectories = load_and_visualize_all_real_trajectories(
        preprocessing_dir=preprocessing_dir,
        transport_modes=["CAR", "WALKING", "MIXED", "BIKE", "PUBLIC_TRANSPORT"],
        output_file="all_real_trajectories.html",
        max_trajectories_per_mode=500  # Limit for performance
    )

Loading real trajectories for all transport modes...

Loading CAR trajectories...
Found 25169 CAR trajectories
  Loaded 25169 CAR trajectories

Loading WALKING trajectories...
Found 19948 WALKING trajectories
  Loaded 19948 WALKING trajectories

Loading MIXED trajectories...
Found 7526 MIXED trajectories
  Loaded 7526 MIXED trajectories

Loading BIKE trajectories...
Found 5979 BIKE trajectories
  Loaded 5979 BIKE trajectories

Loading PUBLIC_TRANSPORT trajectories...
Found 9432 PUBLIC_TRANSPORT trajectories
  Loaded 9432 PUBLIC_TRANSPORT trajectories

Creating combined map with all real trajectories...
Creating map with 2500 real trajectories
All real trajectories visualization saved to: all_real_trajectories.html

Visualization Summary:
  CAR: 500 trajectories displayed (of 25,169 total)
  WALKING: 500 trajectories displayed (of 19,948 total)
  MIXED: 500 trajectories displayed (of 7,526 total)
  BIKE: 500 trajectories displayed (of 5,979 total)
  PUBLIC_TRANSPORT: 500 trajectories di