In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import io
from PIL import Image
import os
from IPython.display import Image as IPImage
import torch
from PIL import Image as PILImage
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
import ipywidgets as widgets
from IPython.display import display, IFrame
from matplotlib.animation import HTMLWriter
from matplotlib.patches import Circle

In [2]:
class Entity:
    def __init__(self, x, y, entity_type):
        self.x = x
        self.y = y
        self.type = entity_type

    def move(self, dx, dy):
        self.x += dx
        self.y += dy

def detect_collision(entities):
    collisions = []
    for i in range(len(entities)):
        for j in range(i+1, len(entities)):
            if entities[i].x == entities[j].x and entities[i].y == entities[j].y:
                collisions.append((entities[i], entities[j]))
    return collisions

In [3]:
class SyntheticDataGenerator:
    def __init__(self, grid_size=10, num_vehicles=5, num_pedestrians=5, num_obstacles=5):
        self.grid_size = grid_size
        self.num_vehicles = num_vehicles
        self.num_pedestrians = num_pedestrians
        self.num_obstacles = num_obstacles

    def is_valid_move(self, new_x, new_y, entities, moving_entity):
        """Check if a move is valid based on grid boundaries"""
        # Check grid boundaries
        if new_x < 0 or new_x >= self.grid_size or new_y < 0 or new_y >= self.grid_size:
            return False
        return True

    def get_safe_move(self, entity):
        """Get a safe move direction for an entity"""
        # Define possible movements (including staying in place)
        possible_moves = [(dx, dy) for dx in [-1, 0, 1] for dy in [-1, 0, 1]]
        
        # Shuffle possible moves for randomness
        np.random.shuffle(possible_moves)
        
        # Try each possible move
        for dx, dy in possible_moves:
            new_x = entity.x + dx
            new_y = entity.y + dy
            if self.is_valid_move(new_x, new_y, [], entity):
                return dx, dy
        
        # If no valid move is found, stay in place
        return 0, 0

    def generate_data(self, num_samples=100, num_steps=10):
        data = []
        for _ in range(num_samples):
            scenario = []
            entities = []
            
            # Generate initial positions
            for _ in range(self.num_vehicles):
                x, y = np.random.randint(0, self.grid_size, size=2)
                entities.append(Entity(x, y, 'vehicle'))
            for _ in range(self.num_pedestrians):
                x, y = np.random.randint(0, self.grid_size, size=2)
                entities.append(Entity(x, y, 'pedestrian'))
            for _ in range(self.num_obstacles):
                x, y = np.random.randint(0, self.grid_size, size=2)
                entities.append(Entity(x, y, 'obstacle'))
            
            scenario.append(entities)
            
            # Generate movements for subsequent steps
            for _ in range(num_steps - 1):
                new_entities = []
                for entity in entities:
                    
                    dx, dy = self.get_safe_move(entity)
                    new_entity = Entity(entity.x, entity.y, entity.type)
                    new_entity.move(dx, dy)
                    
                    new_entities.append(new_entity)
                scenario.append(new_entities)
                entities = new_entities
                
            data.append(scenario)
        return data

    def visualize_data(self, scenario, sample_id):
        # Create figure with adjusted size to accommodate legend
        fig = plt.figure(figsize=(8, 6))
        
        # Create gridspec to position the plot and legend
        gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
        
        # Create main plot and legend axes
        ax = plt.subplot(gs[0])
        legend_ax = plt.subplot(gs[1])
        legend_ax.axis('off')  # Hide the legend axes frame
        
        # Dictionary to store frame-specific collision locations
        collision_data = {}
        
        # Pre-compute collision frames and locations
        for frame, entities in enumerate(scenario):
            collisions = detect_collision(entities)
            if collisions:
                # Store only the exact collision locations for this frame
                collision_data[frame] = [(entity1.x, entity1.y) for entity1, entity2 in collisions]

        def update(frame):
            ax.clear()
            legend_ax.clear()
            legend_ax.axis('off')

            entities = scenario[frame]

            # Create empty lists for legend handles and labels
            legend_elements = []

            # Draw entities and collect unique elements for legend
            vehicles = ax.scatter([], [], color='blue', marker='^', s=200, label='Vehicle')
            pedestrians = ax.scatter([], [], color='red', marker='o', s=200, label='Pedestrian')
            obstacles = ax.scatter([], [], color='green', marker='s', s=200, label='Obstacle')

            # Add to legend elements
            legend_elements.extend([vehicles, pedestrians, obstacles])

            # Draw actual entities
            for entity in entities:
                if entity.type == 'vehicle':
                    ax.scatter(entity.x, entity.y, color='blue', marker='^', s=200)
                elif entity.type == 'pedestrian':
                    ax.scatter(entity.x, entity.y, color='red', marker='o', s=200)
                else:
                    ax.scatter(entity.x, entity.y, color='green', marker='s', s=200)

            # Draw collision circles only at exact collision locations
            if frame in collision_data:
                for x, y in collision_data[frame]:
                    circle = Circle((x, y), radius=1, color='red', alpha=0.3)
                    ax.add_patch(circle)
                ax.text(0.02, 0.98, 'COLLISION!', transform=ax.transAxes, 
                       color='red', fontsize=12, verticalalignment='top')
                # Add collision indicator to legend
                collision_patch = plt.scatter([], [], c='red', alpha=0.3, marker='o', s=500, label='Collision Area')
                legend_elements.append(collision_patch)

            # Set up the main plot with explicit grid lines and labels
            ax.set_xlim(-0.5, self.grid_size - 0.5)
            ax.set_ylim(-0.5, self.grid_size - 0.5)
            ax.set_title(f"Time Step {frame+1}")

            # Set major ticks at integer positions
            ax.set_xticks(range(self.grid_size))
            ax.set_yticks(range(self.grid_size))

            # Add grid with integer spacing
            ax.grid(True, which='major', linestyle='-', linewidth=1)

            # Add minor gridlines at 0.5 intervals if desired
            ax.grid(True, which='minor', linestyle=':', linewidth=0.5)

            # Set labels for axes
            ax.set_xlabel('X Coordinate')
            ax.set_ylabel('Y Coordinate')

            # Create legend in the separate axis
            legend_ax.legend(handles=legend_elements, 
                           labels=['Vehicle', 'Pedestrian', 'Obstacle', 'Collision Area'] if frame in collision_data 
                           else ['Vehicle', 'Pedestrian', 'Obstacle'],
                           loc='center left',
                           bbox_to_anchor=(0, 0.5))

            # Adjust layout to prevent overlap
            plt.tight_layout()

        ani = FuncAnimation(fig, update, frames=len(scenario), interval=500, repeat=False)

        # Create the animations folder if it doesn't exist
        os.makedirs('animations', exist_ok=True)

        # Save the animation as an HTML file
        html_path = f'animations/scenario_{sample_id}_allrandom.html'
        writer = HTMLWriter(fps=2)
        ani.save(html_path, writer=writer)

        plt.close(fig)
        return html_path

    def analyze_collisions(self, scenario):
        collision_report = []
        for timestep, entities in enumerate(scenario):
            collisions = detect_collision(entities)
            if collisions:
                for entity1, entity2 in collisions:
                    collision_report.append({
                        'timestep': timestep + 1,
                        'location': (entity1.x, entity1.y),
                        'entities': (entity1.type, entity2.type)
                    })
        return collision_report

In [4]:
generator = SyntheticDataGenerator(grid_size=20, num_vehicles=5, num_pedestrians=5, num_obstacles=5)
data = generator.generate_data(num_samples=1, num_steps=100)

for i, scenario in enumerate(data):
    print(f"\nScenario {i+1}:")
    html_path = generator.visualize_data(scenario, sample_id=i+1)
    display(IFrame(src=html_path, width=800, height=800))

    # Get and display detailed collision report
    collision_report = generator.analyze_collisions(scenario)
    if collision_report:
        print("\nCollision Report:")
        for collision in collision_report:
            print(f"Time Step {collision['timestep']}: "
                  f"{collision['entities'][0].capitalize()} collided with "
                  f"{collision['entities'][1]} at position {collision['location']}")
    else:
        print("No collisions detected in this scenario")
    print("-" * 50)


Scenario 1:



Collision Report:
Time Step 1: Vehicle collided with vehicle at position (16, 8)
Time Step 6: Vehicle collided with vehicle at position (11, 12)
Time Step 8: Vehicle collided with pedestrian at position (16, 8)
Time Step 13: Vehicle collided with vehicle at position (18, 3)
Time Step 13: Pedestrian collided with pedestrian at position (9, 8)
Time Step 15: Pedestrian collided with obstacle at position (7, 7)
Time Step 21: Vehicle collided with obstacle at position (13, 16)
Time Step 23: Obstacle collided with obstacle at position (3, 9)
Time Step 32: Pedestrian collided with obstacle at position (3, 3)
Time Step 41: Pedestrian collided with obstacle at position (11, 6)
Time Step 52: Vehicle collided with pedestrian at position (15, 6)
Time Step 57: Vehicle collided with pedestrian at position (6, 10)
Time Step 65: Vehicle collided with pedestrian at position (12, 4)
Time Step 73: Vehicle collided with obstacle at position (10, 3)
Time Step 74: Pedestrian collided with obstacle at posit