In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
from typing import List, Tuple
from scipy.optimize import linear_sum_assignment  # For Hungarian algorithm
from collections import deque


class GridWorld:
    def __init__(self, grid_size: int, n_obstacles: int, min_obstacle_len: int, max_obstacle_len: int):
        self.grid_size = grid_size
        self.grid = np.zeros((grid_size, grid_size), dtype=int)
        self._place_obstacles(n_obstacles, min_obstacle_len, max_obstacle_len)

    def _place_obstacles(self, n_obstacles: int, min_len: int, max_len: int) -> None:
        for _ in range(n_obstacles):
            start_x, start_y = random.randint(0, self.grid_size - 1), random.randint(0, self.grid_size - 1)
            length = random.randint(min_len, max_len)
            orientation = random.choice(["horizontal", "vertical"])
            for i in range(length):
                x, y = (start_x, (start_y + i) % self.grid_size) if orientation == "horizontal" else \
                       ((start_x + i) % self.grid_size, start_y)
                self.grid[x, y] = -1  # Mark as obstacle
    
    def bfs(self, start_x, start_y):
        # BFS to find shortest paths from (start_x, start_y)
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # Up, Down, Left, Right
        queue = deque([(start_x, start_y)])
        distances = np.full_like(self.grid, -1, dtype=int)  # -1 means unvisited
        distances[start_x, start_y] = 0
        
        while queue:
            x, y = queue.popleft()
            for dx, dy in directions:
                nx, ny = x + dx, y + dy
                if 0 <= nx < self.grid_size and 0 <= ny < self.grid_size and self.grid[nx, ny] != -1:
                    if distances[nx, ny] == -1:  # Not visited
                        distances[nx, ny] = distances[x, y] + 1
                        queue.append((nx, ny))
        return distances

    def calculate_total_distance(self, agent_positions):
        # Calculate the total distance from all positions to the nearest agent
        total_distance = 0
        for i in range(self.grid_size):
            for j in range(self.grid_size):
                if self.grid[i, j] == 0:  # Free space
                    # Find shortest distance to any of the agents
                    shortest_distance = min(
                        np.abs(i - ax) + np.abs(j - ay) for ax, ay in agent_positions
                    )
                    total_distance += shortest_distance
        return total_distance

    def find_optimal_agent_positions(self, n_agents):
        # Step 1: Find a list of all free positions
        free_positions = [(i, j) for i in range(self.grid_size) for j in range(self.grid_size) if self.grid[i, j] == 0]
        # Step 2: Greedily place the first agent in the position that minimizes distance to all others
        agent_positions = []
        # Try placing agents at different positions and select the best configuration
        for _ in range(n_agents):
            best_position = None
            best_distance = float('inf')
            for position in free_positions:
                # Calculate the total distance if this position is chosen for the agent
                candidate_positions = agent_positions + [position]
                total_distance = self.calculate_total_distance(candidate_positions)
                if total_distance < best_distance:
                    best_position = position
                    best_distance = total_distance
            # Place the best agent found
            agent_positions.append(best_position)
            free_positions.remove(best_position)  # Remove the placed agent position from available positions
        return agent_positions




class Simulation:
    def __init__(
        self, grid_size: int, n_obstacles: int, min_obstacle_len: int, max_obstacle_len: int,
        n_agents: int, n_goals: int, swarm_center: Tuple[int, int], swarm_radius: int
    ):
        self.grid_world = GridWorld(grid_size, n_obstacles, min_obstacle_len, max_obstacle_len)
        self.grid = self.grid_world.grid
        self.agents = self._initialize_positions(n_agents, swarm_center, swarm_radius)
        self.goals = self._initialize_positions(n_goals, None, None)

    def _initialize_positions(self, n: int, center: Tuple[int, int], radius: int) -> List[Tuple[int, int]]:
        positions = []
        while len(positions) < n:
            x, y = random.randint(0, self.grid.shape[0] - 1), random.randint(0, self.grid.shape[1] - 1)
            if center:
                x = random.randint(max(0, center[0] - radius), min(self.grid.shape[0] - 1, center[0] + radius))
                y = random.randint(max(0, center[1] - radius), min(self.grid.shape[1] - 1, center[1] + radius))
            if (x, y) not in positions and self.grid[x, y] != -1:
                positions.append((x, y))
        return positions

    def shortest_path(self, start: Tuple[int, int], goal: Tuple[int, int]) -> List[Tuple[int, int]]:
        """Compute the shortest path from start to goal using BFS."""
        queue = [(start, [start])]
        visited = set()
        while queue:
            (x, y), path = queue.pop(0)
            if (x, y) == goal:
                return path
            if (x, y) in visited or self.grid[x, y] == -1:
                continue
            visited.add((x, y))
            for nx, ny in [(x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)]:
                if 0 <= nx < self.grid.shape[0] and 0 <= ny < self.grid.shape[1]:
                    queue.append(((nx, ny), path + [(nx, ny)]))
        return []

    def save_frame(self, iteration: int, image_list: List[Image.Image]) -> None:
        """Save a frame of the grid with agents and goals."""
        grid_display = self.grid.copy()
        for gx, gy in self.goals:
            grid_display[gx, gy] = 2
        for ax, ay in self.agents:
            grid_display[ax, ay] = -5
        plt.imshow(grid_display, cmap="gray", interpolation="nearest")
        plt.title(f"Iteration {iteration}")
        plt.axis("off")
        plt.tight_layout()
        filename = f"frame_{iteration}.png"
        plt.savefig(filename)
        image_list.append(Image.open(filename))
        plt.close()

    def run_simulation(self, n_iters: int, gif_frame_duration: int) -> None:
        """Run the simulation where agents move towards their assigned goals."""
        #heuristic_goals = self.grid_world.find_optimal_agent_positions(n_agents=len(self.agents) - len(self.goals)) TBD
        image_list = []
        for iteration in range(n_iters):
            # If there are fewer goals than agents, not all agents will be assigned a goal.
            cost_matrix = np.zeros((len(self.agents), len(self.goals)))
            for i, agent in enumerate(self.agents):
                for j, goal in enumerate(self.goals):
                    path = self.shortest_path(agent, goal)
                    cost_matrix[i, j] = len(path) if path else 1e6  # Assign high cost if no path

            # Use Hungarian algorithm for optimal assignment
            agent_indices, goal_indices = linear_sum_assignment(cost_matrix)
            # Update agents and check for reached goals
            new_positions = []
            reached_goals = set()
            for a_idx, g_idx in zip(agent_indices, goal_indices):
                agent = self.agents[a_idx]
                goal = self.goals[g_idx]
                path = self.shortest_path(agent, goal)
                if path:
                    next_step = path[1] if len(path) > 1 else path[0]
                    new_positions.append(next_step)
                    if next_step == goal:
                        reached_goals.add(goal)

            # don't move unassigned agents
            for i in [i for i in range(len(self.agents)) if i not in agent_indices]:
                new_positions.insert(i, self.agents[i])

            # Remove reached goals and generate new ones
            for goal in reached_goals:
                self.goals.remove(goal)
                while True:
                    new_goal = (random.randint(0, self.grid.shape[0] - 1), random.randint(0, self.grid.shape[1] - 1))
                    if new_goal not in self.goals and self.grid[new_goal] != -1:
                        self.goals.append(new_goal)
                        break
            self.agents = new_positions

            self.save_frame(iteration, image_list)

        # Save as GIF
        image_list[0].save(
            "swarm.gif",
            save_all=True,
            append_images=image_list[1:],
            duration=gif_frame_duration,
            loop=0,
        )


simulation = Simulation(
    grid_size=50,
    n_obstacles=20,
    min_obstacle_len=6,
    max_obstacle_len=20,
    n_agents=8,
    n_goals=3,
    swarm_center=(8, 8),
    swarm_radius=8,
)
simulation.run_simulation(n_iters=128, gif_frame_duration=50)