<a href="https://colab.research.google.com/github/TheAmirHK/OceanFun_RL/blob/main/BubbleNetFeeding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Thu Feb  5 12:18:35 2025

@author: amirh
"""
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

In [None]:
class BubbleNetEnv(gym.Env):
    def __init__(self):
        super(BubbleNetEnv, self).__init__()
        self.grid_size = 50
        self.num_fish = 500
        self.max_steps = 2000
        self.max_bubbles = 1000
        self.num_whales = 5

        self.action_space = spaces.MultiDiscrete([5] * self.num_whales)
        self.observation_space = spaces.Box(
            low=0, high=self.grid_size, shape=(2 * self.num_whales + self.num_fish * 2 + self.max_bubbles * 2,), dtype=np.float32
        )

        # Initialize state
        self.whale_pos = None
        self.fish_pos = None
        self.bubbles = None
        self.steps = 0

        # Pygame initialization
        self.screen_size = 1000
        self.cell_size = self.screen_size // self.grid_size
        pygame.init()
        self.screen = pygame.display.set_mode((self.screen_size, self.screen_size))
        pygame.display.set_caption("Bubble Net Feeding - RL Simulation")
        self.clock = pygame.time.Clock()

        self.frame_dir = "frames"
        os.makedirs(self.frame_dir, exist_ok=True)
        self.frame_count = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.whale_pos = np.array([[self.grid_size // 2, 2]] * self.num_whales, dtype=np.float32)
        # Spread fish across the grid, avoiding clustering
        angle = np.linspace(0, 2 * np.pi, self.num_fish)
        radius = np.random.uniform(5, self.grid_size / 2, self.num_fish)  # My assumption, fish are spread in circle
        self.fish_pos = np.zeros((self.num_fish, 2), dtype=np.float32)
        self.fish_pos += np.random.uniform(-1, 1, self.fish_pos.shape)
        self.fish_pos[:, 0] = self.grid_size // 2 + radius * np.cos(angle)
        self.fish_pos[:, 1] = self.grid_size // 2 + radius * np.sin(angle)

        self.bubbles = []
        self.steps = 0
        self.frame_count = 0
        obs = self._get_obs()
        return obs, {}

### Here i have defiened the behaviours as follows:
# Leader whale moves in a spiral around the fish cluster while releasing bubbles.
# Bubbles push the fish toward the center.
# Other whales follow the leader but position themselves around the fish to create a trap.
# Once fish are grouped in the center, other whales stay positioned around them to catch them

    def step(self, action):
        self.steps += 1

        # Step 1: Compute fish cluster center
        fish_center = np.mean(self.fish_pos, axis=0)

        # Step 2: Determine dynamic radius based on fish spread
        fish_spread = np.max(np.linalg.norm(self.fish_pos - fish_center, axis=1))
        r = max(5, fish_spread * 0.9)  # Ensure a reasonable radius

        # Step 3: Leader whale spirals around fish center
        if len(self.bubbles) < self.max_bubbles:
            theta = self.steps * 0.05  # Spiral speed
            self.whale_pos[0][0] = fish_center[0] + r * np.cos(theta)
            self.whale_pos[0][1] = fish_center[1] + r * np.sin(theta)
            self.bubbles.append([self.whale_pos[0][0], self.whale_pos[0][1]])

        # Step 4: Other whales follow leader & surround fish
        for i in range(1, self.num_whales):
            angle = (i / self.num_whales) * 2 * 3  # Spread evenly
            self.whale_pos[i][0] = fish_center[0] + (r + 2) * np.cos(angle)
            self.whale_pos[i][1] = fish_center[1] + (r + 2) * np.sin(angle)

        # Step 5: Fish move toward the center due to bubbles
        for j in range(self.num_fish):
            self.fish_pos[j] += (fish_center - self.fish_pos[j]) * 0.0005  # Attraction force

        # move fish toward the center of the bubble net
        if len(self.bubbles) > 0:
            bubble_center = np.mean(self.bubbles, axis=0)
            for i in range(self.num_fish):
                self.fish_pos[i] += (bubble_center - self.fish_pos[i]) * 0.01

        # reward
        reward = 0
        for i in range(self.num_whales):
            for fish in self.fish_pos:
                if np.linalg.norm(fish - self.whale_pos[i]) < 0.1:  # Fish is trapped
                    reward += 1

        done = self.steps >= self.max_steps or reward >= self.num_fish

        obs = self._get_obs()
        return obs, reward, done, False, {}

    def _get_obs(self):
        whale_obs = self.whale_pos.flatten()
        fish_obs = self.fish_pos.flatten()
        bubble_obs = np.zeros(self.max_bubbles * 2, dtype=np.float32)
        if len(self.bubbles) > 0:
            bubble_obs[: len(self.bubbles) * 2] = np.array(self.bubbles, dtype=np.float32).flatten()[: self.max_bubbles * 2]
        obs = np.concatenate([whale_obs, fish_obs, bubble_obs]).astype(np.float32)
        return obs

    def render(self):
        self.screen.fill((0, 0, 0))

        # bubbles are in blue
        for bubble in self.bubbles:
            pygame.draw.circle(
                self.screen,
                (0, 0, 255),
                (int(bubble[0] * self.cell_size), int(bubble[1] * self.cell_size)),
                5,
            )

        # fish are in green
        for fish in self.fish_pos:
            pygame.draw.circle(
                self.screen,
                (0, 255, 0),
                (int(fish[0] * self.cell_size), int(fish[1] * self.cell_size)),
                5,
            )

        # whales are in red, except the leader whale (yellow)
        for i, whale in enumerate(self.whale_pos):
            color = (255, 255, 0) if i == 0 else (255, 0, 0)  # Leader whale is yellow, others are red
            pygame.draw.circle(
                self.screen,
                color,
                (int(whale[0] * self.cell_size), int(whale[1] * self.cell_size)),
                10,
            )

        frame_path = os.path.join(self.frame_dir, f"frame_{self.frame_count:04d}.png")
        pygame.image.save(self.screen, frame_path)
        self.frame_count += 1

        pygame.display.flip()
        self.clock.tick(20)  # FPS rate here !

    def close(self):
        pygame.quit()



In [None]:
env = BubbleNetEnv()
check_env(env)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

obs, _ = env.reset()
for _ in range(200):
    action, _ = model.predict(obs)
    obs, reward, done, truncated, info = env.step(action)
    env.render()
    if done:
        break

env.close()

frame_files = sorted([os.path.join(env.frame_dir, f) for f in os.listdir(env.frame_dir) if f.startswith("frame_")])
if frame_files:
    frame = cv2.imread(frame_files[0])
    height, width, _ = frame.shape
    video = cv2.VideoWriter("bubble_net_simulation.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height))
    for frame_file in frame_files:
        video.write(cv2.imread(frame_file))
    video.release()
    print("Video saved as bubble_net_simulation.mp4")