<a href="https://colab.research.google.com/github/TheAmirHK/OceanFun_RL/blob/main/BubbleNetFeeding/Codes/BubbleNetFeeding3D_v0_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
import imageio


In [None]:
# In[3D Environment]
class BubbleNetEnv3D(gym.Env):
    def __init__(self):
        super(BubbleNetEnv3D, self).__init__()
        self.grid_size = 100
        self.num_fish = 500
        self.max_steps = 2000
        self.max_bubbles = 500
        self.num_whales = 5

        self.action_space = spaces.MultiDiscrete([6] * self.num_whales)
        self.observation_space = spaces.Box(
            low=0, high=self.grid_size, shape=(3 * self.num_whales + self.num_fish * 3 + self.max_bubbles * 3,), dtype=np.float32
        )

        self.whale_pos = None
        self.fish_pos = None
        self.bubbles = None
        self.steps = 0

        self.frame_dir = "frames"
        os.makedirs(self.frame_dir, exist_ok=True)
        self.frame_count = 0
        self.caught_fish = np.zeros(self.num_fish, dtype=bool)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.whale_pos = np.array([[self.grid_size // 2, 2, self.grid_size // 2]] * self.num_whales, dtype=np.float32)

        angle = np.linspace(0, 2 * np.pi, self.num_fish)
        radius = np.random.uniform(5, self.grid_size / 2, self.num_fish)
        self.fish_pos = np.zeros((self.num_fish, 3), dtype=np.float32)
        self.fish_pos[:, 0] = self.grid_size // 2 + radius * np.cos(angle)
        self.fish_pos[:, 1] = np.random.uniform(2, self.grid_size // 2, self.num_fish)
        self.fish_pos[:, 2] = self.grid_size // 2 + radius * np.sin(angle)
        self.caught_fish = np.zeros(self.num_fish, dtype=bool)
        self.bubbles = []
        self.steps = 0
        self.frame_count = 0
        obs = self._get_obs()
        return obs, {}

    def step(self, action):
        self.steps += 1
        fish_center = np.mean(self.fish_pos[~self.caught_fish], axis=0)
        fish_spread = np.max(np.linalg.norm(self.fish_pos - fish_center, axis=1))
        r = max(5, fish_spread * 0.7)

        if fish_spread > 10:
            if len(self.bubbles) < self.max_bubbles:
                theta = self.steps * 0.1
                self.whale_pos[0] = fish_center + [r * np.cos(theta), r * np.sin(theta), np.sin(theta) * 5]
                self.bubbles.append(self.whale_pos[0].copy())
        else:
            self.whale_pos[0] = fish_center.copy()

        # whales except for the leader do spiral pattern
        angles = np.zeros(self.num_whales - 1)
        for i in range(1, self.num_whales):
            angles[i-1] = self.steps * 0.05 + (i / self.num_whales) * 2 * np.pi
            self.whale_pos[i] = fish_center + np.random.uniform(-0.1, 0.1, size=3) + [ (r + 2)*np.cos(angles[i-1]), np.sin(angles[i-1]), (r + 2)*np.sin(angles[i-1])]

        for j in range(self.num_fish):
            if not self.caught_fish[j]:
                attraction_force = (fish_center - self.fish_pos[j]) * 0.005
                distance_to_center = np.linalg.norm(self.fish_pos[j] - fish_center)
                repulsion_force = (self.fish_pos[j] - fish_center) * 0.01 if distance_to_center < 3 else np.zeros(3)
                random_force = np.random.uniform(-0.5, 0.5, size=3)
                self.fish_pos[j] += attraction_force + repulsion_force + random_force
                self.fish_pos[j] = np.clip(self.fish_pos[j], 0, self.grid_size)

        if len(self.bubbles) > 0:
            bubble_center = np.mean(self.bubbles, axis=0)
            for i in range(self.num_fish):
                if not self.caught_fish[i]:
                    self.fish_pos[i] += (bubble_center - self.fish_pos[i]) * 0.005
                    self.fish_pos[i][1] += 0.4

        # so here what's the trick ?
        # if more than 75% of the fish are within a certain radius of the center (3 units), whales would attack them
        fish_distances = np.linalg.norm(self.fish_pos - fish_center, axis=1)
        fish_within_radius = fish_distances < 3

        if np.sum(fish_within_radius) > 0.75 * self.num_fish:

            fish_indices = np.where(fish_within_radius)[0]
            self.caught_fish[fish_indices] = True
            reward = 1000 * len(fish_indices)

        else:
            reward = 0
            for j in range(self.num_fish):
                if not self.caught_fish[j] and np.linalg.norm(self.whale_pos[0] - self.fish_pos[j]) < 0.1:
                    self.caught_fish[j] = True
                    reward += 1000

        #print(fish_distances)
        print(sum(self.caught_fish))
        done = bool(self.steps >= self.max_steps or np.all(self.caught_fish) or sum(self.caught_fish)>= self.num_fish )

        obs = self._get_obs()
        return obs, reward, done, False, {}

    def _get_obs(self):
        whale_obs = self.whale_pos.flatten()
        fish_obs = self.fish_pos.flatten()
        bubble_obs = np.zeros(self.max_bubbles * 3, dtype=np.float32)
        if len(self.bubbles) > 0:
            bubble_obs[: len(self.bubbles) * 3] = np.array(self.bubbles, dtype=np.float32).flatten()[: self.max_bubbles * 3]
        obs = np.concatenate([whale_obs, fish_obs, bubble_obs]).astype(np.float32)
        return obs

    def render(self):
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')

        if np.any(~self.caught_fish):
            ax.scatter(*zip(*self.fish_pos[~self.caught_fish]), c='green', marker='o', label='Fish')

        ax.scatter(*zip(*self.bubbles), c='blue', marker='o', label='Bubbles')
        ax.scatter(*zip(*self.whale_pos[1:]), c='red', marker='o', s=100, label='Whales')
        ax.scatter(*self.whale_pos[0], c='yellow', marker='o', s=100, label='Leader Whale')

        ax.set_xlim(0, self.grid_size)
        ax.set_ylim(0, self.grid_size)
        ax.set_zlim(0, self.grid_size)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        plt.legend()
        plt.savefig(os.path.join(self.frame_dir, f"frame_{self.frame_count:04d}.png"))
        self.frame_count += 1
        plt.close()

    def save_animation(self, filename="Bubble_net3D.gif"):
        images = []
        for i in range(self.frame_count):
            img_path = os.path.join(self.frame_dir, f"frame_{i:04d}.png")
            images.append(imageio.imread(img_path))
        imageio.mimsave(filename, images, fps=20)

    def close(self):
        pass

In [None]:
# In[Train]
env = BubbleNetEnv3D()
check_env(env)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=3000)

In [None]:
# In[Test]
obs, _ = env.reset()
done = False
print("Test started ...")
while not done:  # More Pythonic way
    action, _ = model.predict(obs)
    obs, reward, done, truncated, info = env.step(action)
    env.render()

env.close()

In [None]:
# In[Save as gif format]
env.save_animation("Bubble_net3D.gif")