<a href="https://colab.research.google.com/github/TheAmirHK/OceanFun_RL/blob/main/EcholocationHunting/%C3%A9cholocalisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 19 20:14:11 2025

@author: amirh
"""

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
import random

In [None]:
class EcoEnv(gym.Env):
    def __init__(self):
        super(EcoEnv, self).__init__()
        self.grid_size = 100
        self.num_fish = 100
        self.max_steps = 2000
        self.num_whales = 5
        self.frame_count = 0

        self.echo_range = 10  # echolocation detection range
        self.echo_accuracy = 0.8  # accuracy of echolocation
        self.stealth_mode = [False] * self.num_whales

        # action space: 5 actions per whale (move up, down, left, right, toggle stealth)
        self.action_space = spaces.MultiDiscrete([5] * self.num_whales)

        # observtion space: whale positions, fish positions (detected), stealth mode
        self.observation_space = spaces.Box(
            low=0, high=self.grid_size,
            shape=(2 * self.num_whales + self.num_fish * 2  + self.num_whales,),
            dtype=np.float32
        )


        self.whale_pos = None
        self.fish_pos = None
        self.steps = 0
        self.caught_fish = np.zeros(self.num_fish, dtype=bool)

        self.screen_size = 500
        self.cell_size = self.screen_size // self.grid_size
        pygame.init()
        self.screen = pygame.display.set_mode((self.screen_size, self.screen_size))
        pygame.display.set_caption("Echolocation - RL Simulation")
        self.clock = pygame.time.Clock()

        self.frame_dir = "frames"
        os.makedirs(self.frame_dir, exist_ok=True)
        self.frame_count = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.whale_pos = np.array([[self.grid_size // 2, 2]] * self.num_whales, dtype=np.float32)

        # spread fish across the grid
        angle = np.linspace(0, 2 * np.pi, self.num_fish)
        radius = np.random.uniform(5, self.grid_size / 2, self.num_fish)
        self.fish_pos = np.zeros((self.num_fish, 2), dtype=np.float32)
        self.fish_pos[:, 0] = self.grid_size // 2 + radius * np.cos(angle)
        self.fish_pos[:, 1] = self.grid_size // 2 + radius * np.sin(angle)

        self.steps = 0
        self.frame_count = 0
        self.stealth_mode = [False] * self.num_whales
        self.caught_fish = np.zeros(self.num_fish, dtype=bool)
        obs = self._get_obs()
        return obs, {}

    def step(self, action):
        self.steps += 1

        # whale positions and stealth mode based on actions
        for i in range(self.num_whales):
            if action[i] == 4:  # stealth mode
                self.stealth_mode[i] = not self.stealth_mode[i]
            else:
                if not self.stealth_mode[i]:  # up, down, left, right
                    if action[i] == 0:
                        self.whale_pos[i][1] = max(0, self.whale_pos[i][1] - 1)
                    elif action[i] == 1:
                        self.whale_pos[i][1] = min(self.grid_size - 1, self.whale_pos[i][1] + 1)
                    elif action[i] == 2:
                        self.whale_pos[i][0] = max(0, self.whale_pos[i][0] - 2)
                    elif action[i] == 3:
                        self.whale_pos[i][0] = min(self.grid_size - 1, self.whale_pos[i][0] + 1)


        # echolocation mechanism: detect fish in stealth mode only and don't detect already caught ones

        detected_fish = np.zeros((self.num_fish, 2), dtype=np.float32)
        for i in range(self.num_whales):
            if not self.stealth_mode[i]:
                for j in range(self.num_fish):
                    if not self.caught_fish[j]:
                        distance = np.linalg.norm(self.whale_pos[i] - self.fish_pos[j])
                        if distance < self.echo_range and np.random.rand() < self.echo_accuracy:
                            detected_fish[j] = self.fish_pos[j]

        # move toward the fish,if no fish is detected then change the direction
        for i in range(self.num_whales):
            if np.any(detected_fish):  # If at least one fish is detected
                closest_fish = detected_fish[np.any(detected_fish, axis=1)].reshape(-1, 2)
                distances = np.linalg.norm(closest_fish - self.whale_pos[i], axis=1)
                target_fish = closest_fish[np.argmin(distances)]  # Pick the nearest detected fish
                direction = target_fish - self.whale_pos[i]
            else:
                fish_center = np.mean(self.fish_pos[~self.caught_fish], axis=0)
                direction = fish_center - self.whale_pos[i]

            if np.linalg.norm(direction) > 0.1:
                move_direction = np.sign(direction)  # (-1, 0, or 1)
                self.whale_pos[i] += move_direction


        # fish they can move on their own
        for j in range(self.num_fish):
            if not self.caught_fish[j]:
                self.fish_pos[j] += np.random.uniform(-1, 1, size=2)
                fish_center = np.array([self.grid_size / 2, self.grid_size / 2])
                if np.linalg.norm(self.fish_pos[j] - fish_center) > self.grid_size / 3:
                    self.fish_pos[j] += (fish_center - self.fish_pos[j]) * 0.05


        # whales hunt reward
        reward = 0
        for i in range(self.num_whales):
            for j in range(self.num_fish):
                if not self.caught_fish[j] and np.linalg.norm(self.whale_pos[i] - self.fish_pos[j]) < 1:
                    self.caught_fish[j] = True
                    reward += 2

        # termination conditins
        terminated = bool(np.all(self.caught_fish))
        truncated = self.steps >= self.max_steps

        obs = self._get_obs()
        info = {}
        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        whale_obs = self.whale_pos.flatten()
        fish_obs = self.fish_pos.flatten()
        stealth_obs = np.array(self.stealth_mode, dtype=np.float32)
        obs = np.concatenate([whale_obs, fish_obs, stealth_obs]).astype(np.float32)
        return obs

    def render(self, save_frame=False):
        self.screen.fill((0, 0, 0))

        # fish
        for j in range(self.num_fish):
            if not self.caught_fish[j]:
                x = int(self.fish_pos[j][0] * self.cell_size)
                y = int(self.fish_pos[j][1] * self.cell_size)
                pygame.draw.circle(self.screen, (0, 255, 0), (x, y), 3)

        # whales
        for i in range(self.num_whales):
            x = int(self.whale_pos[i][0] * self.cell_size)
            y = int(self.whale_pos[i][1] * self.cell_size)
            color = (255, 0, 0) if not self.stealth_mode[i] else (255, 100, 0) # they have two colors, once they are not in stealth mode and once they are.
            pygame.draw.circle(self.screen, color, (x, y), 6)

        pygame.display.flip()
        self.clock.tick(10)  # control FPS rate here

        if save_frame:
            frame_path = os.path.join(self.frame_dir, f"frame_{self.frame_count:04d}.png")
            pygame.image.save(self.screen, frame_path)
            self.frame_count += 1

    def close(self):
        pygame.quit()

In [None]:
env = EcoEnv()
check_env(env)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=5000)

obs, _ = env.reset()
for _ in range(500):
    action, _ = model.predict(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    env.render(save_frame=True)
    if terminated or truncated:
        break
env.close()


In [None]:
frame_files = sorted([os.path.join(env.frame_dir, f) for f in os.listdir(env.frame_dir) if f.startswith("frame_")])
if frame_files:
    frame = cv2.imread(frame_files[0])
    height, width, _ = frame.shape
    video = cv2.VideoWriter("echlocation_simulation.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height))
    for frame_file in frame_files:
        video.write(cv2.imread(frame_file))
    video.release()
    print("Video is saved")