In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import io
from PIL import Image
import os
from IPython.display import Image as IPImage
import torch
from PIL import Image as PILImage
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
import ipywidgets as widgets
from IPython.display import display, IFrame
from matplotlib.animation import HTMLWriter
from matplotlib.patches import Circle

In [2]:
class Entity:
    def __init__(self, x, y, entity_type):
        self.x = x
        self.y = y
        self.type = entity_type
        # Add a flag to indicate if the entity can move
        self.is_movable = entity_type != 'obstacle'

    def move(self, dx, dy):
        # Only move if the entity is movable
        if self.is_movable:
            self.x += dx
            self.y += dy


In [3]:
def detect_collision(entities):
    collisions = []
    for i in range(len(entities)):
        for j in range(i+1, len(entities)):
            if entities[i].x == entities[j].x and entities[i].y == entities[j].y:
                collisions.append((entities[i], entities[j]))
    return collisions

In [4]:
class SyntheticDataGenerator:
    def __init__(self, grid_size=10, num_vehicles=5, num_pedestrians=5, num_obstacles=5):
        self.grid_size = grid_size
        self.num_vehicles = num_vehicles
        self.num_pedestrians = num_pedestrians
        self.num_obstacles = num_obstacles

    def generate_initial_positions(self):
        """Generate non-overlapping initial positions for all entities"""
        positions = set()
        entities = []

        # First place obstacles since they're fixed
        for _ in range(self.num_obstacles):
            while True:
                x, y = np.random.randint(0, self.grid_size, size=2)
                if (x, y) not in positions:
                    positions.add((x, y))
                    entities.append(Entity(x, y, 'obstacle'))
                    break

        # Then place vehicles and pedestrians
        for _ in range(self.num_vehicles):
            while True:
                x, y = np.random.randint(0, self.grid_size, size=2)
                if (x, y) not in positions:
                    positions.add((x, y))
                    entities.append(Entity(x, y, 'vehicle'))
                    break

        for _ in range(self.num_pedestrians):
            while True:
                x, y = np.random.randint(0, self.grid_size, size=2)
                if (x, y) not in positions:
                    positions.add((x, y))
                    entities.append(Entity(x, y, 'pedestrian'))
                    break

        return entities

    def generate_data(self, num_samples=100, num_steps=10):
        data = []
        for _ in range(num_samples):
            scenario = []
            # Generate initial non-overlapping positions
            entities = self.generate_initial_positions()
            scenario.append(entities)
            
            # Generate subsequent steps
            for _ in range(num_steps - 1):
                new_entities = []
                occupied_positions = set((e.x, e.y) for e in entities if not e.is_movable)
                
                for entity in entities:
                    if entity.is_movable:
                        # Try to find a valid move
                        max_attempts = 10
                        moved = False
                        
                        for _ in range(max_attempts):
                            dx, dy = np.random.randint(-1, 2, size=2)
                            new_x = entity.x + dx
                            new_y = entity.y + dy
                            
                            # Check if new position is within grid and not occupied by obstacle
                            if (0 <= new_x < self.grid_size and 
                                0 <= new_y < self.grid_size and 
                                (new_x, new_y) not in occupied_positions):
                                new_entity = Entity(entity.x, entity.y, entity.type)
                                new_entity.move(dx, dy)
                                new_entities.append(new_entity)
                                moved = True
                                break
                        
                        # If no valid move found, stay in place
                        if not moved:
                            new_entities.append(Entity(entity.x, entity.y, entity.type))
                    else:
                        # Copy obstacles without movement
                        new_entities.append(Entity(entity.x, entity.y, entity.type))
                
                scenario.append(new_entities)
                entities = new_entities
                
            data.append(scenario)
        return data

    def visualize_data(self, scenario, sample_id):
        # Create figure with adjusted size to accommodate legend
        fig = plt.figure(figsize=(8, 6))
        
        # Create gridspec to position the plot and legend
        gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
        
        # Create main plot and legend axes
        ax = plt.subplot(gs[0])
        legend_ax = plt.subplot(gs[1])
        legend_ax.axis('off')  # Hide the legend axes frame
        
        collision_frames = []
        collision_locations = []
        
        # Pre-compute collision frames and locations
        for frame, entities in enumerate(scenario):
            collisions = detect_collision(entities)
            if collisions:
                collision_frames.append(frame)
                for entity1, entity2 in collisions:
                    collision_locations.append((entity1.x, entity1.y))

        def update(frame):
            ax.clear()
            legend_ax.clear()
            legend_ax.axis('off')
            
            entities = scenario[frame]
            
            # Create empty lists for legend handles and labels
            legend_elements = []
            
            # Draw entities and collect unique elements for legend
            vehicles = ax.scatter([], [], color='red', marker='o', label='Vehicle')
            pedestrians = ax.scatter([], [], color='blue', marker='^', label='Pedestrian')
            obstacles = ax.scatter([], [], color='green', marker='s', label='Obstacle')
            
            # Add to legend elements
            legend_elements.extend([vehicles, pedestrians, obstacles])
            
            # Draw actual entities
            for entity in entities:
                if entity.type == 'vehicle':
                    ax.scatter(entity.x, entity.y, color='red', marker='o')
                elif entity.type == 'pedestrian':
                    ax.scatter(entity.x, entity.y, color='blue', marker='^')
                else:  # obstacle
                    ax.scatter(entity.x, entity.y, color='green', marker='s')

            # Draw collision circles if this is a collision frame
            if frame in collision_frames:
                for x, y in collision_locations:
                    circle = Circle((x, y), radius=1.0, color='red', alpha=0.3)
                    ax.add_patch(circle)
                ax.text(0.02, 0.98, 'COLLISION!', transform=ax.transAxes, 
                       color='red', fontsize=12, verticalalignment='top')
                # Add collision indicator to legend
                collision_patch = plt.scatter([], [], c='red', alpha=0.3, marker='o', s=300, label='Collision Area')
                legend_elements.append(collision_patch)

            # Set up the main plot
            ax.set_xlim(-1, self.grid_size + 1)
            ax.set_ylim(-1, self.grid_size + 1)
            ax.set_title(f"Time Step {frame+1}")
            ax.grid(True)
            
            # Create legend in the separate axis
            legend_ax.legend(handles=legend_elements, 
                           labels=['Vehicle', 'Pedestrian', 'Obstacle', 'Collision Area'] if frame in collision_frames 
                           else ['Vehicle', 'Pedestrian', 'Obstacle'],
                           loc='center left',
                           bbox_to_anchor=(0, 0.5))
            
            # Adjust layout to prevent overlap
            plt.tight_layout()

        ani = FuncAnimation(fig, update, frames=len(scenario), interval=500, repeat=False)

        # Create the animations folder if it doesn't exist
        os.makedirs('animations', exist_ok=True)

        # Save the animation as an HTML file
        html_path = f'animations/scenario_{sample_id}_obstacle_fixed.html'
        writer = HTMLWriter(fps=2)
        ani.save(html_path, writer=writer)

        plt.close(fig)
        return html_path

    def analyze_collisions(self, scenario):
        collision_report = []
        for timestep, entities in enumerate(scenario):
            collisions = detect_collision(entities)
            if collisions:
                for entity1, entity2 in collisions:
                    collision_report.append({
                        'timestep': timestep + 1,
                        'location': (entity1.x, entity1.y),
                        'entities': (entity1.type, entity2.type)
                    })
        return collision_report

In [7]:
generator = SyntheticDataGenerator(grid_size=20, num_vehicles=3, num_pedestrians=4, num_obstacles=4)
data = generator.generate_data(num_samples=1, num_steps=20)

for i, scenario in enumerate(data):
    print(f"\nScenario {i+1}:")
    html_path = generator.visualize_data(scenario, sample_id=i+1)
    display(IFrame(src=html_path, width=800, height=800))

    # Get and display detailed collision report
    collision_report = generator.analyze_collisions(scenario)
    if collision_report:
        print("\nCollision Report:")
        for collision in collision_report:
            print(f"Time Step {collision['timestep']}: "
                  f"{collision['entities'][0].capitalize()} collided with "
                  f"{collision['entities'][1]} at position {collision['location']}")
    else:
        print("No collisions detected in this scenario")
    print("-" * 50)


Scenario 1:



Collision Report:
Time Step 15: Vehicle collided with pedestrian at position (3, 16)
Time Step 17: Pedestrian collided with pedestrian at position (13, 2)
--------------------------------------------------
