# Setup

In [None]:
import gymnasium as gym
import babybot01_env
from tqdm import tqdm
from stable_baselines3 import PPO, DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from stable_baselines3.common.monitor import Monitor
from torch.utils.tensorboard import SummaryWriter
import time

'''
Use DQN to reach the target.

 - Run env at 60 fps (under 60 it srcew up)
 - Set a sim_factor to 5 to make the env_freq 5x slower than the sim_freq
 - Expose 2 phases_biases (0,65)
 - Performe 6 discrete action (
    - do noting, 
    - increase phases_biases 0
    - decrease phases_biases 0
    - increase phases_biases 65
    - decrease phases_biases 65
    )
 - Observe 3 continuous values 
    - phase_biases 0 [-1, 1]
    - phase_biases 65 [-1, 1]
    - alignement [-1,1]

Task:
    Map phase_biase 0 and 65 incrementation according to current phase_biases 0 and 65 and current alignment.
     

'''

name = 'dqn_spidy_v4_4'
env_id = "Spidy-v4_4"
sim_frequence = 60
sim_factor = 60
n_steps = 60
n_envs = 1
exposed_phases_indexes = [0,65]
action_mode = "discrete"

policy = 'MlpPolicy'
tensorboard_log = f"./{name}/t_logs/"
save_path = f"./{name}/model/"
path = f"./{name}/model/{name}"
log_path = f"./{name}/logs/"
device = 'cpu'

def make_env(render_mode:str=None, debug_mode = False):
    return gym.make(
        env_id, 
        sim_frequence=sim_frequence,
        sim_factor=sim_factor,
        max_episode_steps=n_steps, 
        exposed_phases_indexes= exposed_phases_indexes, 
        render_mode=render_mode,
        debug_mode= debug_mode,
        action_mode=action_mode,
        )
    

# Test action

In [None]:
env = make_env(render_mode="human")
obs, info = env.reset()

print(f"First Obs: {obs}")
#print(f"First Phases biases: {info['phase_biases']}")

action = [0.1, -0.1]
obs, rew, terminated, truncated, info = env.step(action)

print(f"After action Obs: {obs}")
#print(f"Second Phases biases: {info['phase_biases']}")

obs, info = env.reset()
print(f"After reset Obs: {obs}")

env.close()


# Test Env

In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("06_reach_DQN.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

# Create Env
env = make_env(debug_mode=False, render_mode='human')

# Run env

obs, info = env.reset()
for t in range(10):
    
    action = np.random.randint(0, 5)
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"obs: {obs}")
    print(f"rew: {rew}")
    print(f"t: {t}")
    print(f"proximity_reward: {info['proximity_reward']}, alignement_reward: {info['alignement_reward']}")
    print(f"terminate: {terminated}, truncated: {truncated}")
    if terminated or truncated:
        break

env.close()

# Create agent

In [None]:
train_env = make_env()

# PPO
# model = PPO(
#     policy, 
#     train_env, 
#     batch_size = 30, 
#     verbose=0, 
#     n_steps=n_steps, 
#     tensorboard_log=tensorboard_log,
#     ent_coef=0           
# )

model = DQN(
    policy,
    train_env,
    verbose=0,
    tensorboard_log=tensorboard_log
    )

model.save(path)

# Train

In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("06_reach_DQN.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

class ActionLoggerCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(ActionLoggerCallback, self).__init__(verbose)
        self.writer = None

    def _on_training_start(self) -> None:
        # Initialize TensorBoard writer
        self.writer = SummaryWriter(log_dir=self.logger.dir)

    def _on_step(self) -> bool:
        # Get the actions from the rollout buffer
        actions = self.locals['actions']  # PPO stores actions here
        if actions is not None:
            actions_mean = np.mean(actions)
            actions_std = np.std(actions)
            self.writer.add_scalar("policy/actions_mean", actions_mean, self.num_timesteps)
            self.writer.add_scalar("policy/actions_std", actions_std, self.num_timesteps)
        return True

class InfoCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.log_freq = 1

    def _on_step(self) -> bool:
        # Access the training environment's 'step_info' attribute
        if self.num_timesteps % self.log_freq == 0:
            info = self.locals['infos']
            info = info[0]
            
            self.logger.record("exposed_phases_biases/phase_0", info["exposed_phases_biases"][0])
            self.logger.record("exposed_phases_biases/phase_65", info["exposed_phases_biases"][1])
            self.logger.record("rewards/proximity_reward", info["proximity_reward"])
            self.logger.record("rewards/alignement_reward", info["alignement_reward"])

            actions = self.locals["actions"][0]
            self.logger.record("actions", actions)

            self.logger.dump(step=self.num_timesteps)
        
        return True
            
train_env = Monitor(make_env())
eval_env = Monitor(make_env())

eval_callback = EvalCallback(eval_env,
                             best_model_save_path=save_path,
                             log_path=log_path, eval_freq=1e3,
                             deterministic=True, render=False)

class SaveOnStep(BaseCallback):
    def __init__(self, steps: int, path: str):
        super().__init__()
        self.steps = steps
        self.save_path = path

    def _on_step(self) -> bool:
        # Check if the current step matches the saving frequency
        if self.n_calls % self.steps == 0:
            # Save model with the current timestep in the filename

            print(f"Saving model at step {self.n_calls} to {self.save_path}")
            self.model.save(self.save_path)
        return True
    
callbacks = [
    SaveOnStep(1e3, path), 
    InfoCallback(), 
    ActionLoggerCallback(),
    eval_callback
    ]

model = DQN.load(path,train_env ,device=device)

model.learn(
    total_timesteps=1e5,
    progress_bar=True, 
    callback=callbacks, 
    reset_num_timesteps=False)

model.save(path+"_final")

train_env.close()


# Display

In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("06_reach_DQN.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

test_env = make_env(render_mode='human')
model = DQN.load(path)
info = {}

for episode in range(1):

    done = False
    obs, info = test_env.reset()
    for t in range(n_steps):

        action = model.predict(obs)[0]
        obs, reward, terminate, trunc, info = test_env.step(action)
    
        if terminate or trunc:
            break

        print(f"t: {t}, Obs: {np.array_str(obs, precision=2)}, Action: {np.array_str(action, precision=2)}, Rew: {reward:.2f} ")

test_env.close()




In [None]:
# Show CPG parameters
coupling_weights = info['coupling_weights']
phase_biases= info['phase_biases']



fig, axs = plt.subplots(2,1, figsize=(12, 6))  # Adjust size if needed

im1 = axs[0].imshow(coupling_weights, cmap='viridis', aspect='equal')
fig.colorbar(im1, ax=axs[0], orientation='vertical')  # Add a color bar for reference
axs[0].set_title("coupling_weights")
axs[0].set_xlabel("Column Index")
axs[0].set_ylabel("Row Index")
axs[0].set_xticks(range(12))
axs[0].set_yticks(range(12))

im2 = axs[1].imshow(phase_biases, cmap='viridis', aspect='equal')
fig.colorbar(im1, ax=axs[1], orientation='vertical')  # Add a color bar for reference
axs[1].set_title("phase_biases")
axs[1].set_xlabel("Column Index")
axs[1].set_ylabel("Row Index")
axs[1].set_xticks(range(12))
axs[1].set_yticks(range(12))


plt.tight_layout()
plt.show()