<a href="https://colab.research.google.com/github/TheAmirHK/OceanFun_RL/blob/main/BubbleNetFeeding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

I'm still working the bubbuling simulation !!! But not bad at the moment.

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

In [None]:
# Environment for Bubble Net Feeding
class BubbleNetEnv(gym.Env):
    def __init__(self):
        super(BubbleNetEnv, self).__init__()
        self.grid_size = 20
        self.num_fish = 30
        self.max_steps = 1000
        self.max_bubbles = 200
        self.num_whales = 5

        # Define action and observation space
        self.action_space = spaces.MultiDiscrete([5] * self.num_whales)
        self.observation_space = spaces.Box(
            low=0, high=self.grid_size, shape=(2 * self.num_whales + self.num_fish * 2 + self.max_bubbles * 2,), dtype=np.float32
        )

        # Initialize state
        self.whale_pos = None
        self.fish_pos = None
        self.bubbles = None
        self.steps = 0

        # Pygame initialization
        self.screen_size = 500
        self.cell_size = self.screen_size // self.grid_size
        pygame.init()
        self.screen = pygame.display.set_mode((self.screen_size, self.screen_size))
        pygame.display.set_caption("Bubble Net Feeding - RL Simulation")
        self.clock = pygame.time.Clock()

        # Video recording setup
        self.frame_dir = "frames"
        os.makedirs(self.frame_dir, exist_ok=True)
        self.frame_count = 0

    def reset(self, seed=None, options=None):

        super().reset(seed=seed)
        self.whale_pos = np.array([[self.grid_size // 2, 2]] * self.num_whales, dtype=np.float32)

        self.fish_pos = np.random.uniform(0, self.grid_size, size=(self.num_fish, 2)).astype(np.float32)
        self.fish_pos[:, 1] = np.clip(self.fish_pos[:, 1], 4, self.grid_size)
        self.bubbles = []
        self.steps = 0
        self.frame_count = 0
        obs = self._get_obs()
        return obs, {}

    def step(self, action):
        self.steps += 1

        for i in range(self.num_whales):
            if action[i] == 0:
                self.whale_pos[i][1] = min(self.whale_pos[i][1] + 1, self.grid_size - 1)
            elif action[i] == 1:
                self.whale_pos[i][1] = max(self.whale_pos[i][1] - 1, 0)
            elif action[i] == 2:
                self.whale_pos[i][0] = max(self.whale_pos[i][0] - 1, 0)
            elif action[i] == 3:
                self.whale_pos[i][0] = min(self.whale_pos[i][0] + 1, self.grid_size - 1)
            elif action[i] == 4:  # only the first whale creates the spiral
                if i == 0 and len(self.bubbles) < self.max_bubbles:

                    # spiral pattern of bubbles
                    theta = self.steps * 0.05
                    r = self.steps * 0.01
                    bubble_x = self.whale_pos[i][0] + r * np.cos(theta)
                    bubble_y = self.whale_pos[i][1] + r * np.sin(theta)
                    self.bubbles.append([bubble_x, bubble_y])

        # move fish toward the center of the bubble net
        if len(self.bubbles) > 0:
            bubble_center = np.mean(self.bubbles, axis=0)
            for i in range(self.num_fish):
                self.fish_pos[i] += (bubble_center - self.fish_pos[i]) * 0.1

        # Calculate reward
        reward = 0
        for i in range(self.num_whales):
            for fish in self.fish_pos:
                if np.linalg.norm(fish - self.whale_pos[i]) < 1.0:  # Fish is trapped
                    reward += 1

        # Check termination
        done = self.steps >= self.max_steps or reward >= self.num_fish

        obs = self._get_obs()
        return obs, reward, done, False, {}

    def _get_obs(self):
        whale_obs = self.whale_pos.flatten()
        fish_obs = self.fish_pos.flatten()
        bubble_obs = np.zeros(self.max_bubbles * 2, dtype=np.float32)
        if len(self.bubbles) > 0:
            bubble_obs[: len(self.bubbles) * 2] = np.array(self.bubbles, dtype=np.float32).flatten()[: self.max_bubbles * 2]
        obs = np.concatenate([whale_obs, fish_obs, bubble_obs]).astype(np.float32)
        return obs

    def render(self):
        self.screen.fill((0, 0, 0))

        # bubbles are in blue
        for bubble in self.bubbles:
            pygame.draw.circle(
                self.screen,
                (0, 0, 255),
                (int(bubble[0] * self.cell_size), int(bubble[1] * self.cell_size)),
                5,
            )

        # fish are in green
        for fish in self.fish_pos:
            pygame.draw.circle(
                self.screen,
                (0, 255, 0),
                (int(fish[0] * self.cell_size), int(fish[1] * self.cell_size)),
                5,
            )

        #  whales are in red
        for whale in self.whale_pos:
            pygame.draw.circle(
                self.screen,
                (255, 0, 0),
                (int(whale[0] * self.cell_size), int(whale[1] * self.cell_size)),
                10,
            )

        frame_path = os.path.join(self.frame_dir, f"frame_{self.frame_count:04d}.png")
        pygame.image.save(self.screen, frame_path)
        self.frame_count += 1

        pygame.display.flip()
        self.clock.tick(10)  # FPS rate here !

    def close(self):
        pygame.quit()

In [None]:
env = BubbleNetEnv()
check_env(env)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

obs, _ = env.reset()
for _ in range(500):
    action, _ = model.predict(obs)
    obs, reward, done, truncated, info = env.step(action)
    env.render()
    if done:
        break

env.close()

frame_files = sorted([os.path.join(env.frame_dir, f) for f in os.listdir(env.frame_dir) if f.startswith("frame_")])
if frame_files:
    frame = cv2.imread(frame_files[0])
    height, width, _ = frame.shape
    video = cv2.VideoWriter("bubble_net_simulation.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height))
    for frame_file in frame_files:
        video.write(cv2.imread(frame_file))
    video.release()
    print("Video saved as bubble_net_simulation.mp4")