In [None]:
import sys

sys.path.append("../..")

In [None]:
import nest_asyncio

nest_asyncio.apply()

import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
from training.dqnetwork import DQNetwork

In [None]:
from environment.deepqlearning.phototaxis_env import PhototaxisEnv
from utils.reader import get_yaml_path, read_file

## Connect to Simulator

In [None]:
server_address = "localhost:50051"
client_name = "PhototaxisDQNClient"
env = PhototaxisEnv(server_address, client_name)
env.connect_to_client()

## Load Configurations


In [None]:
import glob
import os

config_dir = os.path.join("..", "..", "scripts", "resources", "generated", "photo")
config_files = sorted(glob.glob(os.path.join(config_dir, "environment_*.yml")))

configs = []
for config_file in config_files:
    configs.append(read_file(config_file))

print(f"Loaded {len(configs)} configuration files")

env.init(configs[0])
# Reset environment to ensure agents are properly initialized
_ = env.reset()

## Network Architecture

In [None]:
neuron_count_per_hidden_layer = [128, 64]

## Hyperparameters

In [None]:
episode_count = 1000  # Total number of training episodes
episode_max_steps = 2000  # Maximum number of steps per episode

replay_memory_max_size = 100000  # Maximum number of transitions in replay memory
replay_memory_init_size = 10000  # Initial replay memory size before training starts
batch_size = 512  # Mini-batch size for training

step_per_update = 4  # Number of steps between action model updates
step_per_update_target_model = 1000  # Number of steps between target model updates

max_epsilon = 1.0  # Initial exploration probability
min_epsilon = 0.01  # Minimum exploration probability
epsilon_decay = 0.0002  # Decay rate for exploration probability

gamma = 0.99  # Discount factor for future rewards

moving_avg_window_size = 20  # Window size for moving average of rewards
moving_avg_stop_thr = 100  # Threshold for early stopping based on moving average

## Create Agent

In [None]:
from agent.scala_dqagent import DQAgent

agent1 = DQAgent(
    env,
    agent_id="00000000-0000-0000-0000-000000000001",
    action_model=DQNetwork(
        env.observation_space.shape,
        neuron_count_per_hidden_layer,
        env.action_space.n,
        summary=True,
    ),
    target_model=DQNetwork(
        env.observation_space.shape,
        neuron_count_per_hidden_layer,
        env.action_space.n,
        summary=False,
    ),
    epsilon_max=max_epsilon,
    epsilon_min=min_epsilon,
    gamma=gamma,
    replay_memory_max_size=replay_memory_max_size,
    replay_memory_init_size=replay_memory_init_size,
    batch_size=batch_size,
    step_per_update=step_per_update,
    step_per_update_target_model=step_per_update_target_model,
    moving_avg_window_size=moving_avg_window_size,
    moving_avg_stop_thr=moving_avg_stop_thr,
    episode_max_steps=episode_max_steps,
    episodes=episode_count,
)

agents = [agent1]

## Training

In [None]:
import time
from training.multi_agent_dqlearning import DQLearning

train_start_time = time.time()

trainer = DQLearning(
    env,
    agents,
    configs,
    episode_count=episode_count,
    episode_max_steps=episode_max_steps,
)
# train_rewards = trainer.simple_dqn_training()

# train_finish_time = time.time()
# train_elapsed_time = train_finish_time - train_start_time
# train_avg_episode_time = train_elapsed_time / episode_count

# print(f"Train time: {train_elapsed_time / 60.0:.1f}m [{train_avg_episode_time:.1f}s]")

## Evaluation with Visualization

Watch the trained agent perform phototaxis in real-time.

**Keyboard Controls:**
- `ESC/Q`: Quit
- `SPACE`: Pause/Resume
- `↑/↓`: Adjust FPS

In [None]:
trainer.play_with_pygame(episodes=1, fps=90)

## Save Trained Models

In [None]:
# import os
#
# # Create checkpoints directory if it doesn't exist
# os.makedirs("checkpoints", exist_ok=True)
#
# # Save the trained models
# for i, agent in enumerate(agents):
#     agent.action_model.save(f"checkpoints/phototaxis_dqn_agent{i}_action_model.keras")
#     agent.target_model.save(f"checkpoints/phototaxis_dqn_agent{i}_target_model.keras")
#     print(f"Agent {i} models saved successfully")

## Load Pre-trained Models (Optional)

In [None]:
for i, agent in enumerate(agents):
    agent.action_model = tf.keras.models.load_model(
        f"../../scripts/checkpoints/phototaxis_final/action_model.keras"
    )
    agent.target_model = tf.keras.models.load_model(
        f"../../scripts/checkpoints/phototaxis_final/target_model.keras"
    )
    print(f"Agent {i} models loaded successfully")

In [None]:
# Optimize models for faster inference
import tensorflow as tf

# Warm up the models (first prediction is always slower)
print("Warming up models for smooth visualization...")
for agent in agents:
    dummy_state = tf.random.normal((1,) + env.observation_space.shape)
    _ = agent.action_model(dummy_state, training=False)
    _ = agent.target_model(dummy_state, training=False)
print("✓ Models warmed up and ready!")

In [None]:
def play_optimized(trainer, episodes=1, fps=60, render_scale=(800, 600)):
    """Optimized visualization with smoother rendering."""
    import pygame
    import numpy as np
    import tensorflow as tf

    pygame.init()
    screen = pygame.display.set_mode(render_scale)
    pygame.display.set_caption("DQN Agent - Optimized View")
    clock = pygame.time.Clock()
    running = True
    paused = False
    current_fps = fps

    # Font setup
    font = pygame.font.Font(None, 20)

    # Pre-compile model for faster inference
    @tf.function
    def predict_action(model, state):
        return model(state, training=False)

    for ep in range(episodes):
        states, _ = trainer.env.reset()
        done = False
        total_reward = 0
        step = 0

        for agent in trainer.agents:
            agent.terminated = False

        while not done and running:
            step += 1

            # Handle events
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                elif event.type == pygame.KEYDOWN:
                    if event.key in (pygame.K_ESCAPE, pygame.K_q):
                        running = False
                    elif event.key == pygame.K_SPACE:
                        paused = not paused
                    elif event.key == pygame.K_UP:
                        current_fps = min(240, current_fps + 10)
                    elif event.key == pygame.K_DOWN:
                        current_fps = max(10, current_fps - 10)

            if paused:
                pygame.time.wait(100)
                continue

            # Fast action selection using compiled function
            actions = {}
            for agent in trainer.agents:
                state_batch = tf.convert_to_tensor(
                    states[agent.id][np.newaxis], dtype=tf.float32
                )
                q_values = predict_action(agent.action_model, state_batch)
                actions[agent.id] = int(tf.argmax(q_values[0]).numpy())

            # Step environment
            next_states, rewards, terminateds, truncateds, _ = trainer.env.step(actions)
            dones = {
                agent.id: terminateds[agent.id] or truncateds[agent.id]
                for agent in trainer.agents
            }
            done = all(dones.values())
            total_reward += rewards[trainer.agents[0].id]
            states = next_states

            # Optimized rendering
            rgb_array = trainer.env.render()
            surface = pygame.surfarray.make_surface(np.transpose(rgb_array, (1, 0, 2)))
            surface = pygame.transform.smoothscale(
                surface, render_scale
            )  # smoothscale is faster
            screen.blit(surface, (0, 0))

            # Minimal overlay for better performance
            info_texts = [
                f"Ep: {ep + 1} | Step: {step} | Reward: {total_reward:.1f} | FPS: {current_fps}",
                "PAUSED (SPACE)" if paused else "ESC: Quit | SPACE: Pause | ↑↓: FPS",
            ]

            for i, text in enumerate(info_texts):
                color = (255, 255, 0) if paused else (255, 255, 255)
                text_surf = font.render(text, True, color, (0, 0, 0, 180))
                screen.blit(text_surf, (10, 10 + i * 25))

            pygame.display.flip()
            clock.tick(current_fps)

        print(f"Episode {ep + 1}/{episodes} - Reward: {total_reward:.2f}")

    pygame.quit()


print("✓ Optimized visualization function loaded")

In [None]:
# Test all configuration scenarios with optimized visualization
for i, config in enumerate(configs):
    if i == 30:
        break
    print()
    print("=" * 60)
    print(f"Testing Scenario {i + 1}/{len(configs)}")
    print("=" * 60)

    # Initialize environment with specific configuration
    env.init(config)

    # Use optimized visualization (smoother and faster)
    play_optimized(trainer, episodes=1, fps=60)

    import time

    time.sleep(0.3)  # Brief pause between scenarios

print()
print("✓ Finished testing all scenarios!")