In [8]:
import numpy as np
import matplotlib.pyplot as plt
import os
from dataclasses import dataclass
from typing import List, Tuple, Dict
import json
import shutil
import matplotlib.gridspec as gridspec
import io
from PIL import Image
from IPython.display import Image as IPImage
import torch
from PIL import Image as PILImage
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
import ipywidgets as widgets
from IPython.display import display, IFrame
from matplotlib.animation import HTMLWriter
from matplotlib.patches import Circle

In [9]:
class Entity:
    def __init__(self, x, y, entity_type):
        self.x = x
        self.y = y
        self.type = entity_type

    def move(self, dx, dy):
        self.x += dx
        self.y += dy

def detect_collision(entities):
    collisions = []
    for i in range(len(entities)):
        for j in range(i+1, len(entities)):
            if entities[i].x == entities[j].x and entities[i].y == entities[j].y:
                collisions.append((entities[i], entities[j]))
    return collisions

In [10]:
class SyntheticDataGenerator:
    def __init__(self, grid_size=10, num_vehicles=5, num_pedestrians=5, num_obstacles=5):
        self.grid_size = grid_size
        self.num_vehicles = num_vehicles
        self.num_pedestrians = num_pedestrians
        self.num_obstacles = num_obstacles

    def is_valid_move(self, new_x, new_y, entities, moving_entity):
        """Check if a move is valid based on grid boundaries"""
        # Check grid boundaries
        if new_x < 0 or new_x >= self.grid_size or new_y < 0 or new_y >= self.grid_size:
            return False
        return True

    def get_safe_move(self, entity):
        """Get a safe move direction for an entity"""
        # Define possible movements (including staying in place)
        possible_moves = [(dx, dy) for dx in [-1, 0, 1] for dy in [-1, 0, 1]]
        
        # Shuffle possible moves for randomness
        np.random.shuffle(possible_moves)
        
        # Try each possible move
        for dx, dy in possible_moves:
            new_x = entity.x + dx
            new_y = entity.y + dy
            if self.is_valid_move(new_x, new_y, [], entity):
                return dx, dy
        
        # If no valid move is found, stay in place
        return 0, 0

    def generate_data(self, num_samples=100, num_steps=10):
        data = []
        for _ in range(num_samples):
            scenario = []
            entities = []
            
            # Generate initial positions
            for _ in range(self.num_vehicles):
                x, y = np.random.randint(0, self.grid_size, size=2)
                entities.append(Entity(x, y, 'vehicle'))
            for _ in range(self.num_pedestrians):
                x, y = np.random.randint(0, self.grid_size, size=2)
                entities.append(Entity(x, y, 'pedestrian'))
            for _ in range(self.num_obstacles):
                x, y = np.random.randint(0, self.grid_size, size=2)
                entities.append(Entity(x, y, 'obstacle'))
            
            scenario.append(entities)
            
            # Generate movements for subsequent steps
            for _ in range(num_steps - 1):
                new_entities = []
                for entity in entities:
                    
                    dx, dy = self.get_safe_move(entity)
                    new_entity = Entity(entity.x, entity.y, entity.type)
                    new_entity.move(dx, dy)
                    
                    new_entities.append(new_entity)
                scenario.append(new_entities)
                entities = new_entities
                
            data.append(scenario)
        return data

    def visualize_data(self, scenario, sample_id):
        # Create figure with adjusted size to accommodate legend
        fig = plt.figure(figsize=(8, 6))
        
        # Create gridspec to position the plot and legend
        gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
        
        # Create main plot and legend axes
        ax = plt.subplot(gs[0])
        legend_ax = plt.subplot(gs[1])
        legend_ax.axis('off')  # Hide the legend axes frame
        
        # Dictionary to store frame-specific collision locations
        collision_data = {}
        
        # Pre-compute collision frames and locations
        for frame, entities in enumerate(scenario):
            collisions = detect_collision(entities)
            if collisions:
                # Store only the exact collision locations for this frame
                collision_data[frame] = [(entity1.x, entity1.y) for entity1, entity2 in collisions]

        def update(frame):
            ax.clear()
            legend_ax.clear()
            legend_ax.axis('off')

            entities = scenario[frame]

            # Create empty lists for legend handles and labels
            legend_elements = []

            # Draw entities and collect unique elements for legend
            vehicles = ax.scatter([], [], color='blue', marker='^', s=200, label='Vehicle')
            pedestrians = ax.scatter([], [], color='red', marker='o', s=200, label='Pedestrian')
            obstacles = ax.scatter([], [], color='green', marker='s', s=200, label='Obstacle')

            # Add to legend elements
            legend_elements.extend([vehicles, pedestrians, obstacles])

            # Draw actual entities
            for entity in entities:
                if entity.type == 'vehicle':
                    ax.scatter(entity.x, entity.y, color='blue', marker='^', s=200)
                elif entity.type == 'pedestrian':
                    ax.scatter(entity.x, entity.y, color='red', marker='o', s=200)
                else:
                    ax.scatter(entity.x, entity.y, color='green', marker='s', s=200)

            # Draw collision circles only at exact collision locations
            if frame in collision_data:
                for x, y in collision_data[frame]:
                    circle = Circle((x, y), radius=1, color='red', alpha=0.3)
                    ax.add_patch(circle)
                ax.text(0.02, 0.98, 'COLLISION!', transform=ax.transAxes, 
                       color='red', fontsize=12, verticalalignment='top')
                # Add collision indicator to legend
                collision_patch = plt.scatter([], [], c='red', alpha=0.3, marker='o', s=500, label='Collision Area')
                legend_elements.append(collision_patch)

            # Set up the main plot with explicit grid lines and labels
            ax.set_xlim(-0.5, self.grid_size - 0.5)
            ax.set_ylim(-0.5, self.grid_size - 0.5)
            ax.set_title(f"Time Step {frame+1}")

            # Set major ticks at integer positions
            ax.set_xticks(range(self.grid_size))
            ax.set_yticks(range(self.grid_size))

            # Add grid with integer spacing
            ax.grid(True, which='major', linestyle='-', linewidth=1)

            # Add minor gridlines at 0.5 intervals if desired
            ax.grid(True, which='minor', linestyle=':', linewidth=0.5)

            # Set labels for axes
            ax.set_xlabel('X Coordinate')
            ax.set_ylabel('Y Coordinate')

            # Create legend in the separate axis
            legend_ax.legend(handles=legend_elements, 
                           labels=['Vehicle', 'Pedestrian', 'Obstacle', 'Collision Area'] if frame in collision_data 
                           else ['Vehicle', 'Pedestrian', 'Obstacle'],
                           loc='center left',
                           bbox_to_anchor=(0, 0.5))

            # Adjust layout to prevent overlap
            plt.tight_layout()

        ani = FuncAnimation(fig, update, frames=len(scenario), interval=500, repeat=False)

        # Create the animations folder if it doesn't exist
        os.makedirs('animations', exist_ok=True)

        # Save the animation as an HTML file
        html_path = f'animations/scenario_{sample_id}_allrandom.html'
        writer = HTMLWriter(fps=2)
        ani.save(html_path, writer=writer)

        plt.close(fig)
        return html_path

    def analyze_collisions(self, scenario):
        collision_report = []
        for timestep, entities in enumerate(scenario):
            collisions = detect_collision(entities)
            if collisions:
                for entity1, entity2 in collisions:
                    collision_report.append({
                        'timestep': timestep + 1,
                        'location': (entity1.x, entity1.y),
                        'entities': (entity1.type, entity2.type)
                    })
        return collision_report

In [11]:
# Initialize the base generator
base_generator = SyntheticDataGenerator(grid_size=10, num_vehicles=1, num_pedestrians=1, num_obstacles=0)

In [14]:
data = base_generator.generate_data(num_samples=1, num_steps=100)

for i, scenario in enumerate(data):
    print(f"\nScenario {i+1}:")
    html_path = base_generator.visualize_data(scenario, sample_id=i+1)
    display(IFrame(src=html_path, width=800, height=800))

    # Get and display detailed collision report
    collision_report = base_generator.analyze_collisions(scenario)
    if collision_report:
        print("\nCollision Report:")
        for collision in collision_report:
            print(f"Time Step {collision['timestep']}: "
                  f"{collision['entities'][0].capitalize()} collided with "
                  f"{collision['entities'][1]} at position {collision['location']}")
    else:
        print("No collisions detected in this scenario")
    print("-" * 50)


Scenario 1:



Collision Report:
Time Step 13: Vehicle collided with pedestrian at position (7, 1)
--------------------------------------------------


In [4]:
def convert_to_serializable(obj):
    """Convert numpy types to Python native types for JSON serialization"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, tuple):
        return tuple(convert_to_serializable(item) for item in obj)
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    return obj


In [5]:
@dataclass
class Sample:
    """Class to store information about a sequence of frames"""
    frames: List[List[Entity]]  # 7 consecutive frames
    has_collision: bool
    collision_details: Dict = None  # Only for collision samples
    sample_id: int = None
    
    def set_id(self, new_id: int):
        """Set the sample ID"""
        self.sample_id = new_id
        return self

In [6]:
class GNNDatasetGenerator:
    def __init__(self, base_generator: SyntheticDataGenerator):
        self.generator = base_generator
        self.dataset_root = "traffic_dataset"
        self.collision_dir = os.path.join(self.dataset_root, "collision")
        self.no_collision_dir = os.path.join(self.dataset_root, "no_collision")
        self.metadata_path = os.path.join(self.dataset_root, "metadata.json")
        
    def setup_directories(self):
        """Create necessary directories and clean existing data"""
        if os.path.exists(self.dataset_root):
            shutil.rmtree(self.dataset_root)
        
        os.makedirs(self.dataset_root)
        os.makedirs(self.collision_dir)
        os.makedirs(self.no_collision_dir)
        
    def save_frame(self, entities: List[Entity], path: str) -> None:
        """Save a single frame as an image"""
        fig, ax = plt.subplots(figsize=(8, 8))
        
        # Plot each entity
        for entity in entities:
            if entity.type == 'vehicle':
                ax.scatter(entity.x, entity.y, color='blue', marker='^', s=200)
            elif entity.type == 'pedestrian':
                ax.scatter(entity.x, entity.y, color='red', marker='o', s=200)
            else:  # obstacle
                ax.scatter(entity.x, entity.y, color='green', marker='s', s=200)
        
        # Set up the plot
        ax.set_xlim(-0.5, self.generator.grid_size - 0.5)
        ax.set_ylim(-0.5, self.generator.grid_size - 0.5)
        ax.grid(True)
        ax.set_xticks(range(self.generator.grid_size))
        ax.set_yticks(range(self.generator.grid_size))
        
        # Save and close
        plt.savefig(path)
        plt.close(fig)
    
    def check_sequence_validity(self, frames: List[List[Entity]], check_all: bool = False) -> bool:
        """
        Check if a sequence of frames is valid:
        - For collision samples (check_all=False): only last frame should have collision
        - For no-collision samples (check_all=True): no frame should have collision
        """
        for i, frame in enumerate(frames):
            has_collision = len(detect_collision(frame)) > 0
            if check_all and has_collision:
                return False
            elif not check_all and i < len(frames)-1 and has_collision:
                return False
            elif not check_all and i == len(frames)-1 and not has_collision:
                return False
        return True
    
    def extract_collision_samples(self, scenario: List[List[Entity]], window_size: int = 7) -> List[Sample]:
        """Extract valid collision samples from a scenario"""
        samples = []
        
        for i in range(window_size-1, len(scenario)):
            sequence = scenario[i-window_size+1:i+1]
            if self.check_sequence_validity(sequence, check_all=False):
                # Get collision details from the last frame
                collisions = detect_collision(sequence[-1])
                collision_info = {
                    "frame": int(i),  # Convert numpy.int64 to regular int
                    "collisions": [
                        {
                            "types": (c[0].type, c[1].type),
                            "location": (int(c[0].x), int(c[0].y))
                        } for c in collisions
                    ]
                }
                samples.append(Sample(sequence, True, collision_info))
        
        return samples
    
    def extract_no_collision_samples(self, scenario: List[List[Entity]], window_size: int = 7) -> List[Sample]:
        """Extract valid no-collision samples from a scenario"""
        samples = []
        
        for i in range(window_size-1, len(scenario)):
            sequence = scenario[i-window_size+1:i+1]
            if self.check_sequence_validity(sequence, check_all=True):
                samples.append(Sample(sequence, False, None))
        
        return samples
    
    def save_sample(self, sample: Sample):
        """Save a sample's frames and metadata"""
        if sample.sample_id is None:
            raise ValueError("Sample ID must be set before saving")
            
        # Determine target directory
        base_dir = self.collision_dir if sample.has_collision else self.no_collision_dir
        sample_dir = os.path.join(base_dir, f"sample_{sample.sample_id:05d}")
        os.makedirs(sample_dir)
        
        # Save frames
        for i, frame in enumerate(sample.frames):
            frame_path = os.path.join(sample_dir, f"frame_{i:02d}.png")
            self.save_frame(frame, frame_path)
        
        # Save frame-level metadata
        metadata = {
            "sample_id": int(sample.sample_id),
            "has_collision": sample.has_collision,
            "collision_details": convert_to_serializable(sample.collision_details)
        }
        
        with open(os.path.join(sample_dir, "metadata.json"), 'w') as f:
            json.dump(metadata, f, indent=2)
    
    def generate_dataset(self, num_samples: int = 100, num_steps: int = 500, window_size: int = 7):
        """Generate complete dataset with balanced collision and no-collision samples"""
        self.setup_directories()
        
        collision_samples = []
        no_collision_samples = []
        samples_per_category = num_samples // 2
        
        while len(collision_samples) < samples_per_category or len(no_collision_samples) < samples_per_category:
            # Generate a long scenario
            scenario = self.generator.generate_data(num_samples=1, num_steps=num_steps)[0]
            
            # Extract collision samples if needed
            if len(collision_samples) < samples_per_category:
                new_samples = self.extract_collision_samples(scenario, window_size)
                remaining = samples_per_category - len(collision_samples)
                for sample in new_samples[:remaining]:
                    # Assign consecutive IDs starting from 0 for collision samples
                    sample.set_id(len(collision_samples))
                    collision_samples.append(sample)
            
            # Extract no-collision samples if needed
            if len(no_collision_samples) < samples_per_category:
                new_samples = self.extract_no_collision_samples(scenario, window_size)
                remaining = samples_per_category - len(no_collision_samples)
                for sample in new_samples[:remaining]:
                    # Assign consecutive IDs starting from 0 for no-collision samples
                    sample.set_id(len(no_collision_samples))
                    no_collision_samples.append(sample)
        
        # Save all samples
        for sample in collision_samples[:samples_per_category]:
            self.save_sample(sample)
        for sample in no_collision_samples[:samples_per_category]:
            self.save_sample(sample)
        
        # Save global metadata
        global_metadata = {
            "total_samples": int(num_samples),
            "collision_samples": int(samples_per_category),
            "no_collision_samples": int(samples_per_category),
            "window_size": int(window_size),
            "grid_size": int(self.generator.grid_size),
            "num_vehicles": int(self.generator.num_vehicles),
            "num_pedestrians": int(self.generator.num_pedestrians),
            "num_obstacles": int(self.generator.num_obstacles)
        }
        
        with open(self.metadata_path, 'w') as f:
            json.dump(global_metadata, f, indent=2)

In [7]:
# Initialize the base generator
base_generator = SyntheticDataGenerator(grid_size=10, num_vehicles=1, num_pedestrians=1, num_obstacles=0)

# Create the GNN dataset generator
dataset_generator = GNNDatasetGenerator(base_generator)

# Generate the dataset
dataset_generator.generate_dataset(num_samples=1000, num_steps=500, window_size=7)