In [None]:
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pygame
import cv2  # for video capture

# Enable cuDNN benchmarking for performance (if using GPU)
torch.backends.cudnn.benchmark = True

# Constants for gridworld and visualization
GRID_SIZE = 15
CELL_SIZE = 50   # decreased cell size so that the grid is smaller on a 1440p screen
WINDOW_SIZE = GRID_SIZE * CELL_SIZE
NUM_OBSTACLES = 25  # increased number of obstacles

# ----------------------------
# Environment (Gridworld)
# ----------------------------
class Gridworld:
    def __init__(self):
        self.grid_size = GRID_SIZE
        self.cell_size = CELL_SIZE
        self.vision_radius = 5  # drones can sense obstacles within 5 cells
        
        # Define goals:
        # Agent 0's goal is bottom-right; Agent 1's goal is bottom-left.
        self.goals = [(GRID_SIZE - 1, GRID_SIZE - 1), (0, GRID_SIZE - 1)]
        
        self.obstacles = self.generate_obstacles()
        
        # Compute best paths using A* (only considering obstacles)
        self.best_path_agent0 = self.a_star((0, 0), self.goals[0])
        self.best_path_agent1 = self.a_star((0, 0), self.goals[1])
        
        # Congestion zones (dynamic) with a timer
        self.congestion_zones = []
        self.congestion_duration = 50  # congestion zones persist for 50 steps
        self.congestion_timer = 0
        
        # Both agents start at top-left (0,0)
        self.agents = [(0, 0), (0, 0)]
        
        # Penalty and reward values
        self.movement_penalty = -1
        self.closeness_penalty = -10
        self.collision_penalty = -50
        self.congestion_penalty = -20
        self.goal_reward = 50
        self.shaping_factor = 0.5
        
        # Parameters for following the best path
        self.best_path_bonus = 10
        self.deviation_penalty = -10

    def generate_obstacles(self):
        obstacles = set()
        while len(obstacles) < NUM_OBSTACLES:
            x = random.randint(0, self.grid_size - 1)
            y = random.randint(0, self.grid_size - 1)
            if (x, y) != (0, 0) and (x, y) not in self.goals:
                obstacles.add((x, y))
        return list(obstacles)

    def heuristic(self, node, goal):
        return abs(node[0] - goal[0]) + abs(node[1] - goal[1])

    def get_neighbors(self, node):
        (x, y) = node
        neighbors = []
        for dx, dy in [(1,0), (-1,0), (0,1), (0,-1)]:
            nx, ny = x + dx, y + dy
            if 0 <= nx < self.grid_size and 0 <= ny < self.grid_size:
                if (nx, ny) not in self.obstacles:
                    neighbors.append((nx, ny))
        return neighbors

    def a_star(self, start, goal):
        import heapq
        open_set = []
        heapq.heappush(open_set, (0, start))
        came_from = {}
        g_score = {start: 0}
        while open_set:
            current_f, current = heapq.heappop(open_set)
            if current == goal:
                path = [current]
                while current in came_from:
                    current = came_from[current]
                    path.append(current)
                path.reverse()
                return path
            for neighbor in self.get_neighbors(current):
                tentative_g = g_score[current] + 1
                if neighbor not in g_score or tentative_g < g_score[neighbor]:
                    came_from[neighbor] = current
                    g_score[neighbor] = tentative_g
                    f_score = tentative_g + self.heuristic(neighbor, goal)
                    heapq.heappush(open_set, (f_score, neighbor))
        return []

    def update_congestion_zones(self):
        if self.congestion_timer > 0:
            self.congestion_timer -= 1
            return
        num_zones = 6  # increased number of congestion zones
        possible_zones = list(set(self.best_path_agent0 + self.best_path_agent1))
        if possible_zones:
            self.congestion_zones = random.sample(possible_zones, min(num_zones, len(possible_zones)))
        else:
            self.congestion_zones = []
        self.congestion_timer = self.congestion_duration

    def get_obstacle_view(self, agent_position):
        view = []
        ax, ay = agent_position
        for dy in range(-self.vision_radius, self.vision_radius + 1):
            for dx in range(-self.vision_radius, self.vision_radius + 1):
                cx, cy = ax + dx, ay + dy
                if 0 <= cx < self.grid_size and 0 <= cy < self.grid_size:
                    view.append(1.0 if (cx, cy) in self.obstacles else 0.0)
                else:
                    view.append(0.0)
        return view

    def get_congestion_view(self, agent_position):
        view = []
        ax, ay = agent_position
        for dy in range(-self.vision_radius, self.vision_radius + 1):
            for dx in range(-self.vision_radius, self.vision_radius + 1):
                cx, cy = ax + dx, ay + dy
                if 0 <= cx < self.grid_size and 0 <= cy < self.grid_size:
                    view.append(1.0 if (cx, cy) in self.congestion_zones else 0.0)
                else:
                    view.append(0.0)
        return view

    def reset(self):
        self.agents = [(0, 0), (0, 0)]
        return self.get_states()

    def get_states(self):
        states = []
        for idx, pos in enumerate(self.agents):
            goal = self.goals[idx]
            base_state = [pos[0], pos[1], goal[0], goal[1]]
            obstacle_view = self.get_obstacle_view(pos)
            congestion_view = self.get_congestion_view(pos)
            states.append(base_state + obstacle_view + congestion_view)
        return states

    def step(self, actions):
        self.update_congestion_zones()
        reward = 0
        reward += self.movement_penalty * len(self.agents)
        
        old_distances = []
        for idx, (x, y) in enumerate(self.agents):
            goal = self.goals[idx]
            old_distances.append(abs(x - goal[0]) + abs(y - goal[1]))
        
        new_positions = []
        for idx, (x, y) in enumerate(self.agents):
            dx, dy = 0, 0
            action = actions[idx]
            if action == 0:
                dy = -1
            elif action == 1:
                dy = 1
            elif action == 2:
                dx = -1
            elif action == 3:
                dx = 1
            new_x, new_y = x + dx, y + dy
            if new_x < 0 or new_x >= self.grid_size or new_y < 0 or new_y >= self.grid_size:
                new_x, new_y = x, y
            if (new_x, new_y) in self.obstacles:
                reward += self.collision_penalty
                new_x, new_y = x, y
            new_positions.append((new_x, new_y))
        self.agents = new_positions
        
        for idx, pos in enumerate(self.agents):
            if pos in self.congestion_zones:
                reward += self.congestion_penalty
        
        for idx, (x, y) in enumerate(self.agents):
            goal = self.goals[idx]
            new_distance = abs(x - goal[0]) + abs(y - goal[1])
            reward += self.shaping_factor * (old_distances[idx] - new_distance)
        
        for idx, pos in enumerate(self.agents):
            best_path = self.best_path_agent0 if idx == 0 else self.best_path_agent1
            if pos in best_path:
                reward += self.best_path_bonus
            else:
                congestion_view = self.get_congestion_view(pos)
                if not any(cell == 1.0 for cell in congestion_view):
                    reward += self.deviation_penalty
        
        manhattan_distance = abs(self.agents[0][0] - self.agents[1][0]) + abs(self.agents[0][1] - self.agents[1][1])
        if manhattan_distance < 3:
            reward += self.closeness_penalty
        
        for idx, pos in enumerate(self.agents):
            if pos == self.goals[idx]:
                reward += self.goal_reward
        
        done = (self.agents[0] == self.goals[0] and self.agents[1] == self.goals[1])
        next_states = self.get_states()
        return next_states, reward, done

    def render(self, screen):
        screen.fill((255, 255, 255))
        for x in range(0, WINDOW_SIZE, self.cell_size):
            pygame.draw.line(screen, (200, 200, 200), (x, 0), (x, WINDOW_SIZE))
        for y in range(0, WINDOW_SIZE, self.cell_size):
            pygame.draw.line(screen, (200, 200, 200), (0, y), (WINDOW_SIZE, y))
        for obs in self.obstacles:
            rect = pygame.Rect(obs[0]*self.cell_size, obs[1]*self.cell_size, self.cell_size, self.cell_size)
            pygame.draw.rect(screen, (0, 0, 0), rect)
        for cz in self.congestion_zones:
            rect = pygame.Rect(cz[0]*self.cell_size, cz[1]*self.cell_size, self.cell_size, self.cell_size)
            pygame.draw.rect(screen, (255, 165, 0), rect)
        # Draw computed best paths
        if self.best_path_agent0 and len(self.best_path_agent0) > 1:
            points0 = [(x * self.cell_size + self.cell_size//2, y * self.cell_size + self.cell_size//2)
                       for (x, y) in self.best_path_agent0]
            pygame.draw.lines(screen, (255, 0, 255), False, points0, 3)
        if self.best_path_agent1 and len(self.best_path_agent1) > 1:
            points1 = [(x * self.cell_size + self.cell_size//2, y * self.cell_size + self.cell_size//2)
                       for (x, y) in self.best_path_agent1]
            pygame.draw.lines(screen, (0, 255, 255), False, points1, 3)
        # Draw goals with colors matching the drones:
        goal_colors = [(255, 0, 0), (0, 0, 255)]
        for idx, goal in enumerate(self.goals):
            rect = pygame.Rect(goal[0]*self.cell_size, goal[1]*self.cell_size, self.cell_size, self.cell_size)
            pygame.draw.rect(screen, goal_colors[idx], rect)
        # Draw agents as circles (agent 0 in red, agent 1 in blue)
        colors = [(255, 0, 0), (0, 0, 255)]
        for idx, pos in enumerate(self.agents):
            center = (pos[0]*self.cell_size + self.cell_size//2,
                      pos[1]*self.cell_size + self.cell_size//2)
            pygame.draw.circle(screen, colors[idx], center, self.cell_size//3)
        pygame.display.flip()

# ----------------------------
# Replay Buffer for DQN
# ----------------------------
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

# ----------------------------
# DQN Network and Agent (GPU-enabled)
# ----------------------------
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )
        
    def forward(self, x):
        return self.net(x)

class DQNAgent:
    def __init__(self, input_dim, output_dim, lr=1e-3, gamma=0.99,
                 buffer_capacity=10000, batch_size=32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN(input_dim, output_dim).to(self.device)
        self.target_net = DQN(input_dim, output_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = ReplayBuffer(buffer_capacity)
        self.batch_size = batch_size
        
    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            return random.randrange(4)
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.policy_net(state_tensor)
        return q_values.argmax().item()
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        transitions = self.replay_buffer.sample(self.batch_size)
        batch_state, batch_action, batch_reward, batch_next_state, batch_done = zip(*transitions)
        
        batch_state = torch.FloatTensor(batch_state).to(self.device)
        batch_action = torch.LongTensor(batch_action).unsqueeze(1).to(self.device)
        batch_reward = torch.FloatTensor(batch_reward).unsqueeze(1).to(self.device)
        batch_next_state = torch.FloatTensor(batch_next_state).to(self.device)
        batch_done = torch.FloatTensor(batch_done).unsqueeze(1).to(self.device)
        
        current_q = self.policy_net(batch_state).gather(1, batch_action)
        next_q = self.target_net(batch_next_state).max(1)[0].unsqueeze(1)
        target_q = batch_reward + self.gamma * next_q * (1 - batch_done)
        
        loss = nn.MSELoss()(current_q, target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def update_target(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

# ----------------------------
# Training Loop with Milestone Rendering and Video Capture
# ----------------------------
def train():
    num_episodes = 1000
    max_steps = 500
    epsilon = 1.0
    epsilon_min = 0.05
    epsilon_decay = 0.9967
    target_update_freq = 10
    milestones = {int(num_episodes * 0.5), int(num_episodes * 0.75), num_episodes}
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    env = Gridworld()
    state_dim = 4 + 2 * ((2 * env.vision_radius + 1) ** 2)
    action_dim = 4
    agents = [DQNAgent(state_dim, action_dim), DQNAgent(state_dim, action_dim)]
    
    for episode in range(num_episodes):
        states = env.reset()
        total_reward = 0
        do_render = (episode + 1) in milestones
        video_writer = None
        if do_render:
            pygame.init()
            screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE))
            pygame.display.set_caption(f"Training Episode {episode + 1}")
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            video_filename = f"episode_{episode+1}.mp4"
            fps = 20
            video_writer = cv2.VideoWriter(video_filename, fourcc, fps, (WINDOW_SIZE, WINDOW_SIZE))
        
        for step in range(max_steps):
            actions = []
            for i in range(2):
                action = agents[i].select_action(states[i], epsilon)
                actions.append(action)
            next_states, reward, done = env.step(actions)
            total_reward += reward
            
            for i in range(2):
                agents[i].replay_buffer.push(states[i], actions[i], reward, next_states[i], done)
            
            states = next_states
            
            for i in range(2):
                agents[i].update()
            
            if do_render:
                env.render(screen)
                frame = pygame.surfarray.array3d(screen)
                frame = np.transpose(frame, (1, 0, 2))
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                video_writer.write(frame)
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        if video_writer is not None:
                            video_writer.release()
                        return
                pygame.time.wait(50)
                
            if done:
                break
        
        if do_render:
            pygame.time.wait(1000)
            pygame.quit()
            if video_writer is not None:
                video_writer.release()
        
        if (episode + 1) % target_update_freq == 0:
            for agent in agents:
                agent.update_target()
        
        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        print(f"Episode {episode + 1}, Total Reward: {total_reward:.2f}, Epsilon: {epsilon:.3f}")

    # Save the trained model for each agent after all episodes have completed.
    for i, agent in enumerate(agents):
        torch.save(agent.policy_net.state_dict(), f"agent_{i}_trained.pth")
    print("Training complete. Models saved.")

    
if __name__ == '__main__':
    train()
