In [1]:
from typing import Optional
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import random

import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

import gymnasium as gym
from gymnasium import spaces
import numpy as np

In [2]:
class PreyPredatorEnv(gym.Env):
    def __init__(self, num_predators=2, num_preys=2):
        super(PreyPredatorEnv, self).__init__()
        self.grid_size = 5
        self.num_predators = num_predators
        self.num_preys = num_preys
        self.action_space = spaces.MultiDiscrete([4] * (num_predators + num_preys))  # Actions for all predators and preys
        self.observation_space = spaces.Box(low=0, high=self.grid_size-1, shape=(2 * (num_predators + num_preys),), dtype=np.int32)
        self.reset()

    def reset(self):
        # Predators start at random positions
        self.predator_positions = np.random.randint(0, self.grid_size, size=(self.num_predators, 2))
        # Preys start at random positions
        self.prey_positions = np.random.randint(0, self.grid_size, size=(self.num_preys, 2))
        return np.concatenate([self.predator_positions.flatten(), self.prey_positions.flatten()])

    def step(self, predator_actions, prey_actions):
        # predator_actions = actions[:self.num_predators]
        # prey_actions = actions[self.num_predators:]

        # Move the predators based on their actions
        print(predator_actions)
        for i, action in enumerate(predator_actions):
            if action == 0:  # up
                self.predator_positions[i][0] = max(self.predator_positions[i][0] - 1, 0)
            elif action == 1:  # down
                self.predator_positions[i][0] = min(self.predator_positions[i][0] + 1, self.grid_size - 1)
            elif action == 2:  # left
                self.predator_positions[i][1] = max(self.predator_positions[i][1] - 1, 0)
            elif action == 3:  # right
                self.predator_positions[i][1] = min(self.predator_positions[i][1] + 1, self.grid_size - 1)

        # Move the preys based on their actions
        for i, action in enumerate(prey_actions):
            if action == 0:  # up
                self.prey_positions[i][0] = max(self.prey_positions[i][0] - 1, 0)
            elif action == 1:  # down
                self.prey_positions[i][0] = min(self.prey_positions[i][0] + 1, self.grid_size - 1)
            elif action == 2:  # left
                self.prey_positions[i][1] = max(self.prey_positions[i][1] - 1, 0)
            elif action == 3:  # right
                self.prey_positions[i][1] = min(self.prey_positions[i][1] + 1, self.grid_size - 1)

        # Check if any predator caught any prey
        done = False
        predator_reward = -0.1 * self.num_predators
        prey_reward = 0
        for predator_pos in self.predator_positions:
            for prey_pos in self.prey_positions:
                if np.array_equal(predator_pos, prey_pos):
                    done = True
                    predator_reward += 1
                    prey_reward -= 10

        return np.concatenate([self.predator_positions.flatten(), self.prey_positions.flatten()]), predator_reward, prey_reward, done, {}

    def render(self, mode='human'):
        fig, ax = plt.subplots()
        ax.set_xlim(0, self.grid_size)
        ax.set_ylim(0, self.grid_size)
        ax.set_xticks(np.arange(0, self.grid_size, 1))
        ax.set_yticks(np.arange(0, self.grid_size, 1))
        ax.grid(True)

        # Add preys
        for prey_pos in self.prey_positions:
            prey_patch = patches.Rectangle(prey_pos[::-1], 1, 1, edgecolor='green', facecolor='green', label="Prey")
            ax.add_patch(prey_patch)

        # Add predators
        for predator_pos in self.predator_positions:
            predator_patch = patches.Rectangle(predator_pos[::-1], 1, 1, edgecolor='red', facecolor='red', label="Predator")
            ax.add_patch(predator_patch)

        plt.legend(handles=[prey_patch, predator_patch])
        plt.title('Predator-Prey Grid World')
        plt.pause(0.3)  # Pause to visualize each step
        plt.show()

    def close(self):
        pass


In [3]:
# Neural Network for the Q-value function
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class Agent:
    def __init__(self, env, input_dim=4):
        self.env = env
        self.input_dim = input_dim
        # print(env.action_space)
        self.q_network = DQN(input_dim=self.input_dim, output_dim=8)
        self.target_network = DQN(input_dim=self.input_dim, output_dim=8)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=0.001)
        self.loss_fn = nn.MSELoss()

        self.memory = deque(maxlen=10000)
        self.gamma = 0.99
        self.batch_size = 64
        self.epsilon = 0.1  # Exploration rate
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.001

    def select_action(self, state):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).unsqueeze(0)
                q_values = self.q_network(state)
                return q_values.argmax().item()

    def store_transition(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def update_q_network(self):
        if len(self.memory) < self.batch_size:
            return
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_values = self.target_network(next_states).max(1)[0]
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = self.loss_fn(current_q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

In [4]:
env = PreyPredatorEnv()
predator = Agent(env, 8)
prey = Agent(env, 8)
num_episodes = 5000

predator_rewards_per_episode = []
prey_rewards_per_episode = []
steps_per_episode = []

for episode in range(num_episodes):
    state = env.reset()
    predator_total_reward = 0
    prey_total_reward = 0
    steps = 0
    done = False

    while steps < 100 and not done:
        predator_action = predator.select_action(state)
        prey_action = prey.select_action(state)
        next_state, predator_reward, prey_reward, done, _ = env.step(predator_action, prey_action)
        predator.store_transition(state, predator_action, predator_reward, next_state, done)
        prey.store_transition(state, prey_action, prey_reward, next_state, done)
        predator.update_q_network()
        prey.update_q_network()
        state = next_state
        predator_total_reward += predator_reward
        prey_total_reward += prey_reward
        steps += 1

    predator.update_target_network()
    prey.update_target_network()

    if predator.epsilon > predator.epsilon_min:
        predator.epsilon *= predator.epsilon_decay
    
    if prey.epsilon > prey.epsilon_min:
        prey.epsilon *= prey.epsilon_decay

    predator_rewards_per_episode.append(predator_total_reward)
    prey_rewards_per_episode.append(prey_total_reward)
    steps_per_episode.append(steps)

    print(f"Episode {episode + 1}, Total Reward for Predator: {predator_total_reward}, Total Reward for Prey: {prey_total_reward}, Steps: {steps}")

# Plot the learning process
plt.figure(figsize=(12, 5))

# Plot total reward per episode for predator
plt.subplot(1, 2, 1)
plt.plot(predator_rewards_per_episode, label='Predator')
plt.plot(prey_rewards_per_episode, label='Prey')
plt.title('Total Reward per Episode')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.legend()

# Plot number of steps per episode
plt.subplot(1, 2, 2)
plt.plot(steps_per_episode)
plt.title('Steps per Episode')
plt.xlabel('Episode')
plt.ylabel('Number of Steps')

plt.tight_layout()
plt.show()


# # Visualize the strategy learned by the agent in a test run
# state = env.reset()
# done = False
# s = 0
# while s < 50 and not done:
#     s += 1
#     env.render()  # Render each step
#     predator_action = predator.select_action(state)  # Use the learned policy
#     prey_action = prey.select_action(state)
#     state, predator_reward, prey_reward, done, _ = env.step(predator_action, prey_action)
# env.render()  # Render the final state
# env.close()


  from .autonotebook import tqdm as notebook_tqdm


3


TypeError: 'int' object is not iterable

In [None]:
# test score
num_episodes = 100
predator_total_reward = 0
prey_total_reward = 0
tot_steps = 0
count_kill = 0
for episode in range(num_episodes):
    if episode % 10 == 0:
        print("Current episode:", episode)
    state = env.reset()
    done = False
    steps = 0
    while steps < 10000 and not done:
        predator_action = predator.select_action(state)
        prey_action = prey.select_action(state)
        next_state, predator_reward, prey_reward, done, _ = env.step(predator_action, prey_action)
        state = next_state
        predator_total_reward += predator_reward
        prey_total_reward += prey_reward
        steps += 1
    
    if done:
        count_kill += 1
        print(steps)
    
    tot_steps += steps

print(f"Total kill: {count_kill}/{num_episodes}")
print(f"Average steps: {tot_steps/num_episodes}")

Current episode: 0
2220
4
Current episode: 10
Current episode: 20
2508
913
Current episode: 30
7836
734
Current episode: 40
1373
Current episode: 50
7642
Current episode: 60
Current episode: 70
Current episode: 80
Current episode: 90
Total kill: 8/100
Average steps: 9432.3
