# Q-Learning for Phototaxis Task - Original Complex Reward

This notebook uses the **Phototaxis** (original fixed) reward variant.

**Reward Components:**
- Progress reward (+10× distance improvement)
- Proximity reward (exponential, rewards being close)
- Goal bonus (+100 when reaching light)
- Obstacle penalty (-10× proximity violation)
- Movement penalties (spinning -1.0, oscillation -0.1×)
- Forward bonus (+0.2)
- Survival reward (logarithmic, small)

**Best for:** Fine-grained control, advanced experiments

**Note:** This is the most complex reward with many tunable components.

In [None]:
import sys
sys.path.append("..")

In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
from tqdm import trange

In [None]:
from environment.qlearning.phototaxis_env import PhototaxisEnv
from utils.reader import get_yaml_path, read_file

In [None]:
from agent.qagent import QAgent
from training.qlearning import QLearning

In [None]:
import pygame
import numpy as np

## Connect to Simulator

In [None]:
server_address = "localhost:50051"
client_name = "PhototaxisRLClient"
env = PhototaxisEnv(server_address, client_name)
env.connect_to_client()

## Load Configuration - Original Complex Reward

In [None]:
config_path = get_yaml_path("resources", "configurations", "phototaxis.yml")
config = read_file(config_path)

# print(config)

In [None]:
env.init(config)

## Training Parameters

In [None]:
episodes = 10
steps = 5000

## Create Agent

In [None]:
agent = QAgent(env, episodes = episodes)
agentId = "00000000-0000-0000-0000-000000000001"
agents = { agentId: agent }

## Training Loop



In [None]:
def run_episodes(
    episode_count, 
    episode_max_steps, 
    render=False, 
    fps=60,
    checkpoint_interval=None,
    checkpoint_path="checkpoints/phototaxis_agent",
    load_checkpoint=None,
    start_episode=0
):
    import os
    
    # Load existing agent if specified
    if load_checkpoint:
        for agent_id, agent_obj in agents.items():
            agent_obj.load(load_checkpoint)
            print(f"Loaded agent from {load_checkpoint}")
    
    # Create checkpoint directory if needed
    if checkpoint_interval:
        os.makedirs(os.path.dirname(checkpoint_path) if os.path.dirname(checkpoint_path) else ".", exist_ok=True)
    
    running = True
    paused = False
    current_fps = fps
    
    # Initialize pygame for rendering
    if render:
        pygame.init()
        screen = pygame.display.set_mode((800, 600))
        pygame.display.set_caption(f"Q-Learning Phototaxis - FPS: {current_fps}")
        clock = pygame.time.Clock()
        
        # Font for displaying info
        try:
            font = pygame.font.Font(None, 24)
            info_font = pygame.font.Font(None, 20)
        except:
            font = None
            info_font = None
    
    try:
        for ep_idx in trange(episode_count, desc="Training", unit="ep"):
            actual_episode = start_episode + ep_idx
            obs, _ = env.reset()
            done = False
            total_reward = {agentId: 0}
            step_count = 0
            
            while not done and step_count < episode_max_steps:
                # Handle pygame events
                if render:
                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            running = False
                        elif event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_ESCAPE or event.key == pygame.K_q:
                                running = False
                            elif event.key == pygame.K_SPACE:
                                paused = not paused
                            elif event.key == pygame.K_UP:
                                current_fps = min(240, current_fps + 10)
                                pygame.display.set_caption(f"Q-Learning Phototaxis - FPS: {current_fps}")
                            elif event.key == pygame.K_DOWN:
                                current_fps = max(10, current_fps - 10)
                                pygame.display.set_caption(f"Q-Learning Phototaxis - FPS: {current_fps}")
                            elif event.key == pygame.K_s:
                                # Manual save
                                for agent_id, agent_obj in agents.items():
                                    save_path = f"{checkpoint_path}_manual_ep{actual_episode}"
                                    agent_obj.save(save_path)
                                    print(f"\n[Manual Save] Episode {actual_episode}")
                
                if not running:
                    break
                
                # Skip step if paused
                if paused and render:
                    pygame.time.wait(100)
                    continue
                
                # Choose and execute actions
                actions = {
                    k: agents[k].choose_action(v, epsilon_greedy=not render) 
                    for k, v in obs.items()
                }
                next_obs, rewards, terminateds, truncateds, _ = env.step(actions)
                
                done = terminateds[agentId] or truncateds[agentId]
                
                # Update Q-table (only during training)
                if not render:
                    for k in next_obs.keys():
                        agents[k].update_q(obs[k], actions[k], rewards[k], next_obs[k], done)
                        total_reward[k] += rewards[k]
                else:
                    # Track reward even during rendering
                    total_reward[agentId] += rewards[agentId]
                
                obs = next_obs
                
                # Render visualization
                if render:
                    rgb_array = env.render()
                    surface = pygame.surfarray.make_surface(np.transpose(rgb_array, (1, 0, 2)))
                    screen.blit(surface, (0, 0))
                    
                    # Display info overlay
                    if font and info_font:
                        info_texts = [
                            f"Episode: {actual_episode + 1}/{start_episode + episode_count}",
                            f"Step: {step_count}/{episode_max_steps}",
                            f"Reward: {total_reward[agentId]:.2f}",
                            f"Epsilon: {agents[agentId].epsilon:.4f}",
                            f"FPS: {current_fps} (↑/↓ to adjust)",
                            f"{'PAUSED' if paused else 'SPACE: Pause'}"
                        ]
                        
                        y_offset = 10
                        for text in info_texts:
                            color = (255, 255, 0) if paused else (255, 255, 255)
                            text_surface = info_font.render(text, True, color, (0, 0, 0))
                            screen.blit(text_surface, (10, y_offset))
                            y_offset += 25
                    
                    pygame.display.flip()
                    clock.tick(current_fps)
                
                step_count += 1
            
            if not running:
                print("\nTraining interrupted by user")
                break
            
            # Decay epsilon after episode
            for agent_obj in agents.values():
                agent_obj.decay_epsilon(actual_episode)
            
            # Save checkpoint at intervals
            if checkpoint_interval and (ep_idx + 1) % checkpoint_interval == 0:
                for agent_id, agent_obj in agents.items():
                    save_path = f"{checkpoint_path}_ep{actual_episode + 1}"
                    agent_obj.save(save_path)
                    print(f"\n[Checkpoint] Saved at episode {actual_episode + 1}")
    
    finally:
        # Cleanup pygame
        if render:
            pygame.quit()
        
        # Final save if checkpointing was enabled
        if checkpoint_interval and running:
            for agent_id, agent_obj in agents.items():
                save_path = f"{checkpoint_path}_final"
                agent_obj.save(save_path)
                print(f"\n[Final Save] Training complete")

## Train the Agent with Checkpoints

In [None]:
run_episodes(
    episode_count=episodes,
    episode_max_steps=steps,
    # load_checkpoint="checkpoints/phototaxis_classic_final",
    start_episode=10,
    checkpoint_interval=100,
    checkpoint_path="checkpoints/phototaxis_classic"
)

## Evaluate the Trained Agent

**Keyboard Controls (during render):**
- `↑/↓`: Adjust FPS (10-240)
- `SPACE`: Pause/Resume
- `S`: Manual checkpoint save
- `ESC/Q`: Quit training

In [None]:
# Evaluate with rendering (use ↑/↓ to adjust speed, SPACE to pause)
run_episodes(1, 10000, render=True, load_checkpoint="checkpoints/phototaxis_classic_final")

## Resume Training from Checkpoint (Optional)

If you want to continue training from a saved checkpoint, run this cell.

In [None]:
# Example: Resume from episode 200 and train for 100 more episodes
run_episodes(
    episode_count=100,
    episode_max_steps=steps,
    checkpoint_interval=25,
    checkpoint_path="checkpoints/phototaxis_agent",
    load_checkpoint="checkpoints/phototaxis_agent_ep200",
    start_episode=200
)

# Check Episodes from checkpoint

In [None]:
import numpy as np

d = np.load('checkpoints/phototaxis_classic_final.npz')
total_trained = d.get("total_episodes_trained", 0)
print(f"Resume from episode: {total_trained}")


## Inspect Q-Table Statistics

After training, check how much of the state space was explored.

In [None]:
print(f"Q-table shape: {agent.Q.shape}")
print(f"Non-zero entries: {np.count_nonzero(agent.Q)}")
print(f"Q-table min/max: {agent.Q.min():.2f} / {agent.Q.max():.2f}")

visited_states = np.where(np.any(agent.Q != 0, axis=1))[0]
print(f"States visited: {len(visited_states)} out of {agent.Q.shape[0]}")