In [None]:
####################
## Algoreitm     ###
## Azizi   ##
## Adrani  ##
####################
from env.env import play
from collections import deque
import heapq, time, pygame
import heapq, time
from collections import deque
from heapq import heappush, heappop
##########################
# Utility Structures   ###
##########################

class Node:
    __slots__ = ("state", "parent", "action", "g", "depth")

    def __init__(self, state, parent=None, action=None, g=0.0, depth=0):
        self.state = state
        self.parent = parent
        self.action = action
        self.g = g
        self.depth = depth


def reconstruct(node):
    path = []
    while node and node.parent:
        path.append(node.action)
        node = node.parent
    return list(reversed(path))


def state_key(s):
    epath = s.get_enemy_path()
    step_mod = s.get_step_count() % len(epath) if epath else 0
    return (s.get_agent_position(), tuple(sorted(s.get_targets_positions())), step_mod)

#################################
# SMART BREADTH-FIRST SEARCH   ##
#################################

def bfs(initial_state):
    """
    BFS derived from Learning Orbis maze BFS.
    Optimized for grid maps with obstacles and moving enemies.
    """
    start_time = time.time()
    root = Node(initial_state)
    frontier = deque([root])
    visited = {state_key(initial_state): None}
    expanded = 0

    rows, cols = initial_state.get_grid_size()
    MAX_EXPANSIONS = min(500000000, rows * cols * 800000)
    MAX_DEPTH = rows + cols + 3000

    while frontier:
        node = frontier.popleft()
        state = node.state

        ####################
        # Check for goal  ##
        ####################
        if state.is_goal_state():
            elapsed = time.time() - start_time
            print(f"[BFS] ✅ Goal found | Expanded {expanded} | {elapsed:.2f}s")
            return reconstruct(node)

        if state.is_collision_state():
            continue

        expanded += 1
        if expanded > MAX_EXPANSIONS:
            print(f"[BFS] ⚠️ Expansion limit reached ({expanded}) — stopping search.")
            break

        if node.depth >= MAX_DEPTH:
            continue
        ############################################################
        # Explore neighbors in fixed NESW order (like Orbis BFS)  ##
        ############################################################
        for a, c, nxt in state.get_successors():
            if nxt.is_collision_state():
                continue
            key = state_key(nxt)
            if key not in visited:
                visited[key] = state_key(state)
                frontier.append(Node(nxt, node, a, node.g + c, node.depth + 1))

    print("[BFS] ❌ No solution found within limits.")
    return []


# ################################
# ADAPTIVE UNIFORM-COST SEARCH  ##
# ################################

def ucs(initial_state):
    """Adaptive-cost UCS that avoids deep exploration on medium maps."""
    start_time = time.time()
    root = Node(initial_state)
    frontier = []
    heapq.heappush(frontier, (0, 0, root))
    g_best = {state_key(initial_state): 0}
    expanded = 0
    counter = 0

    rows, cols = initial_state.get_grid_size()
    MAX_EXPANSIONS = min(500000000, rows * cols * 800000)
    COST_CAP = rows * cols * 300  

    while frontier:
        g, _, node = heapq.heappop(frontier)
        s = node.state
        k = state_key(s)

        if g > g_best.get(k, float('inf')):
            continue
        if g > COST_CAP:
            continue

        if s.is_goal_state():
            elapsed = time.time() - start_time
            print(f"[UCS] ✅ Goal reached | Expanded {expanded} | {elapsed:.2f}s | Cost {g}")
            return reconstruct(node)

        if s.is_collision_state():
            continue

        expanded += 1
        if expanded > MAX_EXPANSIONS:
            print("[UCS] ⚠️ Expansion limit reached.")
            break

        for a, c, nxt in s.get_successors():
            if nxt.is_collision_state():
                continue
            ng = g + c
            if ng > COST_CAP:
                continue
            nk = state_key(nxt)
            if ng < g_best.get(nk, float('inf')):
                g_best[nk] = ng
                counter += 1
                heapq.heappush(frontier, (ng, counter,
                                          Node(nxt, node, a, ng, node.depth + 1)))

    print("[UCS] ❌ No path found within cost bound.")
    return []

##########################################
# Iterative Deepening Search (IDS)    ####
##########################################

def _cheap_progress_heuristic(state):
    agent = state.get_agent_position()
    targets = state.get_targets_positions()
    if not targets:
        return 0
    ax, ay = agent
    ##################################################
    # Manhattan distance (cheap) to nearest target  ##
    ##################################################
    h_min = min(abs(ax - tx) + abs(ay - ty) for (tx, ty) in targets)
    ##########################################################
    # small penalty if enemy is about to chill near agent   ##
    ##########################################################
    en_next = state.get_enemy_next_position()
    if en_next is not None:
        h_min += 0.4 / (1 + abs(ax - en_next[0]) + abs(ay - en_next[1]))
    return h_min

def _ordered_successors(state):
    """Return successors sorted so promising ones go first."""
    succs = state.get_successors()
    ###############################
    # sort by heuristic + cost   ##
    ###############################
    return sorted(succs, key=lambda triple: (_cheap_progress_heuristic(triple[2]), triple[1]))

def dls(initial_state, limit_depth, time_limit_sec=10.0, max_expansions=200000):
    """Depth-Limited Search (non-recursive) returning action list or [] if not found."""
    t0 = time.time()
    stack = [{
        "state": initial_state,
        "depth": 0,
        "actions": [],
        "succs": _ordered_successors(initial_state),
        "i": 0
    }]
    seen = {}  
    expansions = 0

    while stack:
        if time.time() - t0 > time_limit_sec:
            print(f"[DLS] ⏱ time limit reached at depth {limit_depth}")
            return []
        if expansions > max_expansions:
            print(f"[DLS] ⚠ expansion cap reached at depth {limit_depth}")
            return []

        frame = stack[-1]
        s = frame["state"]
        d = frame["depth"]

        if s.is_goal_state():
            return frame["actions"]
        if s.is_collision_state():
            stack.pop()
            continue
        if d >= limit_depth or frame["i"] >= len(frame["succs"]):
            stack.pop()
            continue

        a, cost, nxt = frame["succs"][frame["i"]]
        frame["i"] += 1

        if nxt.is_collision_state():
            continue

        key = state_key(nxt)
        rem_depth = limit_depth - (d + 1)
        if rem_depth <= seen.get(key, -1):
            continue
        seen[key] = rem_depth

        expansions += 1
        stack.append({
            "state": nxt,
            "depth": d + 1,
            "actions": frame["actions"] + [a],
            "succs": _ordered_successors(nxt),
            "i": 0
        })

    return []

def ids(initial_state,
        start_depth=50,
        max_depth=1000,
        total_time_limit=120.0,
        time_frac_per_depth=0.8,
        expansions_per_depth=500000):
    """Iterative Deepening Search: increase depth until solution or limit."""
    t_start = time.time()
    depth = start_depth

    while depth <= max_depth:
        elapsed = time.time() - t_start
        remaining = total_time_limit - elapsed
        if remaining <= 0:
            print(f"[IDS] ⏱ global time limit reached at depth {depth}")
            return []

        time_budget = remaining * time_frac_per_depth
        print(f"[IDS] trying depth {depth}, time budget {time_budget:.2f}s")
        actions = dls(initial_state,
                      limit_depth=depth,
                      time_limit_sec=time_budget,
                      max_expansions=expansions_per_depth)
        if actions:
            print(f"[IDS] ✅ solution found at depth {depth}")
            return actions
        ############################
        # increase depth smartly  ##
        ############################
        if depth < 200:
            depth = int(depth * 1.5)
        else:
            depth += 100

    print("[IDS] ❌ no solution found up to max_depth")
    return []

# ##############################################
# GLOBAL TERRAIN CACHE (shared across runs)   ##
################################################
TERRAIN_CACHE = {}

def heuristic(state):
    """Fast, terrain-aware admissible heuristic with caching and MST approximation."""
    agent = state.get_agent_position()
    targets = state.get_targets_positions()
    if not targets:
        return 0

    rows, cols = state.get_grid_size()
    bushes = set(state.get_bushes_positions())
    rocks = set(state.get_rocks_positions())
    ########################################################
    # Create a reusable cache key for this terrain layout ##
    ########################################################
    terrain_key = (rows, cols, tuple(sorted(bushes)), tuple(sorted(rocks)))
    grid = TERRAIN_CACHE.get(terrain_key)
    #################################################
    # Build the weighted grid (only once per map)  ##
    #################################################
    if grid is None:
        grid = [[1] * cols for _ in range(rows)]
        for br, bc in bushes:
            grid[br][bc] = 6          
        for rr, rc in rocks:
            grid[rr][rc] = 999999    
        TERRAIN_CACHE[terrain_key] = grid
    ######################################################################
    # Local BFS to compute minimal cost from one point to nearest goal  ##
    ######################################################################
    def bfs_cost(start, goals):
        """Return minimal cost from start to any goal, early-exiting when found."""
        q = deque([(start, 0)])
        visited = {start}
        while q:
            (r, c), cost = q.popleft()
            if (r, c) in goals:
                return cost
            for dr, dc in ((1, 0), (-1, 0), (0, 1), (0, -1)):
                nr, nc = r + dr, c + dc
                if 0 <= nr < rows and 0 <= nc < cols and (nr, nc) not in visited:
                    step = grid[nr][nc]
                    if step < 999999:
                        visited.add((nr, nc))
                        q.append(((nr, nc), cost + step))
        return 999999
    ##############################
    # Nearest target distance   ##
    ##############################
    nearest = bfs_cost(agent, set(targets))

    ###############################################################
    # Approximate MST: connect up to 3 nearest targets greedily  ##
    ###############################################################
    remaining = list(targets)
    mst_cost = 0
    if len(remaining) > 1:
        current = remaining.pop(0)
        for _ in range(min(3, len(remaining))):
            nxt = min(remaining, key=lambda p: bfs_cost(current, {p}))
            mst_cost += bfs_cost(current, {nxt})
            remaining.remove(nxt)
            current = nxt

    bias = 1.5 * (len(targets) - 1)
    return min(nearest + mst_cost + bias, 999999)


_ALT_CACHE = {}
INF = 10**9

def _build_weight_grid(state):
    rows, cols = state.get_grid_size()
    grid = [[1] * cols for _ in range(rows)]
    for r, c in state.get_bushes_positions():
        grid[r][c] = 6
    for r, c in state.get_rocks_positions():
        grid[r][c] = INF
    return grid

def _valid_neighbors(r, c, rows, cols):
    if r > 0:         yield r-1, c
    if r+1 < rows:    yield r+1, c
    if c > 0:         yield r, c-1
    if c+1 < cols:    yield r, c+1

def _dijkstra_grid(start, grid):
    rows, cols = len(grid), len(grid[0])
    dist = [[INF]*cols for _ in range(rows)]
    sr, sc = start
    if grid[sr][sc] >= INF: return dist
    dist[sr][sc] = 0
    pq = [(0, sr, sc)]
    while pq:
        d, r, c = heappop(pq)
        if d != dist[r][c]:
            continue
        for nr, nc in _valid_neighbors(r, c, rows, cols):
            w = grid[nr][nc]
            if w >= INF: continue
            nd = d + w
            if nd < dist[nr][nc]:
                dist[nr][nc] = nd
                heappush(pq, (nd, nr, nc))
    return dist

def _pick_landmarks(state, grid, k=6):
    rows, cols = state.get_grid_size()
    corners = [(0,0),(0,cols-1),(rows-1,0),(rows-1,cols-1)]
    corners = [(r,c) for (r,c) in corners if grid[r][c] < INF]
    if not corners:
        corners.append(state.get_agent_position())
    return corners[:k]

def _ensure_alt_preprocessed(state, landmark_count=6):
    rows, cols = state.get_grid_size()
    bushes = tuple(sorted(state.get_bushes_positions()))
    rocks = tuple(sorted(state.get_rocks_positions()))
    key = (rows, cols, bushes, rocks)
    if key in _ALT_CACHE:
        return _ALT_CACHE[key]
    grid = _build_weight_grid(state)
    landmarks = _pick_landmarks(state, grid, k=landmark_count)
    dists = [_dijkstra_grid(lm, grid) for lm in landmarks]
    pack = {"grid": grid, "landmarks": landmarks, "dists": dists}
    _ALT_CACHE[key] = pack
    return pack

def _alt_distance(pack, a, b):
    (ar, ac), (br, bc) = a, b
    best = 0
    for distL in pack["dists"]:
        da = distL[ar][ac]
        db = distL[br][bc]
        if da >= INF or db >= INF: continue
        best = max(best, abs(db - da))
    return best

def heuristic(state):
    pack = _ensure_alt_preprocessed(state)
    agent = state.get_agent_position()
    targets = state.get_targets_positions()
    if not targets:
        return 0
    h_min = min(_alt_distance(pack, agent, t) for t in targets)
    bias = 1.0 * (len(targets) - 1)
    return h_min + bias

def astar(initial_state):
    start_time = time.time()
    root = Node(initial_state)
    openpq = []
    counter = 0
    gbest = {state_key(initial_state): 0}
    heappush(openpq, (heuristic(initial_state), counter, root))
    expanded = 0
    MAX_EXPANSIONS = 30000000

    while openpq:
        f, _, node = heappop(openpq)
        s = node.state
        g = node.g
        sk = state_key(s)
        if g > gbest.get(sk, float('inf')):
            continue
        if s.is_goal_state():
            elapsed = time.time() - start_time
            print(f"[ALT-A*] ✅ Goal found | Expanded {expanded} | {elapsed:.2f}s")
            return reconstruct(node)
        if s.is_collision_state():
            continue
        expanded += 1
        if expanded > MAX_EXPANSIONS:
            print(f"[ALT-A*] ⚠️ Expansion cap reached ({expanded}).")
            break
        for a, c, nxt in s.get_successors():
            if nxt.is_collision_state():
                continue
            ng = g + c
            nk = state_key(nxt)
            if ng < gbest.get(nk, float('inf')):
                gbest[nk] = ng
                counter += 1
                heappush(openpq, (ng + heuristic(nxt), counter,
                                  Node(nxt, node, a, ng, node.depth + 1)))

    print("[ALT-A*] ❌ No path found.")
    return []

########################
# Interactive Menu   ###
########################

def run_menu():
    delay = 200
    difficulty = "easy"
    pygame.init()
    screen = pygame.display.set_mode((620, 250))
    pygame.display.set_caption("AI Search Algorithms - Angry Birds: Star Wars")
    font = pygame.font.Font(None, 32)

    def draw(sel=None):
        screen.fill((15, 15, 15))
        opts = [
            "1 - BFS (uninformed)",
            "2 - UCS (cost-based)",
            "3 - IDS (depth-limited)",
            "4 - A* (informed, graded)",
            "",
            f"Current Map: {difficulty.upper()}",
            "E - Easy | M - Medium | H - Hard",
            "ESC - Exit"
        ]
        for i, t in enumerate(opts):
            color = (255, 255, 0) if str(i + 1) == sel else (200, 200, 200)
            txt = font.render(t, True, color)
            screen.blit(txt, (40, 30 + i * 28))
        pygame.display.flip()

    draw()
    running = True
    while running:
        for e in pygame.event.get():
            if e.type == pygame.QUIT:
                running = False
            elif e.type == pygame.KEYDOWN:
                k = e.key
                if k == pygame.K_ESCAPE:
                    running = False
                elif k == pygame.K_e:
                    difficulty = "easy"; draw()
                elif k == pygame.K_m:
                    difficulty = "medium"; draw()
                elif k == pygame.K_h:
                    difficulty = "hard"; draw()
                elif k == pygame.K_1:
                    draw("1"); play(difficulty, bfs, delay=delay)
                elif k == pygame.K_2:
                    draw("2"); play(difficulty, ucs, delay=delay)
                elif k == pygame.K_3:
                    draw("3"); play(difficulty, ids, delay=delay)

                elif k == pygame.K_4:
                    draw("4"); play(difficulty, astar, delay=delay)
    pygame.quit()
    
########
if __name__ == "__main__":
    run_menu()
