# Predator-Prey Ecosystem

### Model

In [142]:
import torch.nn as nn
import torch.nn.functional as F


class DDQNLSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DDQNLSTM, self).__init__()

        # Convolutional layers with padding to preserve dimensions
        self.conv1 = nn.Conv2d(in_channels=input_shape[0], out_channels=32, kernel_size=4, stride=4)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=2)

        # LSTM layer
        self.lstm = nn.LSTM(input_size=256, hidden_size=256, batch_first=True)

        # Fully connected layers for state-value and advantage-value streams
        self.fc_output_layer = nn.Linear(256, 128)
        self.output_layer = nn.Linear(128, n_actions)

    def forward(self, x, hidden_state=None):
        batch_size = x.size(0)

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = x.view(batch_size, -1)
        x = x.unsqueeze(1)

        if hidden_state is None:
            x, hidden_state = self.lstm(x)
        else:
            x, hidden_state = self.lstm(x, hidden_state)

        x = x.squeeze(1)

        state = F.relu(self.fc_output_layer(x))
        output = self.output_layer(state)

        return output, hidden_state

### Agent

In [143]:
import random


class Agent():
    def __init__(self, id, role, position):
        self.id = id    # TODO think about guid - unique in whole env
        self.role = role
        self.position = position

        if role == 'predator':
            self.health = random.uniform(0.5, 1)
        else:
            self.health = 1

    def set_position(self, position):
        self.position = position

    def get_position(self):
        return self.position

    def get_random_action(self):
        return random.choice([1, 2, 3, 4])   # Actions: 1=up, 2=down, 3=left, 4=right

    def add_health(self, health_gained):
        self.health += health_gained

### Environment

In [144]:
import math
import random
import numpy as np
from pettingzoo.utils.env import ParallelEnv


class PredatorPreyEnv(ParallelEnv):
    def __init__(self, grid_size=(15, 15), num_predators=2, num_prey=3, num_walls=5, 
                 predator_scope=2, health_gained=0.3):
        """
        Initializes the environment.
        grid_size: Tuple[int, int] - dimensions of the grid.
        num_predators: int - number of predator agents.
        num_prey: int - number of prey agents.
        num_walls: int - number of wall elements.
        predator_scope: int - range of predator, where preys are killed
        health_gained: float - value of health restored with killing a prey
        """
        self.grid_size = grid_size
        self.num_predators = num_predators
        self.num_prey = num_prey
        self.num_walls = num_walls
        self.predator_scope = predator_scope
        self.health_gained = health_gained

        self.max_num_predators = 10000
        self.max_num_preys = 10000

        self.agents = []
        # self.agent_positions = {agent: None for agent in self.agents}
        # self.agent_health = {agent: 1 for agent in self.agents}
        self.walls_positions = []

        self.grid = np.zeros(self.grid_size, dtype=object)

        # self.reset()

    def reset(self):
        """Resets the environment."""
        self.grid.fill(0)
        self.walls_positions.clear()

        # Place walls
        for _ in range(self.num_walls):
            while True:
                x, y = random.randint(0, self.grid_size[0] - 1), random.randint(0, self.grid_size[1] - 1)
                if self.grid[x, y] == 0:
                    self.grid[x, y] = -1  # Wall
                    self.walls_positions.append((x, y))
                    break

        # Create and place predators
        for i in range(self.num_predators):
            while True:
                x, y = random.randint(0, self.grid_size[0] - 1), random.randint(0, self.grid_size[1] - 1)
                if self.grid[x, y] == 0:
                    predator = Agent(f"pr_{i}", "predator", (x, y))
                    self.agents.append(predator)
                    self.grid[x, y] = predator  # Predator
                    break

        # Create and place prey
        for i in range(self.num_prey):
            while True:
                x, y = random.randint(0, self.grid_size[0] - 1), random.randint(0, self.grid_size[1] - 1)
                if self.grid[x, y] == 0:
                    prey = Agent(f"py_{i}", "prey", (x, y))
                    self.agents.append(prey)
                    self.grid[x, y] = prey  # Prey
                    break

        return {agent.id: self.get_observation(agent) for agent in self.agents}

    def agents_move(self, actions):
        """Make a move of each agent"""
        new_positions = {}

        for agent in self.agents:
            x, y = agent.get_position()
            new_x, new_y = x, y

            # random actions for now
            action = actions[agent.id]

            if action == 1:  # up
                new_x = (x - 1) % self.grid_size[0]
            elif action == 2:  # down
                new_x = (x + 1) % self.grid_size[0]
            elif action == 3:  # left
                new_y = (y - 1) % self.grid_size[1]
            elif action == 4:  # right
                new_y = (y + 1) % self.grid_size[1]

            if self.grid[new_x, new_y] == 0:  # Move if the cell is empty
                new_positions[agent.id] = (new_x, new_y)
            else:  # Stay in place if the cell is occupied
                new_positions[agent.id] = (x, y)

        # Update grid and agent positions
        self.grid.fill(0)
        for wall in self.walls_positions:
            self.grid[wall[0], wall[1]] = -1

        for agent in self.agents:
            x, y = new_positions[agent.id]
            self.grid[x, y] = agent
            agent.set_position((x, y))

    # TODO try to optimise with object pointers
    def hunting(self, rewards, dones):
        """Handle predator prey interaction - hunting"""
        for predator in [a for a in self.agents if "predator" in a.role]:
            px, py = predator.get_position()
            prey_in_scope = []

            for dx in range(-self.predator_scope, self.predator_scope + 1):
                for dy in range(-self.predator_scope, self.predator_scope + 1):
                    if dx == 0 and dy == 0:
                        continue
                    nx, ny = (px + dx) % self.grid_size[0], (py + dy) % self.grid_size[1]
                    if type(self.grid[nx, ny]) == Agent and self.grid[nx, ny].role == 'prey':
                        distance = abs(dx) + abs(dy)  # Manhattan
                        prey_in_scope.append((distance, (nx, ny)))

            if prey_in_scope:
                # Kill the nearest prey
                prey_in_scope.sort()
                target_prey_pos = prey_in_scope[0][1]
                for prey in self.agents:
                    pos = prey.get_position()
                    if pos == target_prey_pos:
                        self.agents.remove(prey)
                        self.grid[target_prey_pos[0], target_prey_pos[1]] = 0
                        rewards[predator.id] += 1  # Reward for eating prey
                        rewards[prey.id] += -1
                        predator.add_health(self.health_gained)  # Add constant value
                        dones[prey.id] = True
                        # print(f'{prey.id} killed')
                        break

        return rewards, dones

    def predator_hunger(self, dones):
        """Decrease predator health and remove dead predators"""
        for predator in [a for a in self.agents if "predator" in a.role]:
            predator.add_health(-0.01)
            if predator.health <= 0:
                px, py = predator.get_position()
                self.agents.remove(predator)
                self.grid[px, py] = 0
                dones[predator.id] = True
                # print(f'{predator.id} killed')
        return dones

    def generate_new_agents(self, p_predator=0.003, p_prey=0.006):
        """
        Generates new predators and prey based on the provided formula.
        p_predator: float - probability factor for generating new predators.
        p_prey: float - probability factor for generating new prey.
        """
        # Calculate the number of new predators and prey
        num_predators = len([a for a in self.agents if "predator" in a.role])
        num_preys = len([a for a in self.agents if "prey" in a.role])

        new_preys = 0
        new_predators = 0
        if num_predators < self.max_num_predators:
            new_predators = max(1, math.ceil(num_predators * p_predator))
        if num_preys < self.max_num_preys:
            new_preys = max(1, math.ceil(num_preys * p_prey))

        # Add new predators
        for _ in range(new_predators):
            predator_id = f"pr_{len([a for a in self.agents if 'predator' in a.role])}"
            while True:
                x, y = random.randint(0, self.grid_size[0] - 1), random.randint(0, self.grid_size[1] - 1)
                if self.grid[x, y] == 0:  # Empty cell
                    created_agent = Agent(predator_id, 'predator', (x, y))
                    self.grid[x, y] = created_agent  # Predator
                    self.agents.append(created_agent)
                    break

        # Add new preys
        for _ in range(new_preys):
            prey_id = f"py_{len([a for a in self.agents if 'prey' in a.role])}"
            while True:
                x, y = random.randint(0, self.grid_size[0] - 1), random.randint(0, self.grid_size[1] - 1)
                if self.grid[x, y] == 0:  # Empty cell
                    created_agent = Agent(prey_id, 'prey', (x, y))
                    self.grid[x, y] = created_agent  # Prey
                    self.agents.append(created_agent)
                    break

    def step(self, actions):
        """Takes a step in the environment based on the actions and environment rules."""
        rewards = {agent.id: 0 for agent in self.agents}
        dones = {agent.id: False for agent in self.agents}

        self.agents_move(actions)

        rewards, dones = self.hunting(rewards, dones)

        dones = self.predator_hunger(dones)

        self.generate_new_agents(0.003, 0.006)

        observations = {agent.id: self.get_observation(agent) for agent in self.agents}

        return observations, rewards, dones

    def get_observation(self, agent):
        """Returns a 4-channel local grid observation for the given agent."""
        ax, ay = agent.get_position()
        size = self.predator_scope * 8 + 1

        wall_layer = np.zeros((size, size), dtype=int)
        predator_layer = np.zeros((size, size), dtype=int)
        prey_layer = np.zeros((size, size), dtype=int)
        health_layer = np.zeros((size, size), dtype=float)

        for dx in range(-5*self.predator_scope, 5*self.predator_scope + 1):
            for dy in range(-5*self.predator_scope, 5*self.predator_scope + 1):
                nx, ny = (ax + dx) % self.grid_size[0], (ay + dy) % self.grid_size[1]
                local_x, local_y = dx + self.predator_scope, dy + self.predator_scope

                if self.grid[nx, ny] == -1:
                    wall_layer[local_x, local_y] = 1
                elif type(self.grid[nx, ny]) == Agent and self.grid[nx, ny].role == 'predator':
                    predator_layer[local_x, local_y] = 1
                    health_layer[local_x, local_y] = self.grid[nx, ny].health
                elif type(self.grid[nx, ny]) == Agent and self.grid[nx, ny].role == 'prey':
                    prey_layer[local_x, local_y] = 1
                    health_layer[local_x, local_y] = self.grid[nx, ny].health

        return np.stack([wall_layer, predator_layer, prey_layer, health_layer], axis=0)

    def render(self):
        """Renders the environment in the console."""
        render_grid = np.full(self.grid.shape, '.')

        render_grid[self.grid == -1] = '#'  # Wall
        render_grid[self.grid == 1] = 'O'  # Prey
        render_grid[self.grid == 2] = 'X'  # Predator

        print("\n".join("".join(row) for row in render_grid))
        print()

### Renderer

In [145]:
import numpy as np
import cv2
import os
from IPython.display import Video


class PredatorPreyRenderer:
    def __init__(self, env):
        """
        Initializes the renderer.
        :param env: The PredatorPrey environment.
        """
        self.env = env
        self.replay_buffer = []

    def render_frame(self, grid, agent_data, cell_size=20, focus_area=None):
        """
        Renders a single frame of the environment.
        :param grid: The grid state to render.
        :param agent_data: A dictionary of agent positions {agent_id: (x, y)}.
        :param cell_size: Size of each grid cell in pixels.
        :param focus_area: If provided, renders only the specified area (x_start, y_start, x_end, y_end).
        :return: A numpy array representing the frame.
        """
        if focus_area:
            x_start, y_start, x_end, y_end = focus_area
            frame = np.zeros(((x_end - x_start) * cell_size, (y_end - y_start) * cell_size, 3), dtype=np.uint8)
            frame.fill(255)  # White background
            grid_slice = grid[x_start:x_end, y_start:y_end]
        else:
            frame = np.zeros((self.env.grid_size[0] * cell_size, self.env.grid_size[1] * cell_size, 3), dtype=np.uint8)
            frame.fill(255)  # White background
            grid_slice = grid

        # Draw hunt scope for predators as a background
        for agent_id, (x, y) in agent_data.items():
            agent = next((a for a in self.env.agents if a.id == agent_id), None)
            if agent and agent.role == 'predator':  # Only draw scope for predators
                if focus_area:
                    # Adjust predator position relative to the focus area
                    x -= focus_area[0]
                    y -= focus_area[1]

                center_x = y * cell_size + cell_size / 2
                center_y = x * cell_size + cell_size / 2

                scope_size = self.env.predator_scope * cell_size
                top_left = (int(center_x - scope_size), int(center_y - scope_size))
                bottom_right = (int(center_x + scope_size), int(center_y + scope_size))

                # Draw the hunt scope as a light green filled rectangle (background)
                cv2.rectangle(frame, top_left, bottom_right, (200, 255, 200), -1)

        # Draw walls, predators, and prey
        for i in range(grid_slice.shape[0]):
            for j in range(grid_slice.shape[1]):
                x, y = i * cell_size, j * cell_size
                if grid_slice[i, j] == -1:
                    cv2.rectangle(frame, (y, x), (y + cell_size, x + cell_size), (0, 255, 255), -1)  # Yellow walls
                elif isinstance(grid_slice[i, j], Agent):
                    if grid_slice[i, j].role == 'predator':
                        cv2.rectangle(frame, (y, x), (y + cell_size, x + cell_size), (255, 0, 0), -1)  # Blue predator
                    else:
                        cv2.rectangle(frame, (y, x), (y + cell_size, x + cell_size), (0, 0, 255), -1)  # Red prey

        return frame

    def snapshot(self):
        """
        Captures a snapshot of the current environment state and adds it to the replay buffer.
        """
        agent_data = {agent.id: agent.get_position() for agent in self.env.agents}
        self.replay_buffer.append((self.env.grid.copy(), agent_data))

    def show_replay(self, file_name="replay.mp4", fps=10, cell_size=20, focus_size=None):
        """
        Generates a video from the replay buffer and displays it in Jupyter Notebook.
        :param file_name: Name of the output video file.
        :param fps: Frames per second for video rendering.
        :param cell_size: Size of each grid cell in pixels.
        :param focus_size: Size of the focused area (in cells) for 'focus' mode.
        :return: Video object for display in Jupyter Notebook.
        """
        if not self.replay_buffer:
            print("Replay buffer is empty. No video to display.")
            return

        if focus_size:
            x_center = self.env.grid_size[0] // 2
            y_center = self.env.grid_size[1] // 2
            focus_area = (
                max(0, x_center - focus_size // 2),
                max(0, y_center - focus_size // 2),
                min(self.env.grid_size[0], x_center + focus_size // 2),
                min(self.env.grid_size[1], y_center + focus_size // 2)
            )
        else:
            focus_area = None

        if focus_area:
            height = (focus_area[2] - focus_area[0]) * cell_size
            width = (focus_area[3] - focus_area[1]) * cell_size
        else:
            height = self.env.grid_size[0] * cell_size
            width = self.env.grid_size[1] * cell_size

        temp_video_path = "temp_video.mp4"
        try:
            video = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
            if not video.isOpened():
                raise Exception("Failed to create video file. Check dimensions or codec support.")

            for snapshot in self.replay_buffer:
                grid, agent_data = snapshot
                frame = self.render_frame(grid, agent_data, cell_size=cell_size, focus_area=focus_area)
                video.write(frame)

            video.release()

            os.system(f"ffmpeg -y -i {temp_video_path} -vcodec libx264 -x264opts keyint=123:min-keyint=120 -an {file_name}")

            return Video(file_name, embed=True)

        except Exception as e:
            print(f"Error during video generation: {e}")
            return None

    def reset(self):
        """
        Resets the replay buffer.
        """
        self.replay_buffer = []

### Training

In [146]:
import sys
import unittest
from collections import deque

import random
import csv
import torch
import torch.optim as optim


def batchify(data, batch_size):
    return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]

def update_weights(agent_replay_buffer, agent_policy_model, agent_target_model, agent_optimizer, device='cpu'):
    batch = random.sample(agent_replay_buffer, BUFFER_SIZE)

    mini_batches = batchify(batch, BATCH_SIZE)
    for minibatch in mini_batches:
        q_values_batch = []
        target_q_values = []
        for obs_mn, action_mn, reward_mn, done_mn, next_obs_mn, hidden_state_mn, next_hidden_state_mn in minibatch:
            with torch.no_grad():
                next_obs = torch.tensor(next_obs_mn, dtype=torch.float32).unsqueeze(0).to(device)
                next_action = torch.argmax(agent_policy_model(next_obs, next_hidden_state_mn)[0])
                target_q_value = reward_mn + GAMMA * (1 - done_mn) * \
                                 agent_target_model(next_obs, next_hidden_state_mn)[0].squeeze(0)[next_action]
                target_q_values.append(target_q_value)
            q_values, _ = agent_policy_model(torch.tensor(obs_mn, dtype=torch.float32, device=device).unsqueeze(0), hidden_state_mn)
            q_value = q_values.gather(1, action_mn.view(1, 1)).squeeze()
            q_values_batch.append(q_value)
        target_q_values = torch.stack(target_q_values)

        q_values_batch = torch.stack((q_values_batch))
        # Compute current Q-values and loss
        # if all(x is None for x in hidden_state_batch):
        #     q_values, _ = agent_policy_model(obs_batch)
        # else:
        #     q_values, _ = agent_policy_model(obs_batch, hidden_state_batch)
        # q_values = q_values.gather(1, action_batch.unsqueeze(1)).squeeze()
        loss = torch.nn.functional.mse_loss(q_values_batch, target_q_values)

        agent_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(agent_policy_model.parameters(), 1.0)
        agent_optimizer.step()

    if i % UPDATE_FREQ == 0:
        agent_target_model.load_state_dict(agent_policy_model.state_dict())
        torch.save(predator_target_model.state_dict(), "models/predator_target_model.pth")
        torch.save(predator_policy_model.state_dict(), "models/predator_policy_model.pth")

        torch.save(prey_target_model.state_dict(), "models/prey_target_model.pth")
        torch.save(prey_policy_model.state_dict(), "models/prey_policy_model.pth")
    agent_replay_buffer.clear()

# Wrapping the environment - Can be added in the future
def env_creator():
    env = PredatorPreyEnv((600, 600), 1000, 1000, 1000, 5, 1.0)
    return env

RUN_TESTS_BEFORE = False

def run_tests():
    print("Running tests...")
    
    test_suite = unittest.defaultTestLoader.discover(start_dir='.', pattern='test_*.py')
    test_runner = unittest.TextTestRunner()
    result = test_runner.run(test_suite)

    if not result.wasSuccessful():
        print("Tests failed! The program will be terminated...")
        sys.exit(1)
    else:
        print("All tests passed! Proceeding to main program...")

# ==============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if RUN_TESTS_BEFORE:
    run_tests() 
else:
    print("WARNING: running without tests...")    

# Hyperparameters
BUFFER_SIZE = 64
BATCH_SIZE = 64
EPSILON = 0.1
UPDATE_FREQ = 50
GAMMA = 0.99
LEARNING_RATE = 0.0001

env = env_creator()
obs = env.reset()

renderer = PredatorPreyRenderer(env)

csv_file = 'Eval_output_ENV_1_more_hunger_ceil_more_reward_bigger_observation.csv'
data = []

predator_replay_buffer = deque()
prey_replay_buffer = deque()

# Models
predator_policy_model = DDQNLSTM((4, 51, 51), 4).to(device)
predator_target_model = DDQNLSTM((4, 51, 51), 4).to(device)
prey_policy_model = DDQNLSTM((4, 51, 51), 4).to(device)
prey_target_model = DDQNLSTM((4, 51, 51), 4).to(device)

# Optimizers
predator_optimizer = optim.Adam(predator_policy_model.parameters(), lr=LEARNING_RATE)
prey_optimizer = optim.Adam(prey_policy_model.parameters(), lr=LEARNING_RATE)

hidden_states = {agent.id: None for agent in env.agents}
new_hidden_states = {agent.id: None for agent in env.agents}

for i in range(30):
    actions = {}
    # actions = {agent.id: random.randint(0, 4) for agent in env.agents}
    for agent in env.agents:
        obs_tensor = torch.tensor(obs[agent.id], dtype=torch.float32).unsqueeze(0).to(device)
        if agent.id not in hidden_states.keys():
            hidden_state = None
            hidden_states[agent.id] = None
        else:
            hidden_state = hidden_states[agent.id]
        if agent.role == 'predator':
            action_values, new_hidden_state = predator_policy_model(obs_tensor, hidden_state)
        else:
            action_values, new_hidden_state = prey_policy_model(obs_tensor, hidden_state)

        if random.random() < EPSILON:  # Exploration
            actions[agent.id] = torch.tensor(random.randint(0, 3), device=device)  # Assuming action space is [0, 1, 2, 3]
        else:  # Exploitation
            actions[agent.id] = torch.argmax(action_values)
        new_hidden_states[agent.id] = new_hidden_state

    new_obs, rewards, dones = env.step(actions)
    renderer.snapshot()

    for agent_id in actions.keys():
        if dones[agent_id]:
            new_obs_to_save = torch.zeros_like(torch.tensor(obs[agent_id], dtype=torch.float32)).to(device)  # Placeholder
        else:
            new_obs_to_save = new_obs[agent_id]
        experience = (
            obs[agent_id],  # Current observation
            actions[agent_id],  # Action taken
            rewards[agent_id],  # Reward received
            dones[agent_id],  # Done flag
            new_obs_to_save,  # Next observation
            hidden_states[agent_id],  # Current hidden state
            new_hidden_states[agent_id]
        )
        if agent_id[:2] == 'pr':
            predator_replay_buffer.append(experience)
        else:
            prey_replay_buffer.append(experience)

    if len(predator_replay_buffer) >= BUFFER_SIZE:
        # Sample a minibatch and train (same as before)
        update_weights(predator_replay_buffer, predator_policy_model, predator_target_model, predator_optimizer, device)
    if len(prey_replay_buffer) >= BUFFER_SIZE:
        # Sample a minibatch and train (same as before)
        update_weights(prey_replay_buffer, prey_policy_model, prey_target_model, prey_optimizer, device)

    num_predators = len([a for a in env.agents if "predator" in a.role])
    num_preys = len([a for a in env.agents if "prey" in a.role])
    data.append([i, num_predators, num_preys])

    obs = new_obs
    hidden_state = new_hidden_states
    
    print(i, num_predators, num_preys)
    
    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([i, num_predators, num_preys])
        
torch.save(predator_target_model.state_dict(), "predator_target_model.pth")
torch.save(predator_policy_model.state_dict(), "predator_policy_model.pth")

torch.save(prey_target_model.state_dict(), "prey_target_model.pth")
torch.save(prey_policy_model.state_dict(), "prey_policy_model.pth")

Using device: cpu


  next_obs = torch.tensor(next_obs_mn, dtype=torch.float32).unsqueeze(0).to(device)


0 1003 758
1 1007 721
2 1011 716
3 1015 712
4 1019 716
5 1023 710
6 1027 703
7 1031 697
8 1035 694
9 1039 690
10 1043 690
11 1047 687
12 1051 683
13 1055 678
14 1059 673
15 1063 666
16 1067 662
17 1071 653
18 1075 648
19 1079 641
20 1083 637
21 1087 626
22 1091 616
23 1095 606
24 1099 590
25 1103 578
26 1107 577
27 1111 575
28 1115 572
29 1119 569


In [147]:
# Option 1: Render the entire grid
display(renderer.show_replay("full_grid.mp4", fps=5, cell_size=2))

# Option 2: Render a focus area
display(renderer.show_replay("focus_area.mp4", fps=5, cell_size=20, focus_size=120))