# PPO VS DDQN VS DQN

In [None]:
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
import random
import time
import matplotlib.pyplot as plt
from torch.distributions import Categorical
from gymnasium import spaces

 
class MazeEnv(gym.Env):
    def __init__(self, maze_size=8):
        super(MazeEnv, self).__init__()
        self.maze_size = maze_size
        self.observation_space = spaces.Box(low=0, high=1, shape=(maze_size, maze_size), dtype=np.float32)
        self.action_space = spaces.Discrete(4)   
        self.maze = None
        self.current_pos = None
        self.goal_pos = None
        self.max_steps = maze_size * maze_size * 2
        self.steps = 0
        
    def generate_maze(self):
         
        self.maze = np.zeros((self.maze_size, self.maze_size))
         
        self.maze[np.random.choice(self.maze_size, size=self.maze_size//2), 
                 np.random.choice(self.maze_size, size=self.maze_size//2)] = 1
         
        self.current_pos = (0, 0)
        self.goal_pos = (self.maze_size-1, self.maze_size-1)
        self.maze[self.current_pos] = 0
        self.maze[self.goal_pos] = 0
        
    def get_state(self):
        state = self.maze.copy()
        state[self.current_pos] = 2   
        state[self.goal_pos] = 3      
        return state.flatten()
        
    def reset(self, seed=None):
        super().reset(seed=seed)
        self.generate_maze()
        self.steps = 0
        return self.get_state(), {}
        
    def step(self, action):
        self.steps += 1
        x, y = self.current_pos
        
         
        if action == 0:     
            new_pos = (max(0, x-1), y)
        elif action == 1:   
            new_pos = (x, min(self.maze_size-1, y+1))
        elif action == 2:   
            new_pos = (min(self.maze_size-1, x+1), y)
        else:              
            new_pos = (x, max(0, y-1))
            
         
        if self.maze[new_pos] == 1:
            reward = -1   
            new_pos = self.current_pos   
        else:
            self.current_pos = new_pos
            
         
        if self.current_pos == self.goal_pos:
            reward = 100   
            done = True
        elif self.steps >= self.max_steps:
            reward = -10   
            done = True
        else:
             
            reward = -0.1
            done = False
            
        return self.get_state(), reward, done, done, {}

 
class MazeNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MazeNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

 
class MazeDQNAgent:
    def __init__(self, state_dim, action_dim, learning_rate=1e-3):
        self.q_network = MazeNetwork(state_dim, action_dim)
        self.target_network = MazeNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.memory = ReplayBuffer(10000)
        self.batch_size = 64
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.target_update = 10
        self.steps = 0

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(4)   
        with torch.no_grad():
            state = torch.FloatTensor(state)
            q_values = self.q_network(state)
            return q_values.argmax().item()

    def train(self):
        if len(self.memory) < self.batch_size:
            return

        transitions = self.memory.sample(self.batch_size)
        batch = self.memory.Transition(*zip(*transitions))

        state_batch = torch.FloatTensor(np.array(batch.state))
        action_batch = torch.LongTensor(batch.action)
        reward_batch = torch.FloatTensor(batch.reward)
        next_state_batch = torch.FloatTensor(np.array(batch.next_state))
        done_batch = torch.FloatTensor(batch.done)

        current_q_values = self.q_network(state_batch).gather(1, action_batch.unsqueeze(1))
        next_q_values = self.target_network(next_state_batch).max(1)[0].detach()
        target_q_values = reward_batch + (1 - done_batch) * self.gamma * next_q_values

        loss = F.smooth_l1_loss(current_q_values.squeeze(), target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        self.steps += 1
        if self.steps % self.target_update == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

 
class MazeDDQNAgent(MazeDQNAgent):
    def __init__(self, state_dim, action_dim, learning_rate=1e-3):
        super(MazeDDQNAgent, self).__init__(state_dim, action_dim, learning_rate)

    def train(self):
        if len(self.memory) < self.batch_size:
            return

        transitions = self.memory.sample(self.batch_size)
        batch = self.memory.Transition(*zip(*transitions))

        state_batch = torch.FloatTensor(np.array(batch.state))
        action_batch = torch.LongTensor(batch.action)
        reward_batch = torch.FloatTensor(batch.reward)
        next_state_batch = torch.FloatTensor(np.array(batch.next_state))
        done_batch = torch.FloatTensor(batch.done)

        current_q_values = self.q_network(state_batch).gather(1, action_batch.unsqueeze(1))
        next_q_values_online = self.q_network(next_state_batch).max(1)[1].detach().unsqueeze(1)
        next_q_values_target = self.target_network(next_state_batch).gather(1, next_q_values_online).detach()
        target_q_values = reward_batch + (1 - done_batch) * self.gamma * next_q_values_target

        loss = F.smooth_l1_loss(current_q_values.squeeze(), target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        self.steps += 1
        if self.steps % self.target_update == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

 
class MazePPOAgent:
    def __init__(self, state_dim, action_dim, learning_rate=1e-3):
        self.policy_network = MazeNetwork(state_dim, action_dim)
        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=learning_rate)
        self.gamma = 0.99
        self.epsilon_clip = 0.2
        self.K = 4   

    def select_action(self, state):
        state = torch.FloatTensor(state)
        logits = self.policy_network(state)
        dist = Categorical(logits=logits)
        action = dist.sample().item()
        return action

    def train(self, memory):
        states, actions, rewards, next_states, dones = zip(*memory)
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        old_log_probs = torch.FloatTensor([Categorical(logits=self.policy_network(s)).log_prob(a).item() 
                                           for s, a in zip(states, actions)])

        for _ in range(self.K):
                    logits = self.policy_network(states)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)
        ratios = torch.exp(log_probs - old_log_probs)

        advantages = rewards + (1 - dones) * self.gamma * self.policy_network(next_states).max(1)[0].detach() - self.policy_network(states).gather(1, actions.unsqueeze(1)).squeeze()
        advantages = advantages.detach()

        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1 - self.epsilon_clip, 1 + self.epsilon_clip) * advantages
        loss = -torch.min(surr1, surr2).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

 
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        self.Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

    def push(self, state, action, reward, next_state, done):
        self.buffer.append(self.Transition(state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

def train_maze_agent(agent, env, episodes, render=False):
    rewards = []
    success_rate = []
    path_lengths = []

    for episode in range(episodes):
        state, _ = env.reset()
        total_reward = 0
        steps = 0
        done = False
        memory = []

        while not done:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            memory.append((state, action, reward, next_state, done))
            if hasattr(agent, 'memory'):
                agent.memory.push(state, action, reward, next_state, done)
                agent.train()

            state = next_state
            total_reward += reward
            steps += 1

            if render and episode % 100 == 0:
                env.render()

        if hasattr(agent, 'train') and not hasattr(agent, 'memory'):
            agent.train(memory)

        rewards.append(total_reward)
        path_lengths.append(steps)
        success = reward > 0   
        success_rate.append(success)

        if (episode + 1) % 10 == 0:
            recent_success_rate = sum(success_rate[-100:]) / min(100, len(success_rate))
            avg_path_length = sum(path_lengths[-100:]) / min(100, len(path_lengths))
            print(f"Episode {episode + 1}, Success Rate: {recent_success_rate:.2%}, "
                  f"Avg Path Length: {avg_path_length:.1f}, Reward: {total_reward:.2f}")

    return rewards, success_rate, path_lengths

def compare_algorithms(env, episodes):
    state_dim = env.maze_size * env.maze_size   
    action_dim = 4   

    agents = {
        "DQN": MazeDQNAgent(state_dim, action_dim),
        "DDQN": MazeDDQNAgent(state_dim, action_dim),
        "PPO": MazePPOAgent(state_dim, action_dim)
    }

    results = {}
    for algo_name, agent in agents.items():
        print(f"Training {algo_name}...")
        rewards, success_rate, path_lengths = train_maze_agent(agent, env, episodes)
        results[algo_name] = {
            "rewards": rewards,
            "success_rate": success_rate,
            "path_lengths": path_lengths
        }

    return results

 
maze_size = 8
env = MazeEnv(maze_size=maze_size)
episodes = 100

results = compare_algorithms(env, episodes)

 
plt.figure(figsize=(18, 6))

for algo_name, result in results.items():
     
    plt.subplot(1, 3, 1)
    plt.plot(result["rewards"], label=f'Total Reward per Episode ({algo_name})')
    plt.title('Learning Curve')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.legend(loc='upper left')

     
    plt.subplot(1, 3, 2)
    window = 100
    success_rate_smooth = [sum(result["success_rate"][max(0, i-window):i]) / min(i, window) for i in range(1, len(result["success_rate"]) + 1)]
    plt.plot(success_rate_smooth, label=f'Success Rate (Moving Average) ({algo_name})')
    plt.title('Success Rate')
    plt.xlabel('Episode')
    plt.ylabel('Success Rate')
    plt.legend(loc='upper left')

     
    plt.subplot(1, 3, 3)
    path_lengths_smooth = [sum(result["path_lengths"][max(0, i-window):i]) / min(i, window) for i in range(1, len(result["path_lengths"]) + 1)]
    plt.plot(path_lengths_smooth, label=f'Steps per Episode (Moving Average) ({algo_name})')
    plt.title('Average Path Length')
    plt.xlabel('Episode')
    plt.ylabel('Steps')
    plt.legend(loc='upper left')

plt.tight_layout()
plt.show()


# PPO

In [None]:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Env:
    def __init__(self):
        self.x_range = 51
        self.y_range = 31
        self.motions = [(-1, 0), (-1, 1), (0, 1), (1, 1),
                        (1, 0), (1, -1), (0, -1), (-1, -1)]
        self.obs = self.obs_map()
        self.state_dim = 14
        self.action_dim = len(self.motions)
        self.prev_dist = None
        self.steps_taken = 0
        self.max_steps = 300
        self.visited_positions = set()
        self.last_positions = []
        self.position_history = []
        self.stuck_threshold = 20
        self.exploration_bonus_map = {}

    def obs_map(self):
        x = self.x_range
        y = self.y_range
        obs = set()
        for i in range(x):
            obs.add((i, 0))
        for i in range(x):
            obs.add((i, y - 1))
        for i in range(y):
            obs.add((0, i))
        for i in range(y):
            obs.add((x - 1, i))
        for i in range(10, 21):
            obs.add((i, 15))
        for i in range(15):
            obs.add((20, i))
        for i in range(15, 30):
            obs.add((30, i))
        for i in range(16):
            obs.add((40, i))
        return obs

    def reset(self):
        self.start = (5, 5)
        self.goal = (45, 25)
        self.position = self.start
        self.prev_dist = np.linalg.norm(np.array(self.start) - np.array(self.goal))
        self.steps_taken = 0
        self.visited_positions = {self.start}
        self.last_positions = [self.start]
        self.position_history = [self.start]
        self.exploration_bonus_map = {}
        return self._get_state()

    def _get_state(self):
        pos = np.array(self.position)
        goal = np.array(self.goal)
        dist_to_goal = np.linalg.norm(pos - goal)
        angle_to_goal = np.arctan2(goal[1] - pos[1], goal[0] - pos[0]) / np.pi

        obstacle_dists = []
        for dx, dy in self.motions:
            x, y = self.position
            dist = 0
            while True:
                x += dx
                y += dy
                dist += 1
                if (x, y) in self.obs or not (0 <= x < self.x_range and 0 <= y < self.y_range):
                    obstacle_dists.append(min(dist / 10.0, 1.0))
                    break

        return np.array([
            pos[0] / self.x_range,
            pos[1] / self.y_range,
            goal[0] / self.x_range,
            goal[1] / self.y_range,
            dist_to_goal / np.sqrt(self.x_range**2 + self.y_range**2),
            angle_to_goal,
            *obstacle_dists
        ], dtype=np.float32)
    def is_stuck(self):
        if len(self.position_history) < self.stuck_threshold:
            return False
        recent_positions = self.position_history[-self.stuck_threshold:]
        unique_positions = set(recent_positions)
        return len(unique_positions) <= 3

    def get_direction_vector(self, from_pos, to_pos):
        vector = np.array(to_pos) - np.array(from_pos)
        norm = np.linalg.norm(vector)
        return vector / (norm + 1e-6)
    def get_exploration_bonus(self, position):
         
        for pos in self.exploration_bonus_map:
            self.exploration_bonus_map[pos] *= 0.995

         
        if position not in self.exploration_bonus_map:
            self.exploration_bonus_map[position] = 10.0   
        return self.exploration_bonus_map[position]

    def step(self, action):
        self.steps_taken += 1
        x, y = self.position
        dx, dy = self.motions[action]
        new_position = (x + dx, y + dy)

        current_dist = np.linalg.norm(np.array(new_position) - np.array(self.goal))

        if (new_position in self.obs or not (0 <= new_position[0] < self.x_range and
                                           0 <= new_position[1] < self.y_range)):
            reward = -2.0   
            done = False
        else:
            self.position = new_position
            self.visited_positions.add(new_position)
            self.last_positions.append(new_position)
            self.position_history.append(new_position)
            if len(self.last_positions) > 10:
                self.last_positions.pop(0)

            if self.position == self.goal:
                reward = 1000.0   
                done = True
            else:
                progress_reward = (self.prev_dist - current_dist) * 10.0   
                distance_factor = 1.0 - (current_dist / np.sqrt(self.x_range**2 + self.y_range**2))
                distance_reward = distance_factor * 2.0   


                exploration_bonus = self.get_exploration_bonus(new_position)


                direction_to_goal = self.get_direction_vector(self.position, self.goal)
                movement_direction = self.get_direction_vector((x, y), new_position)
                alignment_reward = max(0, np.dot(direction_to_goal, movement_direction)) * 2.0


                stuck_penalty = -5.0 if self.is_stuck() else 0.0
                oscillation_penalty = -1.0 if new_position in self.last_positions[-3:] else 0.0
                time_penalty = -0.1


                reward = (progress_reward +
                         distance_reward +
                         exploration_bonus +
                         alignment_reward +
                         stuck_penalty +
                         oscillation_penalty +
                         time_penalty)

                 
                reward = np.clip(reward, -10.0, 10.0)
                done = self.steps_taken >= self.max_steps

            self.prev_dist = current_dist

        return self._get_state(), reward, done, {}

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state):
        action_probs = self.actor(state)
        value = self.critic(state)
        return action_probs, value

class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, eps_clip=0.2, K_epochs=8, entropy_coef=0.01):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.entropy_coef = entropy_coef
        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=500, gamma=0.95)

    def select_action(self, state, training=True):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action_probs, _ = self.policy_old(state)
            if not training:
                action = torch.argmax(action_probs).item()
                return action, None
            dist = Categorical(action_probs)
            action = dist.sample()
            return action.item(), dist.log_prob(action)

    def update(self, states, actions, log_probs, rewards, next_states, dones):
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

        returns = []
        discounted_reward = 0
        for reward, done in zip(reversed(rewards), reversed(dones)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            returns.insert(0, discounted_reward)

        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        old_log_probs = torch.FloatTensor(log_probs).to(device)
        returns = torch.FloatTensor(returns).to(device)

        for _ in range(self.K_epochs):
            action_probs, state_values = self.policy(states)
            dist = Categorical(action_probs)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()

            ratios = torch.exp(new_log_probs - old_log_probs.detach())
            advantages = returns - state_values.detach()
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5 * nn.MSELoss()(state_values, returns) - self.entropy_coef * entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())

def train_agent(env, agent, num_episodes=3000, max_steps=300):
    episode_rewards = []
    episode_lengths = []
    best_path = None
    best_reward = -float('inf')
    last_success_path = None
    last_success_reward = None
    last_success_length = None

    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        step_count = 0
        path = [env.position]

        states = []
        actions = []
        log_probs = []
        rewards = []
        next_states = []
        dones = []

        for step in range(max_steps):
            action, log_prob = agent.select_action(state)
            next_state, reward, done, _ = env.step(action)

            states.append(state)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)

            state = next_state
            total_reward += reward
            step_count += 1
            path.append(env.position)

            if done:
                 
                if env.position == env.goal:
                    last_success_path = path.copy()
                    last_success_reward = total_reward
                    last_success_length = step_count
                    if total_reward > best_reward:
                        best_path = path.copy()
                        best_reward = total_reward
                break

        agent.update(states, actions, log_probs, rewards, next_states, dones)
        agent.scheduler.step()

        episode_rewards.append(total_reward)
        episode_lengths.append(step_count)

        if episode % 100 == 0:
            print(f"Episode {episode}, Reward: {total_reward:.2f}, Steps: {step_count}, LR: {agent.scheduler.get_last_lr()[0]:.6f}")
            visualize_agent_path(env, path, total_reward, step_count)

     
    print("\nTraining Complete!")
    print(f"Best path found - Reward: {best_reward:.2f}, Length: {len(best_path)}")
    if best_path:
        print("Visualizing best path found during training:")
        visualize_agent_path(env, best_path, best_reward, len(best_path))

    print("\nLast successful path statistics:")
    if last_success_path:
        print(f"Reward: {last_success_reward:.2f}, Length: {last_success_length}")
        print("Visualizing last successful path:")
        visualize_agent_path(env, last_success_path, last_success_reward, last_success_length)
    else:
        print("No successful paths found during training")

    return episode_rewards, episode_lengths, best_path, last_success_path

def visualize_agent_path(env, path, episode_reward, episode_length):
    plt.figure(figsize=(10, 6))
    obs_x = [x[0] for x in env.obs]
    obs_y = [x[1] for x in env.obs]
    plt.plot(obs_x, obs_y, "sk", label="Obstacles")
    plt.plot(env.start[0], env.start[1], "bs", label="Start")
    plt.plot(env.goal[0], env.goal[1], "gs", label="Goal")
    path_x = [p[0] for p in path]
    path_y = [p[1] for p in path]
    plt.plot(path_x, path_y, 'r-', label="Agent Path")
    plt.title(f"Episode Reward: {episode_reward:.2f}, Length: {episode_length}")
    plt.legend()
    plt.grid(True)
    plt.axis("equal")
    plt.show()

if __name__ == "__main__":
    env = Env()
    agent = PPOAgent(state_dim=env.state_dim, action_dim=env.action_dim)
    episode_rewards, episode_lengths = train_agent(env, agent, num_episodes=3000)

## WITHOUT FL

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

class DualTrainer:
    def __init__(self, envs, episodes=1000):
        self.envs = envs
        self.episodes = episodes
        self.agents = [
            PPOAgent(env.state_dim, env.action_dim) 
            for env in envs
        ]
        self.rewards_history = [[] for _ in envs]
        self.best_paths = [None] * len(envs)
        self.best_rewards = [-float('inf')] * len(envs)
        self.current_paths = [[] for _ in envs]

    def train(self):
        print("Starting Dual PPO Training")
        for episode in range(self.episodes):
             
            for agent_id in range(len(self.agents)):
                state = self.envs[agent_id].reset()
                current_path = [self.envs[agent_id].position]
                
                states, actions, log_probs, rewards = [], [], [], []
                next_states, dones = [], []
                episode_reward = 0
                
                for step in range(self.envs[agent_id].max_steps):
                    action, log_prob = self.agents[agent_id].select_action(state)
                    next_state, reward, done, _ = self.envs[agent_id].step(action)
                    
                    states.append(state)
                    actions.append(action)
                    log_probs.append(log_prob)
                    rewards.append(reward)
                    next_states.append(next_state)
                    dones.append(done)
                    current_path.append(self.envs[agent_id].position)
                    
                    state = next_state
                    episode_reward += reward
                    
                    if done:
                        self.current_paths[agent_id] = current_path
                        if self.envs[agent_id].position == self.envs[agent_id].goal:
                            if episode_reward > self.best_rewards[agent_id]:
                                self.best_paths[agent_id] = current_path.copy()
                                self.best_rewards[agent_id] = episode_reward
                                print(f"New best path found for Agent {agent_id + 1}! Reward: {episode_reward:.2f}")
                        break
                
                self.rewards_history[agent_id].append(episode_reward)
                self.agents[agent_id].update(states, actions, log_probs, rewards, next_states, dones)
            
            if (episode + 1) % 100 == 0:
                print(f"\nEpisode {episode + 1}")
                for agent_id in range(len(self.agents)):
                    avg_reward = np.mean(self.rewards_history[agent_id][-100:])
                    print(f"Agent {agent_id + 1} Average Reward: {avg_reward:.2f}")
                self.visualize_current_paths(episode + 1)

         
        self.visualize_final_results()

    def visualize_current_paths(self, episode):
        """Show current paths for both agents"""
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        fig.suptitle(f'Current Paths - Episode {episode}')
        
        for i in range(len(self.agents)):
            ax = axes[i]
            env = self.envs[i]
            
             
            self._plot_environment(ax, env)
            
             
            if self.current_paths[i]:
                path_x = [p[0] for p in self.current_paths[i]]
                path_y = [p[1] for p in self.current_paths[i]]
                ax.plot(path_x, path_y, 'b-', label="Current Path", linewidth=2)
            
            ax.set_title(f'Agent {i+1}\nCurrent Reward: {self.rewards_history[i][-1]:.2f}')
            ax.legend()
            ax.grid(True)
            ax.axis("equal")
        
        plt.tight_layout()
        plt.show()

    def visualize_final_results(self):
        """Show final results with best paths and learning curves"""
        fig = plt.figure(figsize=(20, 10))
        gs = plt.GridSpec(2, 2, figure=fig)
        
         
        for i in range(len(self.agents)):
            ax_path = fig.add_subplot(gs[0, i])
            env = self.envs[i]
            
             
            self._plot_environment(ax_path, env)
            
             
            if self.best_paths[i]:
                path_x = [p[0] for p in self.best_paths[i]]
                path_y = [p[1] for p in self.best_paths[i]]
                ax_path.plot(path_x, path_y, 'r-', label="Best Path", linewidth=2)
            
            ax_path.set_title(f'Agent {i+1} Best Path\nBest Reward: {self.best_rewards[i]:.2f}')
            ax_path.legend()
            ax_path.grid(True)
            ax_path.axis("equal")
        
         
        ax_metrics = fig.add_subplot(gs[1, :])
        for i, rewards in enumerate(self.rewards_history):
             
            window = min(100, len(rewards))
            smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
            ax_metrics.plot(range(window-1, len(rewards)), smoothed, 
                          linewidth=2, label=f'Agent {i+1}')
        
        ax_metrics.set_title('Training Progress')
        ax_metrics.set_xlabel('Episode')
        ax_metrics.set_ylabel('Average Reward (100-episode window)')
        ax_metrics.legend()
        ax_metrics.grid(True)
        
        plt.suptitle('Final Training Results')
        plt.tight_layout()
        plt.show()
    
    def _plot_environment(self, ax, env):
        """Helper method to plot the environment elements"""
         
        if hasattr(env, 'obs') and env.obs:
            obs_x = [x[0] for x in env.obs]
            obs_y = [x[1] for x in env.obs]
            ax.scatter(obs_x, obs_y, c='black', marker='s', s=100, label="Obstacles")
        
         
        ax.scatter(env.start[0], env.start[1], c='blue', marker='o', s=100, label="Start")
        ax.scatter(env.goal[0], env.goal[1], c='green', marker='*', s=200, label="Goal")
        
         
        ax.grid(True, linestyle='--', alpha=0.6)
        
         
        ax.set_aspect('equal')

def train_dual_ppo(episodes=1000):
    """Main function to train two independent PPO agents"""
    envs = [Env(), Env()]
    trainer = DualTrainer(envs, episodes=episodes)
    trainer.train()
    return trainer.best_paths, trainer.best_rewards

if __name__ == "__main__":
    best_paths, best_rewards = train_dual_ppo(episodes=2000)

## WITH FL

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from torch.distributions import Categorical
from collections import defaultdict
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FederatedPPO:
    def __init__(self, num_clients=2, state_dim=14, action_dim=8):
        self.num_clients = num_clients
        self.clients = []
        self.global_model = ActorCritic(state_dim, action_dim).to(device)

         
        for _ in range(num_clients):
            client = PPOAgent(state_dim, action_dim)
            self.clients.append(client)

    def federated_average(self, weight_accumulator):
        averaged_weights = {}
        for name, param in self.global_model.state_dict().items():
            averaged_weights[name] = torch.zeros_like(param)

        total_samples = sum(accumulator['n_samples'] for accumulator in weight_accumulator.values())

        for client_id, accumulator in weight_accumulator.items():
            weight = accumulator['n_samples'] / total_samples
            for name, diff in accumulator['weights'].items():
                averaged_weights[name] += weight * diff

        return averaged_weights

    def update_global_model(self, weight_accumulator):
        averaged_weights = self.federated_average(weight_accumulator)
        with torch.no_grad():
            for name, param in self.global_model.named_parameters():
                param.add_(averaged_weights[name])

    def distribute_global_model(self):
        for client in self.clients:
            client.policy.load_state_dict(self.global_model.state_dict())
            client.policy_old.load_state_dict(self.global_model.state_dict())

class FederatedTrainer:
    def __init__(self, federated_ppo, envs, rounds=30, local_episodes=50):
        self.federated_ppo = federated_ppo
        self.envs = envs
        self.rounds = rounds
        self.local_episodes = local_episodes
        self.global_rewards = []
        self.client_rewards = [[] for _ in range(len(envs))]
        self.best_path = None
        self.best_reward = -float('inf')
        self.client_best_paths = [None] * len(envs)
        self.client_best_rewards = [-float('inf')] * len(envs)

    def compute_weight_update(self, client_id, old_model, new_model):
        weight_diff = {}
        for name, old_param in old_model.state_dict().items():
            new_param = new_model.state_dict()[name]
            weight_diff[name] = new_param - old_param
        return weight_diff

    def train_client(self, client_id, env):
        client = self.federated_ppo.clients[client_id]
        total_samples = 0
        episode_rewards = []
        best_local_path = None
        best_local_reward = -float('inf')

        for episode in range(self.local_episodes):
            state = env.reset()
            path = [env.position]

            states, actions, log_probs, rewards = [], [], [], []
            next_states, dones = [], []
            episode_reward = 0

            for step in range(env.max_steps):
                action, log_prob = client.select_action(state)
                next_state, reward, done, _ = env.step(action)

                states.append(state)
                actions.append(action)
                log_probs.append(log_prob)
                rewards.append(reward)
                next_states.append(next_state)
                dones.append(done)
                path.append(env.position)

                state = next_state
                episode_reward += reward

                if done:
                    if env.position == env.goal and episode_reward > best_local_reward:
                        best_local_path = path.copy()
                        best_local_reward = episode_reward
                    break

            episode_rewards.append(episode_reward)
            client.update(states, actions, log_probs, rewards, next_states, dones)
            total_samples += len(states)

        return total_samples, best_local_path, best_local_reward, np.mean(episode_rewards)

    def train(self):
        print("Starting Federated Training with 2 Clients")
        for round_num in range(self.rounds):
            print(f"\nFederated Round {round_num + 1}/{self.rounds}")
            weight_accumulator = {}
            round_rewards = []

             
            global_model_state = copy.deepcopy(self.federated_ppo.global_model.state_dict())

             
            for client_id in range(self.federated_ppo.num_clients):
                print(f"Training Client {client_id + 1}")
                n_samples, local_best_path, local_best_reward, avg_reward = self.train_client(
                    client_id, self.envs[client_id]
                )

                self.client_rewards[client_id].append(avg_reward)
                round_rewards.append(avg_reward)

                 
                if local_best_reward > self.client_best_rewards[client_id]:
                    self.client_best_paths[client_id] = local_best_path
                    self.client_best_rewards[client_id] = local_best_reward

                weight_diff = self.compute_weight_update(
                    client_id,
                    self.federated_ppo.global_model,
                    self.federated_ppo.clients[client_id].policy
                )

                weight_accumulator[client_id] = {
                    'weights': weight_diff,
                    'n_samples': n_samples
                }

             
            self.federated_ppo.update_global_model(weight_accumulator)
            self.federated_ppo.distribute_global_model()

             
            avg_round_reward = np.mean(round_rewards)
            self.global_rewards.append(avg_round_reward)

            print(f"Round {round_num + 1} Summary:")
            print(f"Global Average Reward: {avg_round_reward:.2f}")
            for client_id in range(self.federated_ppo.num_clients):
                print(f"Client {client_id + 1} Average Reward: {round_rewards[client_id]:.2f}")

             
            if (round_num + 1) % 5 == 0:
                self.visualize_all_paths(round_num + 1)

    def visualize_all_paths(self, round_num):
        """Visualize paths for both clients and global model side by side"""
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        fig.suptitle(f'Best Paths after Round {round_num}')

        for i, (path, reward) in enumerate(zip(self.client_best_paths, self.client_best_rewards)):
            if path:
                ax = axes[i]
                env = self.envs[i]

                 
                obs_x = [x[0] for x in env.obs]
                obs_y = [x[1] for x in env.obs]
                ax.plot(obs_x, obs_y, "sk", label="Obstacles")

                 
                ax.plot(env.start[0], env.start[1], "bs", label="Start")
                ax.plot(env.goal[0], env.goal[1], "gs", label="Goal")

                 
                path_x = [p[0] for p in path]
                path_y = [p[1] for p in path]
                ax.plot(path_x, path_y, 'r-', label="Path")

                ax.set_title(f'Client {i+1}\nReward: {reward:.2f}')
                ax.legend()
                ax.grid(True)
                ax.axis("equal")

        plt.tight_layout()
        plt.show()

    def plot_training_progress(self):
        """Plot training progress for global model and both clients"""
        plt.figure(figsize=(12, 6))

         
        plt.plot(self.global_rewards, 'k-', label='Global Average', linewidth=2)

         
        for i, rewards in enumerate(self.client_rewards):
            plt.plot(rewards, '--', label=f'Client {i+1}', alpha=0.7)

        plt.title('Training Progress')
        plt.xlabel('Round')
        plt.ylabel('Average Reward')
        plt.legend()
        plt.grid(True)
        plt.show()

def train_federated_ppo(rounds=30, local_episodes=50):
    """Main function to train 2-client federated PPO"""
     
    envs = [Env(), Env()]

     
    fed_ppo = FederatedPPO(
        num_clients=2,
        state_dim=envs[0].state_dim,
        action_dim=envs[0].action_dim
    )

     
    trainer = FederatedTrainer(
        federated_ppo=fed_ppo,
        envs=envs,
        rounds=rounds,
        local_episodes=local_episodes
    )

     
    trainer.train()

     
    trainer.plot_training_progress()

    return trainer.client_best_paths, trainer.client_best_rewards

if __name__ == "__main__":
     
    best_paths, best_rewards = train_federated_ppo(
        rounds=30,
        local_episodes=50
    )

# MULTI-AGENT IN STATIC OBSTACLES

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical
from itertools import combinations
from copy import deepcopy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CBSConstraints(object):
    def __init__(self):
        self.vertex_constraints = set()
        self.edge_constraints = set()

    def add_constraint(self, other):
        self.vertex_constraints |= other.vertex_constraints
        self.edge_constraints |= other.edge_constraints

class Location(object):
    def __init__(self, x=-1, y=-1):
        self.x = x
        self.y = y
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
    def __str__(self):
        return str((self.x, self.y))

class State(object):
    def __init__(self, time, location):
        self.time = time
        self.location = location
    def __eq__(self, other):
        return self.time == other.time and self.location == other.location
    def __hash__(self):
        return hash((self.time, self.location.x, self.location.y))
    def is_equal_except_time(self, state):
        return self.location == state.location
    def __str__(self):
        return f"({self.time}, {self.location.x}, {self.location.y})"

class Conflict(object):
    VERTEX = 1
    EDGE = 2
    def __init__(self):
        self.time = -1
        self.type = -1
        self.agent_1 = ''
        self.agent_2 = ''
        self.location_1 = Location()
        self.location_2 = Location()
    def __str__(self):
        return f"({self.time}, {self.agent_1}, {self.agent_2}, {self.location_1}, {self.location_2})"

def get_state(solution, agent_name, t):
    if t < len(solution[agent_name]):
        return solution[agent_name][t]
    else:
        return solution[agent_name][-1]

def detect_first_conflict(solution):
    max_t = max(len(path) for path in solution.values())
    for t in range(max_t):
        for agent1, agent2 in combinations(solution.keys(), 2):
            state1 = get_state(solution, agent1, t)
            state2 = get_state(solution, agent2, t)
            if state1.is_equal_except_time(state2):
                conflict = Conflict()
                conflict.time = t
                conflict.type = Conflict.VERTEX
                conflict.agent_1 = agent1
                conflict.agent_2 = agent2
                conflict.location_1 = state1.location
                return conflict
        for agent1, agent2 in combinations(solution.keys(), 2):
            state1a = get_state(solution, agent1, t)
            state1b = get_state(solution, agent1, t+1)
            state2a = get_state(solution, agent2, t)
            state2b = get_state(solution, agent2, t+1)
            if state1a.is_equal_except_time(state2b) and state1b.is_equal_except_time(state2a):
                conflict = Conflict()
                conflict.time = t
                conflict.type = Conflict.EDGE
                conflict.agent_1 = agent1
                conflict.agent_2 = agent2
                conflict.location_1 = state1a.location
                conflict.location_2 = state1b.location
                return conflict
    return None

def detect_conflict(paths):
    solution = {}
    for i, path in enumerate(paths):
        agent_name = f"agent{i}"
        state_list = []
        for t, pos in enumerate(path):
            state_list.append(State(t, Location(pos[0], pos[1])))
        solution[agent_name] = state_list
    return detect_first_conflict(solution)

class VertexConstraint(object):
    def __init__(self, time, location):
        self.time = time
        self.location = location
        
    def __eq__(self, other):
        return self.time == other.time and self.location == other.location
        
    def __hash__(self):
        return hash(str(self.time) + str(self.location))

class EdgeConstraint(object):
    def __init__(self, time, location_1, location_2):
        self.time = time
        self.location_1 = location_1
        self.location_2 = location_2
        
    def __eq__(self, other):
        return (self.time == other.time and 
                self.location_1 == other.location_1 and 
                self.location_2 == other.location_2)
                
    def __hash__(self):
        return hash(str(self.time) + str(self.location_1) + str(self.location_2))

class CBSNode:
    def __init__(self):
        self.solution = {}
        self.constraints = {}
        self.cost = 0
        
    def __lt__(self, other):
        return self.cost < other.cost

def compute_potential_field(pos, other_pos, other_path, goal_pos, obs, x_range, y_range):
    pos = np.array(pos)
    other_pos = np.array(other_pos)
    goal_pos = np.array(goal_pos)

    k_attr = 1.0   
    k_rep = 150.0   
    k_obs = 100.0   
    k_path = 100.0  
    d0 = 7.0       
    d_path = 5.0   

    f_attr = k_attr * (goal_pos - pos)
    
    dist_to_other = np.linalg.norm(pos - other_pos)
    if dist_to_other < d0:
        f_rep = k_rep * (1/dist_to_other - 1/d0) * (1/dist_to_other**2) * (pos - other_pos)
    else:
        f_rep = np.zeros(2)

    f_path = np.zeros(2)
    if other_path:
        for path_pos in other_path:
            path_pos = np.array(path_pos)
            dist_to_path = np.linalg.norm(pos - path_pos)
            if dist_to_path < d_path:
                f_path += k_path * (1/dist_to_path - 1/d_path) * (1/dist_to_path**2) * (pos - path_pos)

    f_obs = np.zeros(2)
    for dx, dy in [(0,1), (1,0), (0,-1), (-1,0), (1,1), (-1,1), (1,-1), (-1,-1)]:
        check_pos = (int(pos[0] + dx), int(pos[1] + dy))
        if (check_pos in obs or 
            not (0 <= check_pos[0] < x_range and 0 <= check_pos[1] < y_range) or
            check_pos in other_path):
            dist = np.linalg.norm(np.array([dx, dy]))
            if dist < d0:
                f_obs += k_obs * (1/dist - 1/d0) * (1/dist**2) * np.array([-dx, -dy])

    f_total = f_attr + 2*f_rep + f_obs + 1.5*f_path

    f_norm = np.linalg.norm(f_total)
    if f_norm > 0:
        f_total = f_total / f_norm
        
    return f_total

class MultiAgentEnv:
    def __init__(self):
        self.x_range = 61
        self.y_range = 41
        self.motions = [(-1, 0), (-1, 1), (0, 1), (1, 1),
                        (1, 0), (1, -1), (0, -1), (-1, -1)]
        self.obs = self.obs_map()
        self.state_dim = 23
        self.action_dim = len(self.motions)
        self.max_steps = 500
        self.cbs_constraints = CBSConstraints()
        self.agent_paths = [[], []]

        self.safety_radius = 3        
        self.collision_penalty = -40       
        self.path_penalty = -20            
        self.min_path_distance = 2         
        self.path_memory_length = 5       

        self.stagnation_penalty = -400     
        
        self.goal_reached_bonus = 2000.0
        self.progress_reward_scale = 30.0  

        self.agents = [
            {'start': (10, 30), 'goal': (55, 5), 'position': (10, 30), 'steps_taken': 0},
            {'start': (5, 5), 'goal': (45, 35), 'position': (5, 5), 'steps_taken': 0}
        ]
        
        self.path_crossing_cooldown = 20  
        self.last_crossing = [0, 0]       
        self.alternative_path_bonus = 15   
        self.deadlock_threshold = 15      
        self.deadlock_counter = [0, 0]
        self.deadlock_counter = [0, 0] 
        self.path_history = [[] for _ in range(len(self.agents))]  
        self.path_obstacle_memory = 5 
    def apply_path_constraints(self):
        for t in range(len(self.agent_paths[0])):
            for i in range(len(self.agent_paths)):
                if t < len(self.agent_paths[i]):
                    loc = Location(self.agent_paths[i][t][0], self.agent_paths[i][t][1])
                    vertex_const = VertexConstraint(t, loc)
                    self.cbs_constraints.vertex_constraints.add(vertex_const)

    def obs_map(self):
        x = self.x_range
        y = self.y_range
        obs = set()
        for i in range(x):
            obs.add((i, 0))
        for i in range(x):
            obs.add((i, y - 1))
        for i in range(y):
            obs.add((0, i))
        for i in range(y):
            obs.add((x - 1, i))
        for i in range(10, 21):
            obs.add((i, 15))
        for i in range(15):
            obs.add((20, i))
        for i in range(15, 30):
            obs.add((30, i))
        for i in range(16):
            obs.add((40, i))
        num_random_obstacles = 30
        for _ in range(num_random_obstacles):
            rx = np.random.randint(1, self.x_range - 1)
            ry = np.random.randint(1, self.y_range - 1)
            if (rx, ry) in [(5, 5), (55, 35), (55, 5), (10, 30)]:
                continue
            obs.add((rx, ry))
        return obs

    def reset(self):
        self.agent_paths = [[], []]
        
        self.agents = [
            {'start': (10, 30), 'goal': (55, 5), 'position': (10, 30), 'steps_taken': 0},
            {'start': (5, 5), 'goal': (45, 35), 'position': (5, 5), 'steps_taken': 0}
        ]
        self.visited_positions = [{agent['start']} for agent in self.agents]
        self.exploration_bonuses = [{} for _ in range(2)]
        self.stuck_counters = [0, 0]
        self.prev_positions = [[], []]
        for agent in self.agents:
            start = np.array(agent['start'])
            goal = np.array(agent['goal'])
            agent['prev_dist'] = np.linalg.norm(start - goal)
        return self._get_state()

    def _get_state(self):
        states = []
        for i, agent in enumerate(self.agents):
            pos = np.array(agent['position'])
            goal = np.array(agent['goal'])
            other_agent_pos = np.array(self.agents[1-i]['position'])
            other_agent_path = self.agent_paths[1-i]
            
            min_path_dist = float('inf')
            if other_agent_path:
                for path_pos in other_agent_path:
                    dist = np.linalg.norm(np.array(path_pos) - pos)
                    min_path_dist = min(min_path_dist, dist)
            min_path_dist = min(min_path_dist, np.sqrt(self.x_range**2 + self.y_range**2))
            
            dist_to_goal = np.linalg.norm(pos - goal)
            dist_to_other_agent = np.linalg.norm(pos - other_agent_pos)
            angle_to_goal = np.arctan2(goal[1]-pos[1], goal[0]-pos[0]) / np.pi
            angle_to_other_agent = np.arctan2(other_agent_pos[1]-pos[1], other_agent_pos[0]-pos[0]) / np.pi
            
            other_agent_prev_pos = np.array(self.agents[1-i].get('prev_position', other_agent_pos))
            relative_velocity = (other_agent_pos - other_agent_prev_pos)
            rel_vel_x = relative_velocity[0] / self.x_range
            rel_vel_y = relative_velocity[1] / self.y_range
            
            if np.any(relative_velocity):
                time_to_collision = dist_to_other_agent / np.linalg.norm(relative_velocity)
                time_to_collision = np.clip(time_to_collision / 10.0, 0, 1)
            else:
                time_to_collision = 1.0

            obstacle_dists = []
            for dx, dy in self.motions:
                x, y = agent['position']
                dist = 0
                while True:
                    x += dx
                    y += dy
                    dist += 1
                    next_pos = (x, y)
                    if (next_pos in self.obs or 
                        not (0 <= x < self.x_range and 0 <= y < self.y_range) or
                        next_pos in other_agent_path or
                        (np.linalg.norm(np.array([x, y]) - other_agent_pos) < self.safety_radius)):
                        obstacle_dists.append(min(dist/10.0, 1.0))
                        break

            norm_path_dist = min_path_dist / np.sqrt(self.x_range**2 + self.y_range**2)
            
            state = np.array([
                pos[0] / self.x_range,
                pos[1] / self.y_range,
                goal[0] / self.x_range,
                goal[1] / self.y_range,
                dist_to_goal / np.sqrt(self.x_range**2 + self.y_range**2),
                angle_to_goal,
                other_agent_pos[0] / self.x_range,
                other_agent_pos[1] / self.y_range,
                dist_to_other_agent / np.sqrt(self.x_range**2 + self.y_range**2),
                angle_to_other_agent,
                rel_vel_x,
                rel_vel_y,
                time_to_collision,
                norm_path_dist,  
                *obstacle_dists,  
                agent['steps_taken'] / self.max_steps
            ], dtype=np.float32)
            states.append(state)
        return states

    def get_exploration_bonus(self, agent_idx, position):
        if position not in self.exploration_bonuses[agent_idx]:
            self.exploration_bonuses[agent_idx][position] = 1.0
        else:
            self.exploration_bonuses[agent_idx][position] *= self.exploration_decay
            self.exploration_bonuses[agent_idx][position] = max(
                self.exploration_bonuses[agent_idx][position],
                self.min_exploration_bonus
            )
        return self.exploration_bonuses[agent_idx][position]

    def validate_move(self, pos, time, agent_id):
        x, y = pos
        if not (0 <= x < self.x_range and 0 <= y < self.y_range):
            return False
        if pos in self.obs:
            return False
            
        other_agent_path = self.agent_paths[1-agent_id]
        other_agent_pos = self.agents[1-agent_id]['position']
        pos_array = np.array(pos)
        safety_distance = 4.0
        
        other_pos_array = np.array(other_agent_pos)
        if np.linalg.norm(pos_array - other_pos_array) < safety_distance:
            return False
            
        if other_agent_path:
            for path_pos in other_agent_path:
                path_pos_array = np.array(path_pos)
                dist = np.linalg.norm(pos_array - path_pos_array)
                if dist < safety_distance:
                    return False
        
        if len(other_agent_path) >= 2:
            current_pos = np.array(self.agents[agent_id]['position'])
            
            for i in range(len(other_agent_path) - 1):
                path_seg_start = np.array(other_agent_path[i])
                path_seg_end = np.array(other_agent_path[i + 1])
                corridor_width = 5.0  
                path_vector = path_seg_end - path_seg_start
                path_length = np.linalg.norm(path_vector)
                if path_length == 0:
                    continue
                    
                perp_vector = np.array([-path_vector[1], path_vector[0]]) / path_length
                
                pos_relative = pos_array - path_seg_start
                dist_from_path = abs(np.dot(pos_relative, perp_vector))
                
                if dist_from_path < corridor_width:
                    proj_length = np.dot(pos_relative, path_vector) / path_length
                    if 0 <= proj_length <= path_length:
                        return False
    
        for i, agent in enumerate(self.agents):
            if i != agent_id and pos == agent['goal'] and pos == agent['start']:
                return False
                
        loc = Location(pos[0], pos[1])
        vertex_const = VertexConstraint(time, loc)
        return vertex_const not in self.cbs_constraints.vertex_constraints
        
    def validate_transition(self, pos1, pos2,time):
        loc1 = Location(pos1[0], pos1[1])
        loc2 = Location(pos2[0], pos2[1])
        edge_const = EdgeConstraint(time, loc1, loc2)
        return edge_const not in self.cbs_constraints.edge_constraints

    def check_stuck(self, agent_idx, position):
        self.prev_positions[agent_idx].append(position)
        if len(self.prev_positions[agent_idx]) > 10:
            self.prev_positions[agent_idx].pop(0)
        if len(self.prev_positions[agent_idx]) == 10:
            unique_positions = len(set(self.prev_positions[agent_idx]))
            if unique_positions <= 3:
                self.stuck_counters[agent_idx] += 1
                return True
        return False

    def step(self, actions):
        rewards = [0, 0]
        dones = [False, False]
        prev_positions = [agent['position'] for agent in self.agents]
        prev_distances = [agent['prev_dist'] for agent in self.agents]
        current_time = self.agents[0]['steps_taken']
    
        for i, agent in enumerate(self.agents):
            position = agent['position']
            self.path_history[i].append(position)
    
            if len(self.path_history[i]) > self.path_obstacle_memory:
                self.path_history[i].pop(0)
        candidate_positions = []
        for i, (agent, action) in enumerate(zip(self.agents, actions)):
            if agent.get('reached_goal', False):
                candidate_positions.append(agent['position'])
                continue
                
            x, y = agent['position']
            dx, dy = self.motions[action]
            new_pos = (x+dx, y+dy)
            
            if self.check_deadlock(i, new_pos):
                self.deadlock_counter[i] += 1
            else:
                self.deadlock_counter[i] = 0
            
            path_violation = self.check_path_violation(i, new_pos,i)
            other_agent_dist = np.linalg.norm(np.array(new_pos) - np.array(self.agents[1-i]['position']))
            
            if path_violation:
                candidate_positions.append(agent['position'])
                rewards[i] += self.path_penalty * 0.5
            elif new_pos == self.agents[1-i]['position']:
                candidate_positions.append(agent['position'])
                rewards[i] += self.collision_penalty
            elif other_agent_dist < self.safety_radius:
                current_dist_to_goal = np.linalg.norm(np.array(new_pos) - np.array(agent['goal']))
                if current_dist_to_goal < agent['prev_dist']:
                    candidate_positions.append(new_pos)
                    rewards[i] += self.collision_penalty * 0.2
                else:
                    candidate_positions.append(agent['position'])
                    rewards[i] += self.collision_penalty * 0.3
            elif self.validate_move(new_pos, current_time + 1, i):
                candidate_positions.append(new_pos)
                if self.is_alternative_path(i, new_pos):
                    rewards[i] += self.alternative_path_bonus
            else:
                candidate_positions.append(agent['position'])
                rewards[i] += self.collision_penalty * 0.2
    
        for i, agent in enumerate(self.agents):
            if agent.get('reached_goal', False):
                continue
                
            new_position = candidate_positions[i]
            old_position = agent['position']
            agent['position'] = new_position
            agent['steps_taken'] += 1
            
            if self.deadlock_counter[i] >= self.deadlock_threshold:
                self.min_path_distance = max(1, self.min_path_distance - 0.5)
                rewards[i] += self.stagnation_penalty * 5
            else:
                self.min_path_distance = min(2, self.min_path_distance + 0.1)
            
            self.update_path_history(i, new_position)
            
            rewards[i] += self.compute_rewards(i, new_position, old_position)
            
            if new_position == agent['goal']:
                rewards[i] += self.goal_reached_bonus
                dones[i] = True
                agent['reached_goal'] = True
        
        if ((dones[0] and dones[1]) or 
            all(agent['steps_taken'] >= self.max_steps for agent in self.agents)):
            dones = [True, True]
            
        self.apply_path_constraints() 
        
        return self._get_state(), rewards, dones, {}
    def is_valid_move(agent_id, position):
        if position in self.obs:
            return False
        
        for j, path in enumerate(self.path_history):
            if j != agent_id and position in path:
                return False
        return True


    def check_deadlock(self, agent_idx, position):
        if position == self.agents[agent_idx]['position']:
            blocked_directions = 0
            for dx, dy in self.motions:
                check_pos = (position[0] + dx, position[1] + dy)
                if not self.validate_move(check_pos, self.agents[agent_idx]['steps_taken'] + 1, agent_idx):
                    blocked_directions += 1
            return blocked_directions >= 6 
        return False

    def check_path_violation(self, agent_idx, position, current_agent_idx):
        other_agent_path = self.agent_paths[1-agent_idx][-self.path_memory_length:]
        if not other_agent_path:
            return False
            
        pos_array = np.array(position)
        for i, path_seg_start in enumerate(other_agent_path[:-1]):
            path_seg_end = other_agent_path[i + 1]
            
            corridor_width = 5.0  
            
            path_vector = np.array(path_seg_end) - np.array(path_seg_start)
            path_length = np.linalg.norm(path_vector)
            if path_length == 0:
                continue
                
            perp_vector = np.array([-path_vector[1], path_vector[0]]) / path_length
            
            pos_relative = pos_array - np.array(path_seg_start)
            dist_from_path = abs(np.dot(pos_relative, perp_vector))
            
            if dist_from_path < corridor_width:
                proj_length = np.dot(pos_relative, path_vector) / path_length
                if 0 <= proj_length <= path_length:
                    return True
        
        for i, agent in enumerate(self.agents):
            if i != agent_idx and position == agent['goal'] and position == agent['start']:
                return True
                
        loc = Location(position[0], position[1])
        vertex_const = VertexConstraint(self.agents[current_agent_idx]['steps_taken'] + 1, loc)
        return vertex_const in self.cbs_constraints.vertex_constraints

    def is_alternative_path(self, agent_idx, position):
        other_agent = self.agents[1-agent_idx]
        other_pos = np.array(other_agent['position'])
        other_goal = np.array(other_agent['goal'])
        
        direct_path_vector = other_goal - other_pos
        pos_vector = np.array(position) - other_pos
        
        dot_product = np.dot(direct_path_vector, pos_vector)
        norms_product = np.linalg.norm(direct_path_vector) * np.linalg.norm(pos_vector)
        
        if norms_product == 0:
            return False
            
        angle = np.arccos(np.clip(dot_product / norms_product, -1.0, 1.0))
        
        return abs(angle) > np.pi/2.5

    def update_path_history(self, agent_idx, position):
        if len(self.agent_paths[agent_idx]) > self.path_memory_length:
            keep_prob = np.linspace(0.5, 1.0, len(self.agent_paths[agent_idx]))
            self.agent_paths[agent_idx] = [pos for pos, prob in zip(self.agent_paths[agent_idx], keep_prob)
                                         if np.random.random() < prob]
        self.agent_paths[agent_idx].append(position)

    def compute_rewards(self, agent_idx, new_position, old_position):
        reward = 0
        current_dist = np.linalg.norm(np.array(new_position) - np.array(self.agents[agent_idx]['goal']))
        progress = self.agents[agent_idx]['prev_dist'] - current_dist
        self.agents[agent_idx]['prev_dist'] = current_dist
        
        other_agent_path = self.agent_paths[1-agent_idx]
        if other_agent_path:
            min_path_dist = float('inf')
            pos_array = np.array(new_position)
            
            for path_pos in other_agent_path:
                path_pos_array = np.array(path_pos)
                dist = np.linalg.norm(pos_array - path_pos_array)
                min_path_dist = min(min_path_dist, dist)
                
            if min_path_dist < 5.0:  
                path_penalty = -200.0 * (5.0 - min_path_dist) 
                reward += path_penalty
        
        if self.is_alternative_path(agent_idx, new_position):
            progress *= 2.0 
            
        reward += progress * self.progress_reward_scale
        
        if new_position not in self.visited_positions[agent_idx]:
            self.visited_positions[agent_idx].add(new_position)
            reward += 20.0  
        if new_position == old_position and not self.is_blocked(agent_idx):
            reward += self.stagnation_penalty * 2  
            
        return reward
    
    def is_blocked(self, agent_idx):
        agent = self.agents[agent_idx]
        for dx, dy in self.motions:
            new_pos = (agent['position'][0] + dx, agent['position'][1] + dy)
            if self.validate_move(new_pos, agent['steps_taken'] + 1, agent_idx):
                return False
        return True

    def get_agent_path(self, agent_idx):
        return self.prev_positions[agent_idx]

    def apply_cbs_constraints(self, conflict):
        if any(agent.get('reached_goal', False) for agent in self.agents):
            return

        if conflict.type == Conflict.VERTEX:
            loc = conflict.location_1 
            vertex_const = VertexConstraint(conflict.time, loc)
            new_constraints = CBSConstraints()
            new_constraints.vertex_constraints.add(vertex_const)
            self.cbs_constraints.add_constraint(new_constraints)
        elif conflict.type == Conflict.EDGE:
            new_constraints = CBSConstraints()
            edge_const1 = EdgeConstraint(conflict.time, conflict.location_1, conflict.location_2)
            edge_const2 = EdgeConstraint(conflict.time, conflict.location_2, conflict.location_1)
            new_constraints.edge_constraints.add(edge_const1)
            new_constraints.edge_constraints.add(edge_const2)
            self.cbs_constraints.add_constraint(new_constraints)
            
    def _segments_intersect(self, p1, p2, p3, p4):
        def ccw(A, B, C):
            return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])
            
        buffer = 0.1
        p1 = p1 + buffer
        p2 = p2 + buffer
        return ccw(p1,p3,p4) != ccw(p2,p3,p4) and ccw(p1,p2,p3) != ccw(p1,p2,p4)

class MultiAgentPPO:
    def __init__(self, state_dim, action_dim, num_agents=2):
        self.num_agents = num_agents
        self.agents = [PPOAgent(state_dim, action_dim) for _ in range(num_agents)]
    
    def select_actions(self, states, training=True):
        actions = []
        log_probs = []
        for i, agent in enumerate(self.agents):
            action, log_prob = agent.select_action(states[i], training)
            actions.append(action)
            log_probs.append(log_prob)
        return actions, log_probs
    
    def update(self, agent_trajectories):
        for i in range(self.num_agents):
            traj = agent_trajectories[i]
            self.agents[i].update(
                traj['states'],
                traj['actions'],
                traj['log_probs'],
                traj['rewards'],
                traj['next_states'],
                traj['dones']
            )

def train_multi_agent_ppo(num_episodes=3000, max_steps=500):
    env = MultiAgentEnv()
    agent_wrapper = MultiAgentPPO(env.state_dim, env.action_dim)
    
    best_episode_paths = None
    best_combined_reward = float('-inf')
    best_individual_rewards = [0, 0]  
    
    for episode in range(num_episodes):
        states = env.reset()
        total_rewards = [0, 0]
        paths = [[], []]
        agent_trajectories = [
            {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'next_states': [], 'dones': []}
            for _ in range(2)
        ]
        
        for step in range(max_steps):
            actions, log_probs = agent_wrapper.select_actions(states)
            next_states, rewards, dones, _ = env.step(actions)
            
            for i in range(2):
                agent_trajectories[i]['states'].append(states[i])
                agent_trajectories[i]['actions'].append(actions[i])
                agent_trajectories[i]['log_probs'].append(log_probs[i])
                agent_trajectories[i]['rewards'].append(rewards[i])
                agent_trajectories[i]['next_states'].append(next_states[i])
                agent_trajectories[i]['dones'].append(dones[i])
                paths[i].append(env.agents[i]['position'])
                total_rewards[i] += rewards[i]
            
            conflict = detect_conflict(paths)
            if conflict:
                env.apply_cbs_constraints(conflict)
            
            states = next_states
            if all(dones):
                break
        
        combined_reward = sum(total_rewards)
        if (all(env.agents[i].get('reached_goal', False) for i in range(2)) and 
            combined_reward > best_combined_reward and
            not detect_conflict(paths)):
            best_combined_reward = combined_reward
            best_episode_paths = deepcopy(paths)
            best_individual_rewards = deepcopy(total_rewards)  
        
        agent_wrapper.update(agent_trajectories)
        
        if episode % 100 == 0 or episode == num_episodes-1:
            print(f"Episode {episode}")
            print(f"Rewards: Agent1 = {total_rewards[0]:.2f}, Agent2 = {total_rewards[1]:.2f}")
            visualize_multi_agent_paths(env, paths, total_rewards)
    
    print("\nBest successful paths found:")
    if best_episode_paths:
        plt.figure(figsize=(12, 8))
        obs_x = [x for x, y in env.obs]
        obs_y = [y for x, y in env.obs]
        plt.plot(obs_x, obs_y, "sk", label="Obstacles")
    
        for i, agent_data in enumerate(env.agents):
            plt.plot(agent_data['start'][0], agent_data['start'][1], f"{['b','g'][i]}s", label=f"Start {i+1}")
            plt.plot(agent_data['goal'][0], agent_data['goal'][1], f"{['r','m'][i]}s", label=f"Goal {i+1}")
    
        colors = ['b-', 'g-']
        for i, path in enumerate(best_episode_paths):
            path_x = [p[0] for p in path]
            path_y = [p[1] for p in path]
            plt.plot(path_x, path_y, colors[i], label=f"Agent {i+1} Path")
    
        plt.title(f"Best Successful Paths\nAgent 1 Reward: {best_individual_rewards[0]:.2f}, Agent 2 Reward: {best_individual_rewards[1]:.2f}")
        plt.legend()
        plt.grid(True)
        plt.axis("equal")
        plt.show()
    else:
        print("No completely successful episodes found where both agents reached their goals.")
    
    return best_episode_paths, total_rewards

def visualize_multi_agent_paths(env, paths, rewards):
    plt.figure(figsize=(12, 8))
    obs_x = [x for x, y in env.obs]
    obs_y = [y for x, y in env.obs]
    plt.plot(obs_x, obs_y, "sk", label="Obstacles")
    
    for i, agent_data in enumerate(env.agents):
        plt.plot(agent_data['start'][0], agent_data['start'][1], f"{['b','g'][i]}s", label=f"Start {i+1}")
        plt.plot(agent_data['goal'][0], agent_data['goal'][1], f"{['r','m'][i]}s", label=f"Goal {i+1}")
    
    colors = ['b-', 'g-']
    for i, path in enumerate(paths):
        if len(path) == 0 or path[0] != env.agents[i]['start']:
            full_path = [env.agents[i]['start']] + path
        else:
            full_path = path
            
        path_x = [p[0] for p in full_path]
        path_y = [p[1] for p in full_path]
        plt.plot(path_x, path_y, colors[i], label=f"Agent {i+1} Path")
    
    plt.title(f"Multi-Agent Paths\nAgent 1 Reward: {rewards[0]:.2f}, Agent 2 Reward: {rewards[1]:.2f}")
    plt.legend()
    plt.grid(True)
    plt.axis("equal")
    plt.show()


if __name__ == "__main__":
    train_multi_agent_ppo(num_episodes=2000)

# MULTI-AGENT IN DYNAMIC ENVIRONMENT

In [None]:
import numpy as np
from matplotlib.animation import FuncAnimation
import matplotlib.patches as patches
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical
from itertools import combinations
from copy import deepcopy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CBSConstraints(object):
    def __init__(self):
        self.vertex_constraints = set()
        self.edge_constraints = set()

    def add_constraint(self, other):
        self.vertex_constraints |= other.vertex_constraints
        self.edge_constraints |= other.edge_constraints

class Location(object):
    def __init__(self, x=-1, y=-1):
        self.x = x
        self.y = y
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
    def __str__(self):
        return str((self.x, self.y))

class State(object):
    def __init__(self, time, location):
        self.time = time
        self.location = location
    def __eq__(self, other):
        return self.time == other.time and self.location == other.location
    def __hash__(self):
        return hash((self.time, self.location.x, self.location.y))
    def is_equal_except_time(self, state):
        return self.location == state.location
    def __str__(self):
        return f"({self.time}, {self.location.x}, {self.location.y})"

class Conflict(object):
    VERTEX = 1
    EDGE = 2
    def __init__(self):
        self.time = -1
        self.type = -1
        self.agent_1 = ''
        self.agent_2 = ''
        self.location_1 = Location()
        self.location_2 = Location()
    def __str__(self):
        return f"({self.time}, {self.agent_1}, {self.agent_2}, {self.location_1}, {self.location_2})"

def get_state(solution, agent_name, t):
    if t < len(solution[agent_name]):
        return solution[agent_name][t]
    else:
        return solution[agent_name][-1]

def detect_first_conflict(solution):
    max_t = max(len(path) for path in solution.values())
    for t in range(max_t):
        for agent1, agent2 in combinations(solution.keys(), 2):
            state1 = get_state(solution, agent1, t)
            state2 = get_state(solution, agent2, t)
            if state1.is_equal_except_time(state2):
                conflict = Conflict()
                conflict.time = t
                conflict.type = Conflict.VERTEX
                conflict.agent_1 = agent1
                conflict.agent_2 = agent2
                conflict.location_1 = state1.location
                return conflict
        for agent1, agent2 in combinations(solution.keys(), 2):
            state1a = get_state(solution, agent1, t)
            state1b = get_state(solution, agent1, t+1)
            state2a = get_state(solution, agent2, t)
            state2b = get_state(solution, agent2, t+1)
            if state1a.is_equal_except_time(state2b) and state1b.is_equal_except_time(state2a):
                conflict = Conflict()
                conflict.time = t
                conflict.type = Conflict.EDGE
                conflict.agent_1 = agent1
                conflict.agent_2 = agent2
                conflict.location_1 = state1a.location
                conflict.location_2 = state1b.location
                return conflict
    return None

def detect_conflict(paths):
    solution = {}
    for i, path in enumerate(paths):
        agent_name = f"agent{i}"
        state_list = []
        for t, pos in enumerate(path):
            state_list.append(State(t, Location(pos[0], pos[1])))
        solution[agent_name] = state_list
    return detect_first_conflict(solution)

class VertexConstraint(object):
    def __init__(self, time, location):
        self.time = time
        self.location = location

    def __eq__(self, other):
        return self.time == other.time and self.location == other.location

    def __hash__(self):
        return hash(str(self.time) + str(self.location))

class EdgeConstraint(object):
    def __init__(self, time, location_1, location_2):
        self.time = time
        self.location_1 = location_1
        self.location_2 = location_2

    def __eq__(self, other):
        return (self.time == other.time and
                self.location_1 == other.location_1 and
                self.location_2 == other.location_2)

    def __hash__(self):
        return hash(str(self.time) + str(self.location_1) + str(self.location_2))

class CBSNode:
    def __init__(self):
        self.solution = {}
        self.constraints = {}
        self.cost = 0

    def __lt__(self, other):
        return self.cost < other.cost

def compute_potential_field(pos, other_pos, other_path, goal_pos, obs, x_range, y_range):
    pos = np.array(pos)
    other_pos = np.array(other_pos)
    goal_pos = np.array(goal_pos)

    k_attr = 1.0
    k_rep = 150.0
    k_obs = 100.0
    k_path = 100.0
    d0 = 7.0
    d_path = 5.0

    f_attr = k_attr * (goal_pos - pos)

    dist_to_other = np.linalg.norm(pos - other_pos)
    if dist_to_other < d0:
        f_rep = k_rep * (1/dist_to_other - 1/d0) * (1/dist_to_other**2) * (pos - other_pos)
    else:
        f_rep = np.zeros(2)

    f_path = np.zeros(2)
    if other_path:
        for path_pos in other_path:
            path_pos = np.array(path_pos)
            dist_to_path = np.linalg.norm(pos - path_pos)
            if dist_to_path < d_path:
                f_path += k_path * (1/dist_to_path - 1/d_path) * (1/dist_to_path**2) * (pos - path_pos)

    f_obs = np.zeros(2)
    for dx, dy in [(0,1), (1,0), (0,-1), (-1,0), (1,1), (-1,1), (1,-1), (-1,-1)]:
        check_pos = (int(pos[0] + dx), int(pos[1] + dy))
        if (check_pos in obs or
            not (0 <= check_pos[0] < x_range and 0 <= check_pos[1] < y_range) or
            check_pos in other_path):
            dist = np.linalg.norm(np.array([dx, dy]))
            if dist < d0:
                f_obs += k_obs * (1/dist - 1/d0) * (1/dist**2) * np.array([-dx, -dy])

    f_total = f_attr + 2*f_rep + f_obs + 1.5*f_path

    f_norm = np.linalg.norm(f_total)
    if f_norm > 0:
        f_total = f_total / f_norm

    return f_total

class MultiAgentEnv:
    def __init__(self):
        self.x_range = 61
        self.y_range = 41
        self.motions = [(-1, 0), (-1, 1), (0, 1), (1, 1),
                        (1, 0), (1, -1), (0, -1), (-1, -1)]
        self.obs = self.obs_map()
        self.action_dim = len(self.motions)
        self.max_steps = 500
        self.cbs_constraints = CBSConstraints()
        self.agent_paths = [[], []]

        self.safety_radius = 3
        self.collision_penalty = -40
        self.path_penalty = -20
        self.min_path_distance = 2
        self.path_memory_length = 5

        self.stagnation_penalty = -400

        self.goal_reached_bonus = 2000.0
        self.progress_reward_scale = 30.0

        self.agents = [
            {'start': (10, 30), 'goal': (55, 5), 'position': (10, 30), 'steps_taken': 0},
            {'start': (5, 5), 'goal': (45, 35), 'position': (5, 5), 'steps_taken': 0}
        ]

        self.path_crossing_cooldown = 20
        self.last_crossing = [0, 0]
        self.alternative_path_bonus = 15
        self.deadlock_threshold = 15
        self.deadlock_counter = [0, 0]
        self.path_history = [[] for _ in range(len(self.agents))]
        self.path_obstacle_memory = 5
        self.moving_obstacles = [
            {
                'path_type': 'circle',
                'radius': 8,
                'speed': 0.1,
                'center': (15, 25),
                'position': [15, 25],
                'angle': 0
            },
            {
                'path_type': 'vertical',
                'y_min': 10,
                'y_max': 35,
                'speed': 0.2,
                'position': [35, 25],
                'direction': 1
            },
            {
                'path_type': 'horizontal',
                'x_min': 15,
                'x_max': 35,
                'speed': 0.15,
                'position': [35, 15],
                'direction': 1
            },
            {
                'path_type': 'figure_eight',
                'a': 5,
                'b': 3,
                'center': (45, 10),
                'position': [45, 10],
                'speed': 0.05,
                'angle': 0
            },
            {
                'path_type': 'diagonal',
                'x_min': 10,
                'x_max': 25,
                'y_min': 10,
                'y_max': 20,
                'speed': 0.1,
                'position': [15, 5],
                'direction_x': 1,
                'direction_y': 1
            }
        ]
        self.state_dim = 23 + (4 * len(self.moving_obstacles))

    def apply_path_constraints(self):
        for t in range(len(self.agent_paths[0])):
            for i in range(len(self.agent_paths)):
                if t < len(self.agent_paths[i]):
                    loc = Location(self.agent_paths[i][t][0], self.agent_paths[i][t][1])
                    vertex_const = VertexConstraint(t, loc)
                    self.cbs_constraints.vertex_constraints.add(vertex_const)

    def obs_map(self):
        x = self.x_range
        y = self.y_range
        obs = set()
        for i in range(x):
            obs.add((i, 0))
        for i in range(x):
            obs.add((i, y - 1))
        for i in range(y):
            obs.add((0, i))
        for i in range(y):
            obs.add((x - 1, i))
        for i in range(10, 21):
            obs.add((i, 15))
        for i in range(15):
            obs.add((20, i))
        for i in range(15, 30):
            obs.add((30, i))
        for i in range(16):
            obs.add((40, i))
        return obs

    def reset(self):
        self.agent_paths = [[], []]

        self.agents = [
            {'start': (10, 30), 'goal': (55, 5), 'position': (10, 30), 'steps_taken': 0},
            {'start': (5, 5), 'goal': (45, 35), 'position': (5, 5), 'steps_taken': 0}
        ]
        self.visited_positions = [{agent['start']} for agent in self.agents]
        self.exploration_bonuses = [{} for _ in range(2)]
        self.stuck_counters = [0, 0]
        self.prev_positions = [[], []]
        for agent in self.agents:
            start = np.array(agent['start'])
            goal = np.array(agent['goal'])
            agent['prev_dist'] = np.linalg.norm(start - goal)
        self.obs = self.obs_map()
        return self._get_state()

    def _get_state(self):
        states = []
        for i, agent in enumerate(self.agents):
            pos = np.array(agent['position'])
            goal = np.array(agent['goal'])
            other_agent_pos = np.array(self.agents[1-i]['position'])
            other_agent_path = self.agent_paths[1-i]

            min_path_dist = float('inf')
            if other_agent_path:
                for path_pos in other_agent_path:
                    dist = np.linalg.norm(np.array(path_pos) - pos)
                    min_path_dist = min(min_path_dist, dist)
            min_path_dist = min(min_path_dist, np.sqrt(self.x_range**2 + self.y_range**2))

            dist_to_goal = np.linalg.norm(pos - goal)
            dist_to_other_agent = np.linalg.norm(pos - other_agent_pos)
            angle_to_goal = np.arctan2(goal[1]-pos[1], goal[0]-pos[0]) / np.pi
            angle_to_other_agent = np.arctan2(other_agent_pos[1]-pos[1], other_agent_pos[0]-pos[0]) / np.pi

            other_agent_prev_pos = np.array(self.agents[1-i].get('prev_position', other_agent_pos))
            relative_velocity = (other_agent_pos - other_agent_prev_pos)
            rel_vel_x = relative_velocity[0] / self.x_range
            rel_vel_y = relative_velocity[1] / self.y_range

            if np.any(relative_velocity):
                time_to_collision = dist_to_other_agent / np.linalg.norm(relative_velocity)
                time_to_collision = np.clip(time_to_collision / 10.0, 0, 1)
            else:
                time_to_collision = 1.0

            obstacle_dists = []
            for dx, dy in self.motions:
                x, y = agent['position']
                dist = 0
                while True:
                    x += dx
                    y += dy
                    dist += 1
                    next_pos = (x, y)
                    if (next_pos in self.obs or
                        not (0 <= x < self.x_range and 0 <= y < self.y_range) or
                        next_pos in other_agent_path or
                        (np.linalg.norm(np.array([x, y]) - other_agent_pos) < self.safety_radius)):
                        obstacle_dists.append(min(dist/10.0, 1.0))
                        break

            norm_path_dist = min_path_dist / np.sqrt(self.x_range**2 + self.y_range**2)

            state = np.array([
                pos[0] / self.x_range,
                pos[1] / self.y_range,
                goal[0] / self.x_range,
                goal[1] / self.y_range,
                dist_to_goal / np.sqrt(self.x_range**2 + self.y_range**2),
                angle_to_goal,
                other_agent_pos[0] / self.x_range,
                other_agent_pos[1] / self.y_range,
                dist_to_other_agent / np.sqrt(self.x_range**2 + self.y_range**2),
                angle_to_other_agent,
                rel_vel_x,
                rel_vel_y,
                time_to_collision,
                norm_path_dist,
                *obstacle_dists,
                agent['steps_taken'] / self.max_steps
            ], dtype=np.float32)

            obstacle_info = []
            for obstacle in self.moving_obstacles:
                obstacle_pos = np.array(obstacle['position'])
                dist = np.linalg.norm(pos - obstacle_pos)
                angle = np.arctan2(obstacle_pos[1]-pos[1], obstacle_pos[0]-pos[0]) / np.pi

                if obstacle['path_type'] == 'circle':
                    dx = -obstacle['radius'] * np.sin(obstacle['angle']) * obstacle['speed']
                    dy = obstacle['radius'] * np.cos(obstacle['angle']) * obstacle['speed']
                elif obstacle['path_type'] == 'vertical':
                    dx = 0
                    dy = obstacle['speed'] * obstacle['direction']
                elif obstacle['path_type'] == 'horizontal':
                    dx = obstacle['speed'] * obstacle['direction']
                    dy = 0
                elif obstacle['path_type'] == 'figure_eight':
                    dx = obstacle['a'] * np.cos(obstacle['angle']) * obstacle['speed']
                    dy = obstacle['b'] * np.cos(obstacle['angle'] * 2) * 2 * obstacle['speed']
                elif obstacle['path_type'] == 'diagonal':
                    dx = obstacle['speed'] * obstacle['direction_x']
                    dy = obstacle['speed'] * obstacle['direction_y']

                obstacle_info.extend([
                    dist / np.sqrt(self.x_range**2 + self.y_range**2),
                    angle,
                    dx / self.x_range,
                    dy / self.y_range
                ])

            state = np.append(state, obstacle_info)
            states.append(state)
        return states

    def get_exploration_bonus(self, agent_idx, position):
        if position not in self.exploration_bonuses[agent_idx]:
            self.exploration_bonuses[agent_idx][position] = 1.0
        else:
            self.exploration_bonuses[agent_idx][position] *= self.exploration_decay
            self.exploration_bonuses[agent_idx][position] = max(
                self.exploration_bonuses[agent_idx][position],
                self.min_exploration_bonus
            )
        return self.exploration_bonuses[agent_idx][position]

    def validate_move(self, pos, time, agent_id):
        x, y = pos
        if not (0 <= x < self.x_range and 0 <= y < self.y_range):
            return False
        if pos in self.obs:
            return False

        other_agent_path = self.agent_paths[1-agent_id]
        other_agent_pos = self.agents[1-agent_id]['position']
        pos_array = np.array(pos)
        safety_distance = 4.0

        other_pos_array = np.array(other_agent_pos)
        if np.linalg.norm(pos_array - other_pos_array) < safety_distance:
            return False

        if other_agent_path:
            for path_pos in other_agent_path:
                path_pos_array = np.array(path_pos)
                dist = np.linalg.norm(pos_array - path_pos_array)
                if dist < safety_distance:
                    return False

        if len(other_agent_path) >= 2:
            current_pos = np.array(self.agents[agent_id]['position'])

            for i in range(len(other_agent_path) - 1):
                path_seg_start = np.array(other_agent_path[i])
                path_seg_end = np.array(other_agent_path[i + 1])
                corridor_width = 5.0
                path_vector = path_seg_end - path_seg_start
                path_length = np.linalg.norm(path_vector)
                if path_length == 0:
                    continue

                perp_vector = np.array([-path_vector[1], path_vector[0]]) / path_length

                pos_relative = pos_array - path_seg_start
                dist_from_path = abs(np.dot(pos_relative, perp_vector))

                if dist_from_path < corridor_width:
                    proj_length = np.dot(pos_relative, path_vector) / path_length
                    if 0 <= proj_length <= path_length:
                        return False

        for i, agent in enumerate(self.agents):
            if i != agent_id and pos == agent['goal'] and pos == agent['start']:
                return False

        loc = Location(pos[0], pos[1])
        vertex_const = VertexConstraint(time, loc)
        if vertex_const in self.cbs_constraints.vertex_constraints:
            return False

        for obstacle in self.moving_obstacles:
            obstacle_pos = np.array(obstacle['position'])
            safety_distance = 2.0
            if np.linalg.norm(pos_array - obstacle_pos) < safety_distance:
                return False

        return True  

    def validate_transition(self, pos1, pos2,time):
        loc1 = Location(pos1[0], pos1[1])
        loc2 = Location(pos2[0], pos2[1])
        edge_const = EdgeConstraint(time, loc1, loc2)
        return edge_const not in self.cbs_constraints.edge_constraints

    def check_stuck(self, agent_idx, position):
        self.prev_positions[agent_idx].append(position)
        if len(self.prev_positions[agent_idx]) > 10:
            self.prev_positions[agent_idx].pop(0)
        if len(self.prev_positions[agent_idx]) == 10:
            unique_positions = len(set(self.prev_positions[agent_idx]))
            if unique_positions <= 3:
                self.stuck_counters[agent_idx] += 1
                return True
        return False

    def step(self, actions):
        self.update_moving_obstacles()

        rewards = [0, 0]
        dones = [False, False]
        prev_positions = [agent['position'] for agent in self.agents]
        prev_distances = [agent['prev_dist'] for agent in self.agents]
        current_time = self.agents[0]['steps_taken']

        for i, agent in enumerate(self.agents):
            position = agent['position']
            self.path_history[i].append(position)

            if len(self.path_history[i]) > self.path_obstacle_memory:
                self.path_history[i].pop(0)
        candidate_positions = []
        for i, (agent, action) in enumerate(zip(self.agents, actions)):
            if agent.get('reached_goal', False):
                candidate_positions.append(agent['position'])
                continue

            x, y = agent['position']
            dx, dy = self.motions[action]
            new_pos = (x+dx, y+dy)

            if self.check_deadlock(i, new_pos):
                self.deadlock_counter[i] += 1
            else:
                self.deadlock_counter[i] = 0

            path_violation = self.check_path_violation(i, new_pos,i)
            other_agent_dist = np.linalg.norm(np.array(new_pos) - np.array(self.agents[1-i]['position']))

            if path_violation:
                candidate_positions.append(agent['position'])
                rewards[i] += self.path_penalty * 0.5
            elif new_pos == self.agents[1-i]['position']:
                candidate_positions.append(agent['position'])
                rewards[i] += self.collision_penalty
            elif other_agent_dist < self.safety_radius:
                current_dist_to_goal = np.linalg.norm(np.array(new_pos) - np.array(agent['goal']))
                if current_dist_to_goal < agent['prev_dist']:
                    candidate_positions.append(new_pos)
                    rewards[i] += self.collision_penalty * 0.2
                else:
                    candidate_positions.append(agent['position'])
                    rewards[i] += self.collision_penalty * 0.3
            elif self.validate_move(new_pos, current_time + 1, i):
                candidate_positions.append(new_pos)
                if self.is_alternative_path(i, new_pos):
                    rewards[i] += self.alternative_path_bonus
            else:
                candidate_positions.append(agent['position'])
                rewards[i] += self.collision_penalty * 0.2

        for i, agent in enumerate(self.agents):
            if agent.get('reached_goal', False):
                continue

            new_position = candidate_positions[i]
            old_position = agent['position']
            agent['position'] = new_position
            agent['steps_taken'] += 1

            if self.deadlock_counter[i] >= self.deadlock_threshold:
                self.min_path_distance = max(1, self.min_path_distance - 0.5)
                rewards[i] += self.stagnation_penalty * 5
            else:
                self.min_path_distance = min(2, self.min_path_distance + 0.1)

            self.update_path_history(i, new_position)

            rewards[i] += self.compute_rewards(i, new_position, old_position)

            if new_position == agent['goal']:
                rewards[i] += self.goal_reached_bonus
                dones[i] = True
                agent['reached_goal'] = True

        if ((dones[0] and dones[1]) or
            all(agent['steps_taken'] >= self.max_steps for agent in self.agents)):
            dones = [True, True]

        self.apply_path_constraints()

        return self._get_state(), rewards, dones, {}

    def is_valid_move(agent_id, position):
        if position in self.obs:
            return False

        for j, path in enumerate(self.path_history):
            if j != agent_id and position in path:
                return False
        return True

    def check_deadlock(self, agent_idx, position):
        if position == self.agents[agent_idx]['position']:
            blocked_directions = 0
            for dx, dy in self.motions:
                check_pos = (position[0] + dx, position[1] + dy)
                if not self.validate_move(check_pos, self.agents[agent_idx]['steps_taken'] + 1, agent_idx):
                    blocked_directions += 1
            return blocked_directions >= 6
        return False

    def check_path_violation(self, agent_idx, position, current_agent_idx):
        other_agent_path = self.agent_paths[1-agent_idx][-self.path_memory_length:]
        if not other_agent_path:
            return False

        pos_array = np.array(position)
        for i, path_seg_start in enumerate(other_agent_path[:-1]):
            path_seg_end = other_agent_path[i + 1]

            corridor_width = 5.0

            path_vector = np.array(path_seg_end) - np.array(path_seg_start)
            path_length = np.linalg.norm(path_vector)
            if path_length == 0:
                continue

            perp_vector = np.array([-path_vector[1], path_vector[0]]) / path_length

            pos_relative = pos_array - np.array(path_seg_start)
            dist_from_path = abs(np.dot(pos_relative, perp_vector))

            if dist_from_path < corridor_width:
                proj_length = np.dot(pos_relative, path_vector) / path_length
                if 0 <= proj_length <= path_length:
                    return True

        for i, agent in enumerate(self.agents):
            if i != agent_idx and position == agent['goal'] and position == agent['start']:
                return True

        loc = Location(position[0], position[1])
        vertex_const = VertexConstraint(self.agents[current_agent_idx]['steps_taken'] + 1, loc)
        return vertex_const in self.cbs_constraints.vertex_constraints

    def is_alternative_path(self, agent_idx, position):
        other_agent = self.agents[1-agent_idx]
        other_pos = np.array(other_agent['position'])
        other_goal = np.array(other_agent['goal'])

        direct_path_vector = other_goal - other_pos
        pos_vector = np.array(position) - other_pos

        dot_product = np.dot(direct_path_vector, pos_vector)
        norms_product = np.linalg.norm(direct_path_vector) * np.linalg.norm(pos_vector)

        if norms_product == 0:
            return False

        angle = np.arccos(np.clip(dot_product / norms_product, -1.0, 1.0))

        return abs(angle) > np.pi/2.5

    def update_path_history(self, agent_idx, position):
        if len(self.agent_paths[agent_idx]) > self.path_memory_length:
            keep_prob = np.linspace(0.5, 1.0, len(self.agent_paths[agent_idx]))
            self.agent_paths[agent_idx] = [pos for pos, prob in zip(self.agent_paths[agent_idx], keep_prob)
                                        if np.random.random() < prob]
        self.agent_paths[agent_idx].append(position)

    def compute_rewards(self, agent_idx, new_position, old_position):
        reward = 0
        current_dist = np.linalg.norm(np.array(new_position) - np.array(self.agents[agent_idx]['goal']))
        progress = self.agents[agent_idx]['prev_dist'] - current_dist
        self.agents[agent_idx]['prev_dist'] = current_dist

        other_agent_path = self.agent_paths[1-agent_idx]
        if other_agent_path:
            min_path_dist = float('inf')
            pos_array = np.array(new_position)

            for path_pos in other_agent_path:
                path_pos_array = np.array(path_pos)
                dist = np.linalg.norm(pos_array - path_pos_array)
                min_path_dist = min(min_path_dist, dist)

            if min_path_dist < 5.0:
                path_penalty = -200.0 * (5.0 - min_path_dist)
                reward += path_penalty

        if self.is_alternative_path(agent_idx, new_position):
            progress *= 2.0

        reward += progress * self.progress_reward_scale

        if new_position not in self.visited_positions[agent_idx]:
            self.visited_positions[agent_idx].add(new_position)
            reward += 20.0
        if new_position == old_position and not self.is_blocked(agent_idx):
            reward += self.stagnation_penalty * 2

        return reward

    def is_blocked(self, agent_idx):
        agent = self.agents[agent_idx]
        for dx, dy in self.motions:
            new_pos = (agent['position'][0] + dx, agent['position'][1] + dy)
            if self.validate_move(new_pos, agent['steps_taken'] + 1, agent_idx):
                return False
        return True

    def get_agent_path(self, agent_idx):
        return self.prev_positions[agent_idx]

    def apply_cbs_constraints(self, conflict):
        if any(agent.get('reached_goal', False) for agent in self.agents):
            return

        if conflict.type == Conflict.VERTEX:
            loc = conflict.location_1
            vertex_const = VertexConstraint(conflict.time, loc)
            new_constraints = CBSConstraints()
            new_constraints.vertex_constraints.add(vertex_const)
            self.cbs_constraints.add_constraint(new_constraints)
        elif conflict.type == Conflict.EDGE:
            new_constraints = CBSConstraints()
            edge_const1 = EdgeConstraint(conflict.time, conflict.location_1, conflict.location_2)
            edge_const2 = EdgeConstraint(conflict.time, conflict.location_2, conflict.location_1)
            new_constraints.edge_constraints.add(edge_const1)
            new_constraints.edge_constraints.add(edge_const2)
            self.cbs_constraints.add_constraint(new_constraints)

    def _segments_intersect(self, p1, p2, p3, p4):
        def ccw(A, B, C):
            return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])

        buffer = 0.1
        p1 = p1 + buffer
        p2 = p2 + buffer
        return ccw(p1,p3,p4) != ccw(p2,p3,p4) and ccw(p1,p2,p3) != ccw(p1,p2,p4)

    def update_moving_obstacles(self):
        for obstacle in self.moving_obstacles:
            if obstacle['path_type'] == 'circle':
                obstacle['angle'] += obstacle['speed']
                obstacle['position'][0] = obstacle['center'][0] + obstacle['radius'] * np.cos(obstacle['angle'])
                obstacle['position'][1] = obstacle['center'][1] + obstacle['radius'] * np.sin(obstacle['angle'])

            elif obstacle['path_type'] == 'vertical':
                obstacle['position'][1] += obstacle['speed'] * obstacle['direction']
                if obstacle['position'][1] >= obstacle['y_max']:
                    obstacle['direction'] = -1
                elif obstacle['position'][1] <= obstacle['y_min']:
                    obstacle['direction'] = 1

            elif obstacle['path_type'] == 'horizontal':
                obstacle['position'][0] += obstacle['speed'] * obstacle['direction']
                if obstacle['position'][0] >= obstacle['x_max']:
                    obstacle['direction'] = -1
                elif obstacle['position'][0] <= obstacle['x_min']:
                    obstacle['direction'] = 1

            elif obstacle['path_type'] == 'figure_eight':
                obstacle['angle'] += obstacle['speed']
                obstacle['position'][0] = obstacle['center'][0] + obstacle['a'] * np.sin(obstacle['angle'])
                obstacle['position'][1] = obstacle['center'][1] + obstacle['b'] * np.sin(obstacle['angle'] * 2)

            elif obstacle['path_type'] == 'diagonal':
                obstacle['position'][0] += obstacle['speed'] * obstacle['direction_x']
                obstacle['position'][1] += obstacle['speed'] * obstacle['direction_y']

                if obstacle['position'][0] >= obstacle['x_max']:
                    obstacle['direction_x'] = -1
                elif obstacle['position'][0] <= obstacle['x_min']:
                    obstacle['direction_x'] = 1

                if obstacle['position'][1] >= obstacle['y_max']:
                    obstacle['direction_y'] = -1
                elif obstacle['position'][1] <= obstacle['y_min']:
                    obstacle['direction_y'] = 1

class MultiAgentPPO:
    def __init__(self, state_dim, action_dim, num_agents=2):
        self.num_agents = num_agents
        self.agents = [PPOAgent(state_dim, action_dim) for _ in range(num_agents)]

    def select_actions(self, states, training=True):
        actions = []
        log_probs = []
        for i, agent in enumerate(self.agents):
            action, log_prob = agent.select_action(states[i], training)
            actions.append(action)
            log_probs.append(log_prob)
        return actions, log_probs

    def update(self, agent_trajectories):
        for i in range(self.num_agents):
            traj = agent_trajectories[i]
            self.agents[i].update(
                traj['states'],
                traj['actions'],
                traj['log_probs'],
                traj['rewards'],
                traj['next_states'],
                traj['dones']
            )
def visualize_multi_agent_paths(env, paths, rewards):
    plt.figure(figsize=(12, 8))
    obs_x = [x for x, y in env.obs]
    obs_y = [y for x, y in env.obs]
    plt.plot(obs_x, obs_y, "sk", label="Obstacles")
    # Plot dynamic obstacles as static obstacles
    for obs in env.moving_obstacles:
        plt.plot(obs['position'][0], obs['position'][1], "sk", label="Dynamic Obstacles (Treated as Static)")

    for i, agent_data in enumerate(env.agents):
        plt.plot(agent_data['start'][0], agent_data['start'][1], f"{['b','g'][i]}s", label=f"Start {i+1}")
        plt.plot(agent_data['goal'][0], agent_data['goal'][1], f"{['r','m'][i]}s", label=f"Goal {i+1}")

    colors = ['b-', 'g-']
    for i, path in enumerate(paths):
        if len(path) == 0 or path[0] != env.agents[i]['start']:
            full_path = [env.agents[i]['start']] + path
        else:
            full_path = path

        path_x = [p[0] for p in full_path]
        path_y = [p[1] for p in full_path]
        plt.plot(path_x, path_y, colors[i], label=f"Agent {i+1} Path")

    if isinstance(rewards, (list, tuple)):
        title = f"Multi-Agent Paths\nAgent 1 Reward: {rewards[0]:.2f}, Agent 2 Reward: {rewards[1]:.2f}"
    else:
        title = f"Multi-Agent Paths\nTotal Reward: {rewards:.2f}"

    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.axis("equal")
    plt.show()

def train_multi_agent_ppo(num_episodes=3000, max_steps=500):
    env = MultiAgentEnv()
    agent_wrapper = MultiAgentPPO(env.state_dim, env.action_dim)

    best_episode_paths = None
    best_combined_reward = float('-inf')
    last_success_paths = None
    last_success_rewards = None
    last_success_lengths = None

    for episode in range(num_episodes):
        states = env.reset()
        total_rewards = [0, 0]
        paths = [[], []]
        agent_trajectories = [
            {'states': [], 'actions': [], 'log_probs': [], 'rewards': [],
            'next_states': [], 'dones': []}
            for _ in range(2)
        ]

        for step in range(max_steps):
            actions, log_probs = agent_wrapper.select_actions(states)
            next_states, rewards, dones, _ = env.step(actions)

            for i in range(2):
                agent_trajectories[i]['states'].append(states[i])
                agent_trajectories[i]['actions'].append(actions[i])
                agent_trajectories[i]['log_probs'].append(log_probs[i])
                agent_trajectories[i]['rewards'].append(rewards[i])
                agent_trajectories[i]['next_states'].append(next_states[i])
                agent_trajectories[i]['dones'].append(dones[i])
                paths[i].append(env.agents[i]['position'])
                total_rewards[i] += rewards[i]

            conflict = detect_conflict(paths)
            if conflict:
                env.apply_cbs_constraints(conflict)

            states = next_states
            if all(dones):
                break

        combined_reward = sum(total_rewards)

        if all(env.agents[i].get('reached_goal', False) for i in range(2)):
            last_success_paths = deepcopy(paths)
            last_success_rewards = total_rewards.copy()
            last_success_lengths = [len(paths[0]), len(paths[1])]

            if combined_reward > best_combined_reward and not detect_conflict(paths):
                best_combined_reward = combined_reward
                best_episode_paths = deepcopy(paths)

        agent_wrapper.update(agent_trajectories)

        if episode % 100 == 0:
            print(f"Episode {episode}")
            print(f"Rewards: Agent1 = {total_rewards[0]:.2f}, Agent2 = {total_rewards[1]:.2f}")

    print("\nTraining Complete!")
    if best_episode_paths:
        print(f"Best path found - Combined Reward: {best_combined_reward:.2f}")
        print("Visualizing best path found during training:")
        visualize_multi_agent_paths(env, best_episode_paths, best_combined_reward)

    print("\nLast successful path statistics:")
    if last_success_paths:
        print(f"Rewards: Agent1 = {last_success_rewards[0]:.2f}, Agent2 = {last_success_rewards[1]:.2f}")
        print(f"Lengths: Agent1 = {last_success_lengths[0]}, Agent2 = {last_success_lengths[1]}")
        print("Visualizing last successful paths:")
        visualize_multi_agent_paths(env, last_success_paths, sum(last_success_rewards))
    else:
        print("No successful paths found during training")

    return last_success_paths,best_episode_paths, env

def calculate_path_straightness(path):
    if len(path) < 2:
        return float('inf')

    start, end = path[0], path[-1]
    total_deviation = 0
    for point in path:
        num = abs((end[1] - start[1]) * point[0] - (end[0] - start[0]) * point[1] + end[0] * start[1] - end[1] * start[0])
        den = np.sqrt((end[1] - start[1])**2 + (end[0] - start[0])**2)
        total_deviation += num / den
    return total_deviation

def find_straightest_path(paths):
    straightest_path = None
    min_deviation = float('inf')
    for path in paths:
        deviation = calculate_path_straightness(path)
        if deviation < min_deviation:
            min_deviation = deviation
            straightest_path = path
    return straightest_path

def animate_last_successful_path(env, last_successful_path):
    fig, ax = plt.subplots(figsize=(12, 8))

    obs_x = [x for x, y in env.obs]
    obs_y = [y for x, y in env.obs]
    ax.plot(obs_x, obs_y, "sk", markersize=5, label="Obstacles")

    moving_obstacles = []
    safety_circles = []
    obstacle_data = deepcopy(env.moving_obstacles)
    
    for obs_data in obstacle_data:
        obstacle = ax.plot([], [], 'ro', markersize=8)[0]
        moving_obstacles.append(obstacle)

        safety = plt.Circle((0, 0), 2.0, color='red', alpha=0.1)
        ax.add_patch(safety)
        safety_circles.append(safety)

    agents = []
    paths = []
    agent_safety_circles = []
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']

    for i, agent_data in enumerate(env.agents):
        ax.plot(agent_data['start'][0], agent_data['start'][1],
                f"{colors[i % len(colors)]}s", markersize=10, label=f"Start {i+1}")
        ax.plot(agent_data['goal'][0], agent_data['goal'][1],
                'rs', markersize=10, label=f"Goal {i+1}")

        agent = ax.plot([], [], f"{colors[i % len(colors)]}o", markersize=8)[0]
        path = ax.plot([], [], f"{colors[i % len(colors)]}-", alpha=0.5,
                    label=f"Agent {i+1} Path")[0]

        safety = plt.Circle((0, 0), env.safety_radius, color=colors[i % len(colors)], alpha=0.1)
        ax.add_patch(safety)

        agents.append(agent)
        paths.append(path)
        agent_safety_circles.append(safety)

    ax.set_xlim(-1, env.x_range + 1)
    ax.set_ylim(-1, env.y_range + 1)
    ax.grid(True)
    ax.legend()
    ax.set_aspect('equal')
    title = ax.set_title("Step: 0")

    max_steps = max(len(agent_path) for agent_path in last_successful_path)

    def update_obstacle_positions():
        for obstacle in obstacle_data:
            if obstacle['path_type'] == 'circle':
                obstacle['angle'] += obstacle['speed']
                obstacle['position'][0] = obstacle['center'][0] + obstacle['radius'] * np.cos(obstacle['angle'])
                obstacle['position'][1] = obstacle['center'][1] + obstacle['radius'] * np.sin(obstacle['angle'])

            elif obstacle['path_type'] == 'vertical':
                obstacle['position'][1] += obstacle['speed'] * obstacle['direction']
                if obstacle['position'][1] >= obstacle['y_max']:
                    obstacle['direction'] = -1
                elif obstacle['position'][1] <= obstacle['y_min']:
                    obstacle['direction'] = 1

            elif obstacle['path_type'] == 'horizontal':
                obstacle['position'][0] += obstacle['speed'] * obstacle['direction']
                if obstacle['position'][0] >= obstacle['x_max']:
                    obstacle['direction'] = -1
                elif obstacle['position'][0] <= obstacle['x_min']:
                    obstacle['direction'] = 1

            elif obstacle['path_type'] == 'figure_eight':
                obstacle['angle'] += obstacle['speed']
                obstacle['position'][0] = obstacle['center'][0] + obstacle['a'] * np.sin(obstacle['angle'])
                obstacle['position'][1] = obstacle['center'][1] + obstacle['b'] * np.sin(obstacle['angle'] * 2)

            elif obstacle['path_type'] == 'diagonal':
                obstacle['position'][0] += obstacle['speed'] * obstacle['direction_x']
                obstacle['position'][1] += obstacle['speed'] * obstacle['direction_y']

                if obstacle['position'][0] >= obstacle['x_max']:
                    obstacle['direction_x'] = -1
                elif obstacle['position'][0] <= obstacle['x_min']:
                    obstacle['direction_x'] = 1

                if obstacle['position'][1] >= obstacle['y_max']:
                    obstacle['direction_y'] = -1
                elif obstacle['position'][1] <= obstacle['y_min']:
                    obstacle['direction_y'] = 1

    def animate(frame):
        update_obstacle_positions()

        for i, (obs, safety_circle) in enumerate(zip(moving_obstacles, safety_circles)):
            obs_pos = obstacle_data[i]['position']
            obs.set_data([obs_pos[0]], [obs_pos[1]])
            safety_circle.center = obs_pos

        for i, (agent, path, safety_circle) in enumerate(zip(agents, paths, agent_safety_circles)):
            if frame < len(last_successful_path[i]):
                current_pos = last_successful_path[i][frame]
            else:
                current_pos = last_successful_path[i][-1]

            agent.set_data([current_pos[0]], [current_pos[1]])

            path_x = [pos[0] for pos in last_successful_path[i][:frame+1]]
            path_y = [pos[1] for pos in last_successful_path[i][:frame+1]]
            path.set_data(path_x, path_y)

            safety_circle.center = current_pos

        title.set_text(f"Step: {frame}")

        return ([title] + moving_obstacles + safety_circles + agents + paths + agent_safety_circles)

    anim = FuncAnimation(fig, animate, frames=max_steps, interval=100, blit=False, repeat=True)

    anim.save('dynamic_obstacles_path.gif', writer='pillow')
    plt.show()

if __name__ == "__main__":
    last_successful_path, rewards, env = train_multi_agent_ppo(num_episodes=1500)

    if last_successful_path:
        animate_last_successful_path(env, last_successful_path)
    else:
        print("No successful paths found")

# FL COMBINED WITH BLOCKCHAIN FOR VERIFICATION(GANACHE)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
import random
import inspect
if not hasattr(inspect, 'getargspec'):
    inspect.getargspec = inspect.getfullargspec
import matplotlib.pyplot as plt
from copy import deepcopy
import hashlib
import json
from web3 import Web3
import os
import json

w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:7545"))  

def load_contract():
    with open('chain/build/contracts/FederatedLearningContract.json') as f:
        contract_json = json.load(f)

    network_id = w3.eth.chain_id
    contract_address = "0xC2E962d33E95f2cce8d0EB493025d4215220712B"
    contract_abi = contract_json['abi']
    contract = w3.eth.contract(address=contract_address, abi=contract_abi)
    return contract

contract = load_contract()

def generate_model_hash(model):
    params = []
    for name, param in model.policy.named_parameters():
        params.append(param.data.cpu().numpy().tolist())

    model_json = json.dumps(params)
    hash_bytes = Web3.keccak(text=model_json)
    hash_hex = hash_bytes.hex()  

    return hash_hex

def check_contract_state():
    current_round = contract.functions.currentRound().call()
    aggregation_complete = contract.functions.aggregationComplete().call()

    if current_round > 0:
        for r in range(current_round):
            try:
                global_hash = contract.functions.globalModelHashes(r).call()
                if isinstance(global_hash, bytes):
                    print(f"  Global Model Hash (Round {r}): 0x{global_hash.hex()[:10]}...")
                else:
                    print(f"  Global Model Hash (Round {r}): {global_hash[:10]}...")
            except Exception as e:
                print(f"  Error getting global model hash for round {r}: {e}")
    else:
        try:
            global_hash = contract.functions.globalModelHashes(0).call()
            if isinstance(global_hash, bytes):
                print(f"  Global Model Hash (Round 0): 0x{global_hash.hex()[:10]}...")
            else:
                print(f"  Global Model Hash (Round 0): {global_hash[:10]}...")
        except Exception as e:
            print(f"  Error getting global model hash for round 0: {e}")

class BlockchainFederatedServer:
    def __init__(self, state_dim, action_dim, num_clients):
        self.global_model = PPOAgent(state_dim, action_dim)
        self.num_clients = num_clients
        self.ema_alpha = 0.8  
        self.previous_global_state = None
        self.momentum = 0.9
        self.velocity = None
        self.client_trust_scores = np.ones(num_clients)  
        self.training_round = 0
        self.current_round = 0  

        self.server_account = w3.eth.accounts[0]

        for i in range(1, num_clients + 1):
            client_addr = w3.eth.accounts[i]
            self.register_client(client_addr)

    def register_client(self, client_address):
        tx = contract.functions.registerClient(client_address).transact({
            'from': self.server_account,
            'gas': 2000000
        })

        receipt = w3.eth.waitForTransactionReceipt(tx)
        print(f"Client {client_address} registered. Transaction hash: {tx.hex()}")

    def start_new_round(self):
        if not contract.functions.aggregationComplete().call():
            print("Cannot start new round: Previous round's aggregation is not complete")
            return False

        try:
            tx = contract.functions.startNewRound().transact({
                'from': self.server_account,
                'gas': 2000000
            })
            receipt = w3.eth.waitForTransactionReceipt(tx)
            
            if receipt.status == 1:
                self.current_round = contract.functions.currentRound().call()
                return True
            else:
                print("Transaction to start new round failed")
                return False
            
        except Exception as e:
            print(f"Error starting new round: {str(e)}")
            return False

    def submit_global_model(self):
        model_hash_hex = generate_model_hash(self.global_model)  
        model_hash_bytes = Web3.toBytes(hexstr=model_hash_hex)

        tx = contract.functions.submitGlobalModel(model_hash_bytes).transact({
            'from': self.server_account,
            'gas': 2000000
        })

        receipt = w3.eth.waitForTransactionReceipt(tx)
        if receipt.status == 1:
            print(f"Global model successfully submitted.")
            stored_hash = contract.functions.globalModelHashes(self.current_round).call()
            if isinstance(stored_hash, bytes):
                print(f"Verified stored hash: 0x{stored_hash.hex()[:10]}...")
            else:
                print(f"Verified stored hash: {stored_hash[:10]}...")
        else:
            print(f"Transaction failed when submitting global model!")

        return model_hash_bytes  


    def verify_client_submissions(self):
        all_updated = True
        for i in range(1, self.num_clients + 1):
            client_addr = w3.eth.accounts[i]
            has_updated = contract.functions.hasUpdated(client_addr).call()
            if not has_updated:
                all_updated = False
                print(f"Client {client_addr} has not submitted an update")

        return all_updated and contract.functions.aggregationComplete().call()

    def aggregate_weights(self, client_models, client_rewards, client_training_sizes=None, client_update_times=None):
        self.training_round += 1
        if self.training_round == 1:
            max_reward = max(client_rewards)
            self.client_trust_scores = [
                1.0 + (reward / max_reward - 0.5) 
                for reward in client_rewards
            ]

        else:
            # Calculate model divergence
            divergences = []
            for i, client in enumerate(client_models):
                # Compute divergence between client model and global model
                divergence = 0
                for key in client.policy.state_dict().keys():
                    if 'running_mean' not in key and 'running_var' not in key:
                        divergence += torch.norm(
                            client.policy.state_dict()[key] - 
                            self.global_model.policy.state_dict()[key]
                        ).item()
                divergences.append(divergence)

            # Normalize divergence
            if max(divergences) > 0:
                norm_divergences = [div / max(divergences) for div in divergences]
                
                # Update trust scores based on multiple factors
                for i in range(self.num_clients):
                    # Reward contribution (higher reward increases trust)
                    reward_factor = client_rewards[i] / max(client_rewards) if max(client_rewards) > 0 else 1
                    
                    # Divergence contribution (less divergence increases trust)
                    divergence_factor = 1 - norm_divergences[i]
                    
                    # Exponential moving average update
                    self.client_trust_scores[i] = (
                        0.7 * self.client_trust_scores[i] +  # Previous trust
                        0.2 * reward_factor +               # Reward contribution
                        0.1 * divergence_factor             # Model similarity
                    )
                    
                    # Clip trust scores to a reasonable range
                    self.client_trust_scores[i] = max(0.5, min(1.5, self.client_trust_scores[i]))

        # Calculate weights considering trust scores
        if sum(client_rewards) == 0:
            base_weights = [1.0 / self.num_clients] * self.num_clients
        else:
            # Normalize rewards
            min_reward = min(client_rewards)
            shifted_rewards = [r - min_reward + 1e-5 for r in client_rewards] if min_reward < 0 else client_rewards
            total_reward = sum(shifted_rewards)
            base_weights = [r / total_reward for r in shifted_rewards]

        # Apply trust scores to weights
        weights = [w * trust for w, trust in zip(base_weights, self.client_trust_scores)]
        total_weight = sum(weights)
        weights = [w / total_weight for w in weights]

        print(f"Round {self.training_round - 1} Analysis:")
        print(f"  Rewards:        {[f'{r:.2f}' for r in client_rewards]}")
        print(f"  Trust Scores:   {[f'{t:.2f}' for t in self.client_trust_scores]}")
        print(f"  Aggregation Weights: {[f'{w:.3f}' for w in weights]}")

        new_state_dict = deepcopy(client_models[0].policy.state_dict())
        for key in new_state_dict.keys():
            if 'running_mean' in key or 'running_var' in key:
                continue
            new_state_dict[key] = sum(weights[i] * client_models[i].policy.state_dict()[key]
                                    for i in range(self.num_clients))

        self.previous_global_state = deepcopy(new_state_dict)
        self.global_model.policy.load_state_dict(new_state_dict)

        self.submit_global_model()
class SecurityError(Exception):
        pass

class BlockchainFederatedClient:
    def __init__(self, state_dim, action_dim, env, agent_id, account_index):
        self.model = PPOAgent(state_dim, action_dim)
        self.env = env
        self.agent_id = agent_id
        self.best_path = None
        self.best_reward = -float('inf')
        self.training_size = 0
        self.last_update_time = 0
        self.account_address = w3.eth.accounts[account_index]

    def submit_update_to_blockchain(self, reward):
        has_updated = contract.functions.hasUpdated(self.account_address).call()

        if has_updated:
            print(f"Client {self.agent_id} has already submitted an update for this round.")
            return False, None

        print(f"Verifying client {self.agent_id} with account {self.account_address} before submitting update...")
        
        is_authorized = contract.functions.authorizedClients(self.account_address).call()
        if not is_authorized:
            print(f"ERROR: Client {self.agent_id} is not authorized and cannot submit updates!")
            print("SECURITY ALERT: Unauthorized client detected - terminating federation process")
            print("===== FEDERATION PROCESS TERMINATED =====")
            sys.exit(1)
            return False, None
        else:
            print(f"Client {self.agent_id} verification successful - client is authorized in the contract.")

        model_hash = generate_model_hash(self.model) 
        reward_int = int(reward * 100)  

        try:
            tx = contract.functions.submitUpdate(Web3.toBytes(hexstr=model_hash), reward_int).transact({
                'from': self.account_address,
                'gas': 2000000
            })

            receipt = w3.eth.waitForTransactionReceipt(tx)

            if receipt.status == 1:
                print(f"Client {self.agent_id} submitted update. Transaction hash: {tx.hex()}")
                print(f"Update verified on-chain via the onlyAuthorizedClient modifier")
                return True, model_hash
            else:
                print(f"Transaction failed for Client {self.agent_id}.")
                return False, None

        except Exception as e:
            print(f"Error submitting update: {str(e)}")
            if "revert" in str(e).lower():
                print(f"Transaction reverted - client likely failed the onlyAuthorizedClient verification")
            return False, None 

    def receive_global_model(self, global_model, round_num):
        expected_hash = generate_model_hash(global_model)
        if self.verify_global_model(expected_hash, round_num):
            self.model.policy.load_state_dict(global_model.policy.state_dict())
            print(f"Client {self.agent_id} successfully received verified global model")
            return True
        else:
            print(f"Client {self.agent_id} rejected unverified global model")
            return False

    def train(self, num_episodes=10, round_num=0):
        self.last_update_time = round_num
        best_local_path = None
        best_local_reward = -float('inf')
        successful_paths = []
        self.training_size = num_episodes  

        for episode in range(num_episodes):
            states = self.env.reset()
            total_rewards = [0, 0]
            paths = [[], []]
            agent_trajectories = [
                {'states': [], 'actions': [], 'log_probs': [], 'rewards': [],
                 'next_states': [], 'dones': []}
                for _ in range(2)
            ]

            for step in range(self.env.max_steps):
                actions, log_probs = [], []
                for i in range(2):
                    action, log_prob = self.model.select_action(states[i])
                    actions.append(action)
                    log_probs.append(log_prob)

                next_states, rewards, dones, _ = self.env.step(actions)

                for i in range(2):
                    agent_trajectories[i]['states'].append(states[i])
                    agent_trajectories[i]['actions'].append(actions[i])
                    agent_trajectories[i]['log_probs'].append(log_probs[i])
                    agent_trajectories[i]['rewards'].append(rewards[i])
                    agent_trajectories[i]['next_states'].append(next_states[i])
                    agent_trajectories[i]['dones'].append(dones[i])
                    paths[i].append(self.env.agents[i]['position'])
                    total_rewards[i] += rewards[i]

                conflict = detect_conflict(paths)
                if conflict:
                    self.env.apply_cbs_constraints(conflict)

                states = next_states
                if all(dones):
                    successful_paths.append((deepcopy(paths), sum(total_rewards)))
                    break

            combined_reward = sum(total_rewards)

            if combined_reward > best_local_reward:
                best_local_reward = combined_reward
                best_local_path = deepcopy(paths)

            for i in range(2):
                self.model.update(
                    agent_trajectories[i]['states'],
                    agent_trajectories[i]['actions'],
                    agent_trajectories[i]['log_probs'],
                    agent_trajectories[i]['rewards'],
                    agent_trajectories[i]['next_states'],
                    agent_trajectories[i]['dones']
                )

        if successful_paths:
            self.best_path = max(successful_paths, key=lambda x: x[1])[0]
        else:
            self.best_path = best_local_path

        self.best_reward = best_local_reward
        submission_result, model_hash = self.submit_update_to_blockchain(self.best_reward)

        return self.model, submission_result, model_hash

    def verify_global_model(self, expected_hash, round_num):
        on_chain_hash = contract.functions.globalModelHashes(round_num).call()
        print(f"Verification ")
        print(f"  Expected hash: {expected_hash[:10]}...")
        if isinstance(on_chain_hash, bytes):
            on_chain_hash = '0x' + on_chain_hash.hex()
        print(f"  On-chain hash: {on_chain_hash[:10]}...")

        if on_chain_hash == expected_hash:
            print(f"  Global model verified")
            return True
        else:
            print(f"  Global model verification failed")
            return False
def test_unregistered_client():
    unregistered_account = w3.eth.accounts[5]
    print(f"\n=== Testing unregistered client with account {unregistered_account} ===")
    unregistered_client = BlockchainFederatedClient(state_dim, action_dim, MultiAgentEnv(), 999, 5)
    print("Training unregistered client...")
    unregistered_client.train(num_episodes=10, round_num=0)

sample_env = MultiAgentEnv()
sample_state = sample_env.reset()
state_dim = len(sample_state[0])
action_dim = 8
no_rounds = 5
num_episodes = 100

num_clients = 2
envs = [MultiAgentEnv() for _ in range(num_clients)]
server = BlockchainFederatedServer(state_dim, action_dim, num_clients)
clients = [
    BlockchainFederatedClient(state_dim, action_dim, envs[i], i, i+1)  
    for i in range(num_clients)
]

rewards_history = []
client_rewards_history = [[] for _ in range(num_clients)]

check_contract_state()
# For checking against random agent which is not registered
# test_unregistered_client() 
for round in range(no_rounds):
    print(f"\n==== STARTING TRAINING ROUND {round}====\n")
    
    if round > 0:
        server.start_new_round()
        current_round = contract.functions.currentRound().call()
    else:
        current_round = contract.functions.currentRound().call()

    client_models = []
    round_rewards = []
    training_sizes = []
    update_times = []

    for client in clients:
        print(f"Client {client.agent_id} training in round {round}...")
        model, submission_success, _ = client.train(num_episodes, current_round)
        if not submission_success:
            print(f"Client {client.agent_id} failed to submit update. Terminating federation process.")
            print("===== FEDERATION PROCESS TERMINATED =====")
            sys.exit(1) 
        else:
            client_models.append(model)
            round_rewards.append(client.best_reward)
            client_rewards_history[client.agent_id].append(client.best_reward)
            training_sizes.append(client.training_size)
            update_times.append(client.last_update_time)

    check_contract_state()
    print("\nAggregating client models...")
    server.aggregate_weights(
        client_models,
        round_rewards,
        training_sizes,
        update_times
    )

    check_contract_state()
    print("\nDistributing global model to clients...")
    verification_success = True
    for client in clients:
        if not client.receive_global_model(server.global_model, current_round):
            verification_success = False
            print(f"Verification failed for client {client.agent_id}, stopping federation process")
            sys.exit(1)
            break

    if not verification_success:
        print("Cannot proceed to next round due to verification failure")
        break  

    rewards_history.append(np.mean(round_rewards))
    print(f"Round {round}: Avg Reward = {np.mean(round_rewards):.2f}")

plt.figure(figsize=(12, 6))
plt.plot(rewards_history, 'k-', linewidth=2, label="Global Avg Reward")
for i in range(num_clients):
    plt.plot(client_rewards_history[i], linestyle='--', marker='o', label=f"Client {i} Reward")
plt.xlabel("Rounds")
plt.ylabel("Reward")
plt.title("Blockchain-Enhanced Federated Learning Training Progress")
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()

last_successful_path = None
for client in clients:
    if client.best_path:
        last_successful_path = client.best_path
        break

if last_successful_path:
    animate_last_successful_path(envs[0], last_successful_path)
else:
    print("No successful paths found")

print("Blockchain-enhanced federated training complete!")