# DDPG Training for Balancing Robot

This notebook trains a DDPG agent for the balancing robot environment.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
from pathlib import Path

from src.balancing_robot.models import Actor, Critic, ReplayBuffer
from src.balancing_robot.environment import BalancerEnv
from src.balancing_robot.training import DDPGTrainer
from src.balancing_robot.visualization import plot_training_metrics, create_episode_animation

In [None]:
# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Create environment
env = BalancerEnv(render_mode='rgb_array')

# Model parameters
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# Initialize networks
actor = Actor(state_dim, action_dim, max_action)
critic = Critic(state_dim, action_dim)
buffer = ReplayBuffer(int(1e6))

# Initialize trainer
trainer = DDPGTrainer(
    env=env,
    actor=actor,
    critic=critic,
    buffer=buffer,
    actor_lr=1e-4,
    critic_lr=3e-4,
    gamma=0.99,
    tau=0.005,
    action_noise=0.1
)

In [None]:
# Training parameters
num_episodes = 2000
max_steps = 500
batch_size = 512
eval_freq = 10
save_freq = 100

# Create log directory
log_dir = Path('logs/ddpg_training')
log_dir.mkdir(parents=True, exist_ok=True)

# Train agent
history = trainer.train(
    num_episodes=num_episodes,
    max_steps=max_steps,
    batch_size=batch_size,
    eval_freq=eval_freq,
    save_freq=save_freq,
    log_dir=log_dir
)

In [None]:
# Plot training results
fig = plot_training_metrics(history, save_path=log_dir / 'training_metrics.png')
plt.show()

In [None]:
# Create demonstration video
state = env.reset()
states = []

for _ in range(max_steps):
    action = trainer.select_action(state, training=False)
    next_state, reward, done, _ = env.step(action)
    states.append(state)
    if done:
        break
    state = next_state

states = np.array(states)
anim = create_episode_animation(states, save_path=log_dir / 'demo.mp4')
from IPython.display import HTML
HTML(anim.to_jshtml())