In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

# Hyperparameters
gamma = 0.99
epsilon_clip = 0.2
lr = 1e-3
batch_size = 64
epochs = 3
max_steps = 200
max_episodes = 1000

# Define the environment
class GridEnvironment:
    def __init__(self, size=10, target=(9, 9)):
        self.size = size
        self.target = target
        self.reset()

    def reset(self):
        self.agent_positions = [(0, 0), (self.size-1, self.size-1)]
        self.steps = 0
        return self.get_state()

    def step(self, actions):
        rewards = []
        next_state = []
        for i, action in enumerate(actions):
            x, y = self.agent_positions[i]
            if action == 0:  # move up
                x = max(0, x-1)
            elif action == 1:  # move down
                x = min(self.size-1, x+1)
            elif action == 2:  # move left
                y = max(0, y-1)
            elif action == 3:  # move right
                y = min(self.size-1, y+1)

            self.agent_positions[i] = (x, y)
            if (x, y) == self.target:
                rewards.append(100)
            else:
                rewards.append(-1)

            next_state.append(self.get_state()[i])

        self.steps += 1
        done = self.steps >= max_steps or all(pos == self.target for pos in self.agent_positions)
        return next_state, rewards, done

    def get_state(self):
        states = []
        for pos in self.agent_positions:
            state = np.zeros((self.size, self.size))
            state[pos] = 1
            state[self.target] = 2
            states.append(state.flatten())
        return states

# Define the PPO network
class PPOAgent(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPOAgent, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.policy_head = nn.Linear(256, action_dim)
        self.value_head = nn.Linear(256, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy = self.policy_head(x)
        value = self.value_head(x)
        return policy, value

def select_action(agent, state):
    state = torch.FloatTensor(state).unsqueeze(0)
    policy, _ = agent(state)
    action_prob = torch.softmax(policy, dim=-1)
    action = np.random.choice(len(action_prob[0]), p=action_prob.detach().numpy()[0])
    return action

def compute_returns(rewards, dones, gamma):
    R = 0
    returns = deque()
    for r, d in zip(reversed(rewards), reversed(dones)):
        R = r + gamma * R * (1 - d)
        returns.appendleft(R)
    return list(returns)

def update_policy(agent, optimizer, states, actions, old_log_probs, returns, advantages):
    for _ in range(epochs):
        for i in range(0, len(states), batch_size):
            state_batch = torch.FloatTensor(states[i:i+batch_size])
            action_batch = torch.LongTensor(actions[i:i+batch_size])
            old_log_prob_batch = torch.FloatTensor(old_log_probs[i:i+batch_size])
            return_batch = torch.FloatTensor(returns[i:i+batch_size])
            advantage_batch = torch.FloatTensor(advantages[i:i+batch_size])

            policy, value = agent(state_batch)
            value = value.squeeze()
            action_prob = torch.softmax(policy, dim=-1)
            dist = torch.distributions.Categorical(action_prob)
            log_prob = dist.log_prob(action_batch)

            ratio = torch.exp(log_prob - old_log_prob_batch)
            surr1 = ratio * advantage_batch
            surr2 = torch.clamp(ratio, 1 - epsilon_clip, 1 + epsilon_clip) * advantage_batch

            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = nn.functional.mse_loss(value, return_batch)

            loss = policy_loss + 0.5 * value_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

# Main training loop
env = GridEnvironment()
state_dim = env.size * env.size
action_dim = 4

agent_a = PPOAgent(state_dim, action_dim)
agent_b = PPOAgent(state_dim, action_dim)
optimizer_a = optim.Adam(agent_a.parameters(), lr=lr)
optimizer_b = optim.Adam(agent_b.parameters(), lr=lr)

for episode in range(max_episodes):
    states = env.reset()
    states_a = []
    states_b = []
    actions_a = []
    actions_b = []
    rewards_a = []
    rewards_b = []
    log_probs_a = []
    log_probs_b = []
    dones = []

    for _ in range(max_steps):
        action_a = select_action(agent_a, states[0])
        action_b = select_action(agent_b, states[1])

        actions = [action_a, action_b]
        next_states, rewards, done = env.step(actions)

        states_a.append(states[0])
        states_b.append(states[1])
        actions_a.append(action_a)
        actions_b.append(action_b)
        rewards_a.append(rewards[0])
        rewards_b.append(rewards[1])
        dones.append(done)

        states = next_states

        if done:
            break

    returns_a = compute_returns(rewards_a, dones, gamma)
    returns_b = compute_returns(rewards_b, dones, gamma)

    states_tensor_a = torch.FloatTensor(states_a)
    states_tensor_b = torch.FloatTensor(states_b)
    actions_tensor_a = torch.LongTensor(actions_a)
    actions_tensor_b = torch.LongTensor(actions_b)

    policy_a, values_a = agent_a(states_tensor_a)
    policy_b, values_b = agent_b(states_tensor_b)

    values_a = values_a.squeeze().detach().numpy()
    values_b = values_b.squeeze().detach().numpy()

    action_probs_a = torch.softmax(policy_a, dim=-1).detach().numpy()
    action_probs_b = torch.softmax(policy_b, dim=-1).detach().numpy()

    old_log_probs_a = np.log([action_probs_a[i, actions_a[i]] for i in range(len(actions_a))])
    old_log_probs_b = np.log([action_probs_b[i, actions_b[i]] for i in range(len(actions_b))])

    advantages_a = np.array(returns_a) - values_a
    advantages_b = np.array(returns_b) - values_b

    # Normalize advantages
    advantages_a = (advantages_a - advantages_a.mean()) / (advantages_a.std() + 1e-8)
    advantages_b = (advantages_b - advantages_b.mean()) / (advantages_b.std() + 1e-8)

    # Update policy for Agent A
    update_policy(agent_a, optimizer_a, states_a, actions_a, old_log_probs_a, returns_a, advantages_a)

    # Update policy for Agent B
    update_policy(agent_b, optimizer_b, states_b, actions_b, old_log_probs_b, returns_b, advantages_b)

    if episode % 10 == 0:
        print(f"Episode {episode} complete")

# Testing the trained agents
def test_agents(env, agent_a, agent_b, max_steps=200):
    state = env.reset()
    total_reward = 0

    for step in range(max_steps):
        action_a = select_action(agent_a, state[0])
        action_b = select_action(agent_b, state[1])
        actions = [action_a, action_b]
        next_state, rewards, done = env.step(actions)
        total_reward += sum(rewards)
        state = next_state
        if done:
            break

    return total_reward

# Evaluate trained agents
total_rewards = []
for _ in range(10):
    total_reward = test_agents(env, agent_a, agent_b)
    total_rewards.append(total_reward)

print(f"Average total reward over 10 episodes: {np.mean(total_rewards)}")



  states_tensor_a = torch.FloatTensor(states_a)


Episode 0 complete
Episode 10 complete
Episode 20 complete
Episode 30 complete
Episode 40 complete
Episode 50 complete
Episode 60 complete
Episode 70 complete
Episode 80 complete
Episode 90 complete
Episode 100 complete
Episode 110 complete
Episode 120 complete
Episode 130 complete
Episode 140 complete
Episode 150 complete
Episode 160 complete
Episode 170 complete
Episode 180 complete
Episode 190 complete
Episode 200 complete
Episode 210 complete
Episode 220 complete
Episode 230 complete
Episode 240 complete
Episode 250 complete
Episode 260 complete
Episode 270 complete
Episode 280 complete
Episode 290 complete
Episode 300 complete
Episode 310 complete
Episode 320 complete
Episode 330 complete
Episode 340 complete
Episode 350 complete
Episode 360 complete
Episode 370 complete
Episode 380 complete
Episode 390 complete
Episode 400 complete
Episode 410 complete
Episode 420 complete
Episode 430 complete
Episode 440 complete
Episode 450 complete
Episode 460 complete
Episode 470 complete
Epi